From 0d9ca7f1d26e05e748784de977fb18c1e2c45942 Mon Sep 17 00:00:00 2001 From: Maria Santamaria <maria.santamaria_gomez@nokia.com> Date: Mon, 13 May 2024 06:39:44 +0000 Subject: [PATCH] JVET-AH0096: training --- .gitattributes | 1 + .../NN_Adaptive_Filtering/README.md | 312 +++ .../NN_Adaptive_Filtering/config.py | 101 + .../conversion/__init__.py | 0 .../conversion/onnx2sadl_modified.py | 2096 +++++++++++++++++ .../conversion/sadl2torch.py | 539 +++++ .../conversion/torch2sadl.py | 130 + .../create_dataset_dirs.py | 57 + .../NN_Adaptive_Filtering/create_env.sh | 17 + .../NN_Adaptive_Filtering/data_preparation.py | 121 + .../NN_Adaptive_Filtering/environment.yml | 163 ++ .../NN_Adaptive_Filtering/launch_pipeline.py | 164 ++ .../NN_Adaptive_Filtering/models/__init__.py | 61 + .../models/model_lop2.py | 360 +++ .../models/model_lop2_nnr.py | 279 +++ .../models/model_lop2_with_multiplier.py | 468 ++++ .../NN_Adaptive_Filtering/models/models.json | 51 + .../overfitting_pipeline.py | 396 ++++ .../resources/config.json | 36 + .../resources/datasets/jvet.json | 212 ++ .../resources/datasets/jvet_labels.json | 25 + .../resources/num_layers.json | 163 ++ .../NN_Adaptive_Filtering/run/TMLogger.py | 252 ++ .../NN_Adaptive_Filtering/run/__init__.py | 0 .../NN_Adaptive_Filtering/segment_on_off.py | 325 +++ .../NN_Adaptive_Filtering/trainer/__init__.py | 0 .../trainer/nn_filter.py | 429 ++++ .../NN_Adaptive_Filtering/util/__init__.py | 63 + .../NN_Adaptive_Filtering/util/dataset_bin.py | 153 ++ .../NN_Adaptive_Filtering/util/dataset_yuv.py | 533 +++++ .../NN_Adaptive_Filtering/util/file_system.py | 292 +++ .../NN_Adaptive_Filtering/util/image_ops.py | 149 ++ .../NN_Adaptive_Filtering/util/logger.py | 202 ++ .../NN_Adaptive_Filtering/util/metrics.py | 164 ++ .../NN_Adaptive_Filtering/util/regex.py | 127 + .../NN_Adaptive_Filtering/wu_decoding.py | 131 ++ .../NN_Adaptive_Filtering/wu_encoding.py | 253 ++ 37 files changed, 8825 insertions(+) create mode 100644 training/training_scripts/NN_Adaptive_Filtering/README.md create mode 100755 training/training_scripts/NN_Adaptive_Filtering/config.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/conversion/__init__.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/conversion/onnx2sadl_modified.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/conversion/sadl2torch.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/conversion/torch2sadl.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/create_dataset_dirs.py create mode 100755 training/training_scripts/NN_Adaptive_Filtering/create_env.sh create mode 100644 training/training_scripts/NN_Adaptive_Filtering/data_preparation.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/environment.yml create mode 100644 training/training_scripts/NN_Adaptive_Filtering/launch_pipeline.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/models/__init__.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/models/model_lop2.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_nnr.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_with_multiplier.py create mode 100755 training/training_scripts/NN_Adaptive_Filtering/models/models.json create mode 100644 training/training_scripts/NN_Adaptive_Filtering/overfitting_pipeline.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/resources/config.json create mode 100644 training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet.json create mode 100644 training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet_labels.json create mode 100644 training/training_scripts/NN_Adaptive_Filtering/resources/num_layers.json create mode 100644 training/training_scripts/NN_Adaptive_Filtering/run/TMLogger.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/run/__init__.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/segment_on_off.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/trainer/__init__.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/trainer/nn_filter.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/util/__init__.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/util/dataset_bin.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/util/dataset_yuv.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/util/file_system.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/util/image_ops.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/util/logger.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/util/metrics.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/util/regex.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/wu_decoding.py create mode 100644 training/training_scripts/NN_Adaptive_Filtering/wu_encoding.py diff --git a/.gitattributes b/.gitattributes index acf3ec92dd..b6c4a19bc7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -10,3 +10,4 @@ *.index filter=lfs diff=lfs merge=lfs -text *.pb filter=lfs diff=lfs merge=lfs -text *.data-* filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text diff --git a/training/training_scripts/NN_Adaptive_Filtering/README.md b/training/training_scripts/NN_Adaptive_Filtering/README.md new file mode 100644 index 0000000000..685419a81f --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/README.md @@ -0,0 +1,312 @@ +# Torch Over-fitting LOP2 + +## Instructions + +Content-adaptive LOP2 via NN overfitting on each intra-period + +This document explains how to conduct the training and generate the models for inference. +The steps are: + +0. Requirements +1. Set up +2. Data Preparation +3. Convert base models from SADL to Torch +4. Overfitting pipeline +5. Inference + +## 0. Requirements + +* Torch 1.12.1 +* [MPEG NCTM repository](https://git.mpeg.expert/MPEG/Video/NNCoding/NCTM). This is used to code the weight-updates that result from the overfitting process. +* Executable file ```naive_quantization```. This is used to convert float SADL model to int SADL model. + +**Note:** The executable file ```naive_quantization``` can be obtained by building SADL project. Then copy it to this workspace. + +To get access to the NCTM repository, please follow these steps: + +1. Request an account in [https://git.mpeg.expert/](https://git.mpeg.expert/) +2. Inform the AG/WG Convenor to request the account approval +3. Once the account has been approved, request to join the NCTM project to Werner Bailer (werner.bailer@joanneum.at) + +The environment can be created with venv or anaconda: + +### venv setup +`create_env.sh` can be used to create the virtual environment, called `lop2_overf` + +Start using with: + +```shell +source ${HOME}/lop2_overf/bin/activate +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/lop2_overf/lib +export PYTHONPATH=${PYTHONPATH}:${PWD} +``` + +### anaconda setup + +```shell +conda env create -f environment.yml +conda activate lop2_overf +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/anaconda3/envs/lop2_overf/lib +export PYTHONPATH=${PYTHONPATH}:${PWD} +``` + +Clone NCTM repository and apply the patch provided + +```shell +git clone https://git.mpeg.expert/MPEG/Video/NNCoding/NCTM.git +cd NCTM +git checkout v1-v2-harmonization +``` + +Then, install NCTM as library on top of the active environment + +```shell +python setup.py build +python setup.py install +``` + +The NCTM directory can be removed now. + +## 1. Set up + +All the parameters for training are taken from the [resources/config.json](resources/config.json) config file. +Mandatory updates based on your system: + +* `orig_path`: abs path to original data (will be created) +* `deco_path`: abs path to decoded data (will be created) +* `output_path`: abs path to save the output files from the overfitting (will be created) + +The working directory is `<repo_path>/training/training_scripts/EE1-1.4`. + +The overfitting data is organised in a specific way. Create the directory structure with the following command: + +```shell +python create_dataset_dirs.py +``` + +As a result, the original data file structure will look as shown below. There is one directory for each video +sequence the names can be seen in [resources/datasets/jvet_labels.json](resources/datasets/jvet_labels.json) + +```shell +orig +├── A1_CampfireParty +└── ... +``` + +The decoded data file structure will look like this: + +```shell +deco +├── A1_CampfireParty +│ ├── 22 +│ ├── 27 +│ ├── 32 +│ ├── 37 +│ └── 42 +└── ... +``` + +## 2. Data preparation + +The video data is in yuv420p and yuv420p10le formats. + +The following scripts saves to a json file the slice QP for each slice in a video sequence. + +```shell +python data_preparation.py +``` + +### 2.1. Original data + +Copy the JVET RA NNVC CTC mandatory sequences into the corresponding directory under `orig_path`. + +### 2.2. Video encoding and decoding + +Use [NNVC-8.0rc3](https://vcgit.hhi.fraunhofer.de/jvet-ahg-nnvc/VVCSoftware_VTM/-/tree/NNVC-8.0rc3/) and enable both +the NN intra tool and the NN LOP2.0 loop-filter. + +**NOTE**: The decoded data should follow the file structure shown earlier. The output files can be re-organised +after the decoding process takes place. + +Save the following files: + * Decoder log + * Reconstruction before the deblocking filter + * Boundary strength + * Prediction + * Block Prediction Mode + +These files will have the same name with different sequences as prefix across different QPs. +Check the file names in [resources/config.json](resources/config.json) under `nnvc`. + +## 3. SADL model to Torch model + +The torch model for training is generated from the original LOP2 in Int16 SADL format by using the following command: + +```shell +python conversion/sadl2torch.py +``` + +The same command will also create two JSON files: one with the model quantisers and one that maps the parameters of +the original LOP2 model to the model that contains the multiplier parameters. + +The three output files will be located under the `resources` directory as follows: + + +```shell +resources +├── nnlf_lop2_model.pt +├── nnlf_lop2_quantizers.json +└── nnlf_lop2_params_to_new_model.json +``` + +## 4. Overfitting pipeline + +Overfitting pipeline includes following steps: +1. Trainings +2. Weight-update compression +3. Conversion from Torch model to SADL model +4. Quantization of SADL float model + +### 4.1 Overfitting +For each video sequence, the overfitting is done per intra-period and the overfitting is done by using layer-wise optimisation. + +### 4.2 Weight-update compression +Here the weight-update is computed and compressed, generating an NNR/NNC bitstream. The reconstructed Torch model is saved as well. + +### 4.3 Convert Torch models to SADL +The Torch model will be converted into ONNX first, converted into SADL float32 model and then quantised. + +The sample script below launches in a single python process the overfitting pipeline in the same class: + +```shell +CUDA_VISIBLE_DEVICES=2 python launch_pipeline.py -vs A1_CampfireParty A1_FoodMarket A1_Tango--qps 22 27 32 37 42 +CUDA_VISIBLE_DEVICES=3 python launch_pipeline.py -vs A2_CatRobot A2_DaylightRoad A2_ParkRunning --qps 22 27 32 37 42 +CUDA_VISIBLE_DEVICES=4 python launch_pipeline.py -vs B_BasketBallDrive B_BQTerrace B_Cactus B_MarketPlace B_RitualDance --qps 22 27 32 37 42 +CUDA_VISIBLE_DEVICES=0 python launch_pipeline.py -vs C_BasketballDrill C_PartyScene C_BQMall C_RaceHorses_big --qps 22 27 32 37 42 +CUDA_VISIBLE_DEVICES=1 python launch_pipeline.py -vs D_BasketBallPass D_BQSquare D_BlowingBubbles D_RaceHorses_s --qps 22 27 32 37 42 +CUDA_VISIBLE_DEVICES=5 python launch_pipeline.py -vs F_ArenaOfValor F_BBDrillText F_SlideEditing F_SlideShow --qps 22 27 32 37 42 +``` + +The overfitting results will be in `<output_path>/overfittings`. +The Weight-update compression results will be in `<output_path>/nnr_models` +The float SADL models will be in `<output_path>/sadl_float_models` +The int SADL models will be in `<output_path>/sadl_int_models` + +The directory structure will look as shown below (class_seqQP_partPartIdx): + +```shell +<output_path> +├── overfittings +│ ├── D_BlowingBubbles_42_part0 +│ ├── D_BlowingBubbles_42_part1 +│ ├── D_BlowingBubbles_42_part2 +│ └── ... +├── nnr_models +│ ├── nnr +│ │ ├── A1_Tango_22_part0.nnr +│ │ └── ... +│ ├── nnr_A1_Tango_22_part0.pt +│ └── ... +├── sadl_float_models +│ ├── nnr_A1_Tango_22_part0.sadl +│ └── ... +├── quantizers +│ ├── Q_nnr_A1_Tango_22_part0.log +│ └── ... +└── sadl_int_models + ├── nnr_A1_Tango_22_part0.sadl + └── ... +``` + +## 5. Inference + +The content-adaptive LOP2 is run only on the RA config, and **each intra period is encoded separately**. Then the separate +bitstreams are merged and the resulting bitstream is decoded. + +The NNR bitstream is signaled within a Neural Network Filter Update (NNFU) APS. One APS is signalled within the first +B-Frame of an intra-period. + +The following example shows the NNFU encoder parameters for sequence D_BlowingBubbles, QP 22 and intra period 0. + +```shell +NnfuEnabled : 1 +NumNnfus : 1 +NnfuPayloadFileName0 : nnr/D_BlowingBubbles_22_part0.nnr +NnfuModelFileName0 : overfitted_models/nnr_D_BlowingBubbles_22_part0.sadl +``` + +### 5.1. Encoder +Use the default `encoder_randomacess_nnvc.cfg` as well as NNFU APS parameters. + +The first intra period of D_BlowingBubbles QP 22 can be encoded like: + +```shell +./EncoderAppStatic -c encoder_randomacess_nnvc.cfg --IntraPeriod=64 --FramesToBeEncoded=65 --NnfuEnabled=1 --NumNnfus=1 --NnfuPayloadFileName0=nnr/D_BlowingBubbles_22_part0.nnr --NnfuModelFileName0=overfitted_models/nnr_D_BlowingBubbles_22_part0.sadl +``` + +The final bitstream is obtained by merging the best RA segment bitstreams. For that, the script +`segment_on_off.py` is used. This script takes the first pass (data generation with original lop2 model) +and second pass encoding (overfitted model) results and creates +scripts to (1) merge the final segment bitstreams and (2) call the decoder. +The script is tailored for the proponents environment, but it can be tuned to work with others. + +Sample script: + +```shell +python3 segment_on_off.py \ +-vs D_BasketBallPass D_BQSquare D_BlowingBubbles D_RaceHorses_s \ +--qps 22 27 32 37 42 \ +--pass1 /pass1_results \ +--pass2 /pass2_results \ +--output_dir /segment_on_off \ +--intra_models_dir /src/models/intra \ +--lop2_model /src/models/nnlf_lop2_model_int16.sadl \ +--parcat_bin /src/bin/parcatStatic \ +--decoder_bin /src/bin/DecoderAppStatic +``` + +### 5.1. Decoder + +The reconstruction of the overfitted SADL model is done within the NNVC decoder, which calls `wu_decoding.py` +script when an NNFU APS is decoded. The parameter `NnfuOutputFileStem` is required, and it is used as stem for +the NNR bitstream file names and the reconstructed overfitted SADL models. + +The following setup is required to run the decoder correctly: + +1. Create environment variable that points to this directory (where the `wu_decoding.py` script is located) +and added it too to the python path: + +```shell +export WU_CODE=$PWD +export PYTHONPATH=${PYTHONPATH}:${PWD} +``` + +2. Activate pyton environment with + +```shell +source ${HOME}/lop2_overf/bin/activate +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/lop2_overf/lib + +``` + +or + +```shell +conda activate lop2_overf +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/anaconda3/envs/lop2_overf/lib +``` + +3. Disable python warnings: + +```shell +export TF_CPP_MIN_LOG_LEVEL=2 +export PYTHONWARNINGS="ignore" +``` + +4. Build sadl sample and copy the `naive_quantization` binary to the `$WU_CODE` directory + +Sample decoder script: + +```shell +./DecoderAppStatic -b D_BlowingBubbles_22.bin --NnfuOutputFileStem=D_BlowingBubbles_22 +``` diff --git a/training/training_scripts/NN_Adaptive_Filtering/config.py b/training/training_scripts/NN_Adaptive_Filtering/config.py new file mode 100755 index 0000000000..d093bb5553 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/config.py @@ -0,0 +1,101 @@ +""" +The copyright in this software is being made available under this Software +Copyright License. This software may be subject to other third party and +contributor rights, including patent rights, and no such rights are +granted under this license. +Copyright (c) 1995 - 2021 Fraunhofer-Gesellschaft zur Förderung der +angewandten Forschung e.V. (Fraunhofer) +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted for purpose of testing the functionalities of +this software provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. +* Neither the names of the copyright holders nor the names of its +contributors may be used to endorse or promote products derived from this +software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +NO EXPRESS OR IMPLIED LICENSES TO ANY PATENT CLAIMS, INCLUDING +WITHOUT LIMITATION THE PATENTS OF THE COPYRIGHT HOLDERS AND +CONTRIBUTORS, ARE GRANTED BY THIS SOFTWARE LICENSE. THE +COPYRIGHT HOLDERS AND CONTRIBUTORS PROVIDE NO WARRANTY OF PATENT +NON-INFRINGEMENT WITH RESPECT TO THIS SOFTWARE. +""" + +# Format changed + +import sys + +assert sys.version_info >= (3, 6) + + +# def PUT_SYNTAX(): return False +def PUT_SYNTAX(): + return True + + +# def ROW_SKIP(): return False +def ROW_SKIP(): + return True + + +# def TEMPORAL_CONTEXT(): return False +def TEMPORAL_CONTEXT(): + return True + + +# def OPT_QP(): return False +def OPT_QP(): + return True + + +# def SPARSE(): return False +def SPARSE(): + return True + + +# Use center PUT for temporal contexts in UC14A +def UC14A_CENTER_PUT(): + return False + + +# def UC14A_CENTER_PUT(): return True + + +# stochastic binary / ternary quantization +# def SBT(): return False +def SBT(): + return True + + +print("Config:") +print(" PUT_SYNTAX: ", str(PUT_SYNTAX())) +print(" ROW_SKIP: ", str(ROW_SKIP())) +print(" TEMPORAL_CONTEXT: ", str(TEMPORAL_CONTEXT())) +print(" OPT_QP: ", str(OPT_QP())) +print(" SPARSE: ", str(SPARSE())) +print(" UC14A_CENTER_PUT: ", str(UC14A_CENTER_PUT())) +print(" SBT: ", str(SBT())) + +# check tool requirements +if TEMPORAL_CONTEXT(): + assert PUT_SYNTAX() + +if UC14A_CENTER_PUT(): + assert PUT_SYNTAX() + assert TEMPORAL_CONTEXT() diff --git a/training/training_scripts/NN_Adaptive_Filtering/conversion/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/conversion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/training/training_scripts/NN_Adaptive_Filtering/conversion/onnx2sadl_modified.py b/training/training_scripts/NN_Adaptive_Filtering/conversion/onnx2sadl_modified.py new file mode 100644 index 0000000000..f023336cf2 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/conversion/onnx2sadl_modified.py @@ -0,0 +1,2096 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from __future__ import print_function + +import argparse +import copy +import struct +import sys +from collections import OrderedDict +from enum import IntEnum +from pathlib import Path + +import numpy as np +import onnx + +from util.file_system import read_json_file + +# file format: +# MAGIC: SADL0004 [char[8]] +# type_model [int32_t] 0:int32, 1:float, 2:int16 +# nb_layers [int32_t] +# nb_inputs [int32_t] +# inputs_id [int32_t[nb_inputs]] +# nb_outputs [int32_t] +# outputs_id [int32_t[nb_outputs]] +# (for all layers:) +# layer_id [int32_t] +# op_id [int32_t] +# name_size [int32_t] +# name [char[name_size]] +# nb_inputs [int32_t] +# intput_ids [int32_t[nb_inputs]] +# +# (additional information) +# Const_layer: +# length_dim [int32_t] +# dim [int32_t[length_dim]] +# type [int32_t] 0:int32, 1:float32 2:int16 +# [if integer: quantizer [int32]) +# data [type[prod(dim)]] +# +# Conv2DTranspose +# nb_dim_strides [int32_t] +# strides [int32_t[nb_dim_strides]] +# quantizer [int32_t] +# +# Conv2D +# nb_dim_strides [int32_t] +# strides [int32_t[nb_dim_strides]] +# quantizer [int32_t] +# +# MatMul +# quantizer [int32_t] +# +# Mul +# quantizer [int32_t] +# +# PlaceHolder +# length_dim [int32_t] +# dim [int32_t[length_dim]] +# quantizer [int32_t] +# +# MaxPool +# nb_dim_strides [int32_t] +# strides [int32_t[nb_dim_strides]] +# nb_dim_kernel [int32_t] +# kernel_dim [int32_t[nb_dim_kernel]] + + +class OPTYPE(IntEnum): + Const = (1,) + Placeholder = (2,) + Identity = (3,) + BiasAdd = (4,) + MaxPool = (5,) + MatMul = (6,) + Reshape = (7,) + Relu = (8,) + Conv2D = (9,) + Add = (10,) + ConcatV2 = (11,) + Mul = (12,) + Maximum = (13,) + LeakyReLU = (14,) + Transpose = (15,) + Flatten = (16,) + Shape = (17,) + Expand = (18,) + Conv2DTranspose = (19,) + Slice = ( + 20, + ) # Currently slicing across depth is supported with default step size of 1 + PReLU = (21,) + # In "tf2cpp", the same layer performs the matrix multiplication + # and the matrix multiplication by batches. + BatchMatMul = (6,) + ScatterND = (22,) + GridSample = (23,) + Resize = (24,) + Compare = (25,) + Where = (26,) + Minimum = (27,) + + # "BatchMatMulV2" did not exist in Tensorflow 1.9. It exists in + # Tensorflow 1.15. + BatchMatMulV2 = 6 + Count = 28 + + def __repr__(self): + return self.name + + def __str__(self): + return self.name + + +class DTYPE_SADL(IntEnum): + FLOAT = (1,) # float + INT8 = (3,) # int8_t + INT16 = (2,) # int16_t + INT32 = 0 # int32_t + + def __repr__(self): + return self.name + + def __str__(self): + return self.name + + +class DTYPE_ONNX(IntEnum): + # https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L483-L485 + FLOAT = (1,) # float + INT8 = (3,) # int8_t + INT16 = (4,) # int16_t + INT32 = (6,) # int32_t + INT64 = 7 # int64_t + + def __repr__(self): + return self.name + + def __str__(self): + return self.name + + +class Node_Annotation: + to_remove = False + add_transpose_before = False + add_transpose_after = False + to_transpose = False + layout_onnx = None + + def __repr__(self): + return "to_remove={}, to_transpose={}, layout_onnx={}, add_transpose_before={} add_transpose_after={}".format( + self.to_remove, + self.to_transpose, + self.layout_onnx, + self.add_transpose_before, + self.add_transpose_after, + ) + + +# get attribute name in node +def getAttribute(node, attr): + for a in node.attribute: + if a.name == attr: + return a + return None + + +def transpose_tensor(raw_data, dims): + """ + When convert TF2 to ONNX, ONNX weight's are not represent in the same way as TF2 weight's + """ + # print(dims) + tmp = [] + tmp.append(dims[2]) + tmp.append(dims[3]) + tmp.append(dims[1]) + tmp.append(dims[0]) + + x = np.frombuffer(raw_data, dtype=np.float32) + x = x.reshape(tmp[3], tmp[2], tmp[0] * tmp[1]).transpose().flatten() + return x.tobytes(), tmp + + +def transpose_matrix(raw_data, dims): + x = np.frombuffer(raw_data, dtype=np.float32) + tmp = [] + tmp.append(dims[1]) + tmp.append(dims[0]) + x = x.reshape(dims[0], dims[1]) + x = np.transpose(x) # moveaxis(x, -2, -1) + return x.flatten().tobytes(), tmp + + +def toList(ii): + d = [] + for i in ii: + d.append(i) + return d + + +def is_constant(name, onnx_initializer): + for n in onnx_initializer: + if n.name == name: + return True + return False + + +def is_output(name, onnx_output): + for out in onnx_output: + if out.name == name: + return True + return False + + +def parse_graph_input_node(input_node, map_onnx_to_myGraph, to_transpose): + map_onnx_to_myGraph[input_node.name] = input_node.name + struct = {} + struct["inputs"] = [] + struct["additional"] = {} + if ( + to_transpose + ): # data_layout == 'nchw' and len(input_node.type.tensor_type.shape.dim)==4: + struct["additional"]["dims"] = [ + input_node.type.tensor_type.shape.dim[0].dim_value, + input_node.type.tensor_type.shape.dim[2].dim_value, + input_node.type.tensor_type.shape.dim[3].dim_value, + input_node.type.tensor_type.shape.dim[1].dim_value, + ] + else: + struct["additional"]["dims"] = [ + d.dim_value for d in input_node.type.tensor_type.shape.dim + ] + struct["op_type"] = OPTYPE.Placeholder + return struct + + +def extract_additional_data_from_node(data, to_transpose): + tmp = {} + if data.dims == []: + tmp["dims"] = [1] + else: + tmp["dims"] = [dim for dim in data.dims] + + tmp["raw_data"] = data.raw_data + + if data.data_type == DTYPE_ONNX.FLOAT: + tmp["dtype"] = DTYPE_SADL.FLOAT + elif data.data_type == DTYPE_ONNX.INT8: + tmp["dtype"] = DTYPE_SADL.INT8 + elif data.data_type == DTYPE_ONNX.INT16: + tmp["dtype"] = DTYPE_SADL.INT16 + elif data.data_type == DTYPE_ONNX.INT32: + tmp["dtype"] = DTYPE_SADL.INT32 + elif data.data_type == DTYPE_ONNX.INT64: + + def convert_int64_to_int32(binary_data): + x = np.frombuffer(binary_data, dtype=np.int64) + x = x.astype(np.int32) + return x.tobytes() + + tmp["dtype"] = DTYPE_SADL.INT32 + tmp["raw_data"] = convert_int64_to_int32(tmp["raw_data"]) + else: + raise ValueError("extract_additional_data: Unknown dtype") + + if to_transpose: + if len(tmp["dims"]) == 4: + tmp["raw_data"], tmp["dims"] = transpose_tensor( + tmp["raw_data"], tmp["dims"] + ) + elif len(tmp["dims"]) == 2: # and data_layout == "nchw": + tmp["raw_data"], tmp["dims"] = transpose_matrix( + tmp["raw_data"], tmp["dims"] + ) + + return tmp["dims"], tmp["raw_data"], tmp["dtype"] + + +def extract_additional_data(name, to_transpose, onnx_graph, verbose): + if verbose: + print("[INFO] {} transpose={}".format(name, to_transpose)) + + for init in onnx_graph.initializer: + if name == init.name: + return extract_additional_data_from_node(init, to_transpose) + for node in onnx_graph.node: # not found in initializaer, search in Constant + if name == node.output[0]: + return extract_additional_data_from_node(node.attribute[0].t, to_transpose) + quit("[ERROR] unable to extract data in {}".format(name)) + + +def extract_dims(name, onnx_graph): + for init in onnx_graph.initializer: + if name == init.name: + return init.dims + for node in onnx_graph.node: # not found in initializaer, search in Constant + if name == node.output[0]: + a = getAttribute(node, "value") + if a is not None: + return a.t.dims + else: + return None + for node in onnx_graph.input: # not found in initializaer, search in Constant + if name == node.name: + return node.type.tensor_type.shape.dim + quit("[ERROR] unable to extract dims in {}".format(name)) + + +# get the nodes with name as input +def getNodesWithInput(name, model): + L = [] + for node in model.graph.node: + for inp in node.input: + if inp == name: + L.append(node) + return L + + +# get the nodes with name as output +def getNodesWithOutput(name, model): + for node in model.graph.node: + for out in node.output: + if out == name: + return node + for node in model.graph.initializer: + if node.name == name: + return node + for node in model.graph.input: + if node.name == name: + return node + quit("[ERROR] not found: {}".format(name)) + + +# get the nodes with name as output +def getNodesWithOutputNotConst(name, model): + for node in model.graph.node: + for out in node.output: + if out == name: + return node + for node in model.graph.input: + if node.name == name: + return node + return None + + +# get dims from data +def getDims(node): + if node.data_type != DTYPE_ONNX.INT64: + quit("[ERROR] bad node type fpr getDims {}".format(node)) + + x = np.frombuffer(node.raw_data, dtype=np.int64) + dims = x.tolist() + return dims + + +def getInitializer(name, model_onnx): + for node in model_onnx.graph.initializer: + if node.name == name: + return node + return None + + +def add_transpose(node, myGraph, map_onnx_to_myGraph): + # Transpose inserted + # Const + reshape_coef_name = node.input[0] + "_COEF_TRANSPOSE_NOT_IN_GRAPH" + myGraph[reshape_coef_name] = {} + myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const + myGraph[reshape_coef_name]["inputs"] = [] + additional = {} + additional["dims"] = [4] + additional["raw_data"] = np.array( + [0, 3, 1, 2], dtype=np.int32 + ).tobytes() # nhwc -> nchw + additional["dtype"] = DTYPE_SADL.INT32 + additional["data"] = node + myGraph[reshape_coef_name]["additional"] = additional + map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name + + nname = node.input[0] + "_TRANSPOSE_NOT_IN_GRAPH" + myGraph[nname] = {} + myGraph[nname]["op_type"] = OPTYPE.Transpose + myGraph[nname]["inputs"] = [map_onnx_to_myGraph[node.input[0]], reshape_coef_name] + map_onnx_to_myGraph[nname] = nname + return nname + + +def add_transpose_after(node, myGraph, map_onnx_to_myGraph): + # Transpose inserted + # Const + reshape_coef_name = node.output[0] + "_COEF_TRANSPOSE_AFTER_NOT_IN_GRAPH" + myGraph[reshape_coef_name] = {} + myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const + myGraph[reshape_coef_name]["inputs"] = [] + additional = {} + additional["dims"] = [4] + additional["raw_data"] = np.array( + [0, 2, 3, 1], dtype=np.int32 + ).tobytes() # nchw -> nhwc + additional["dtype"] = DTYPE_SADL.INT32 + additional["data"] = node + myGraph[reshape_coef_name]["additional"] = additional + map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name + + nname = node.output[0] + "_TRANSPOSE_AFTER_NOT_IN_GRAPH" + myGraph[nname] = {} + myGraph[nname]["op_type"] = OPTYPE.Transpose + myGraph[nname]["inputs"] = [map_onnx_to_myGraph[node.output[0]], reshape_coef_name] + map_onnx_to_myGraph[nname] = nname + map_onnx_to_myGraph[node.output[0]] = nname + return nname + + +def parse_graph_node( + node, model_onnx, myGraph, node_annotation, map_onnx_to_myGraph, verbose +): + if verbose > 1: + print("parse node", node.name) + + if node_annotation[ + node.name + ].add_transpose_before: # layout_onnx == 'nchw' : # need to go back to original layout before reshape + n0name = add_transpose(node, myGraph, map_onnx_to_myGraph) + else: + if len(node.input) >= 1: + n0name = node.input[0] + else: + n0name = None + + if ( + node.op_type == "Conv" + or node.op_type == "Gemm" + or node.op_type == "ConvTranspose" + ): + nb_inputs = len(node.input) + if (nb_inputs != 3) and (nb_inputs != 2): + raise Exception("parse_graph_node: Error on node type") + additional = {} + # Const: weight + additional["data"] = node + n2 = getNodesWithOutput(node.input[1], model_onnx) + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data( + node.input[1], + node_annotation[n2.name].to_transpose, + model_onnx.graph, + verbose, + ) + map_onnx_to_myGraph[node.input[1]] = node.input[1] + + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + + # Conv2d + inputs, additional = [], {} + inputs = [map_onnx_to_myGraph[n0name]] + [map_onnx_to_myGraph[node.input[1]]] + + additional["data"] = node + if node.op_type == "Conv" or node.op_type == "ConvTranspose": + a = getAttribute(node, "strides") + additional["strides"] = a.ints + if node.op_type == "Conv": + a = getAttribute(node, "group") + additional["group"] = a.i + a = getAttribute(node, "pads") + # if pads is unavailable then no padding + if a: + additional["pads"] = a.ints + else: + additional["pads"] = [0, 0, 0, 0] + if node.op_type == "ConvTranspose": + a = getAttribute(node, "output_padding") + if a: + additional["output_padding"] = a.ints + else: + additional["output_padding"] = [0, 0] + + if nb_inputs == 2: + map_onnx_to_myGraph[node.output[0]] = node.output[0] + elif nb_inputs == 3: + map_onnx_to_myGraph[node.output[0]] = node.output[0] + "_NOT_IN_GRAPH" + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["inputs"] = inputs + myGraph[node.output[0]]["additional"] = additional + if node.op_type == "Conv": + myGraph[node.output[0]]["op_type"] = OPTYPE.Conv2D + elif node.op_type == "ConvTranspose": + myGraph[node.output[0]]["op_type"] = OPTYPE.Conv2DTranspose + elif node.op_type == "Gemm": + myGraph[node.output[0]]["op_type"] = OPTYPE.MatMul + + if nb_inputs == 3: + additional = {} + # Const: bias + additional["data"] = node + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data(node.input[2], False, model_onnx.graph, verbose) + map_onnx_to_myGraph[node.input[2]] = node.input[2] + myGraph[node.input[2]] = {} + myGraph[node.input[2]]["inputs"] = [] + myGraph[node.input[2]]["additional"] = additional + myGraph[node.input[2]]["op_type"] = OPTYPE.Const + # BiasAdd + inputs, additional = [], {} + inputs = [node.output[0]] + [map_onnx_to_myGraph[node.input[2]]] + additional["data"] = node + map_onnx_to_myGraph[node.output[0] + "_NOT_IN_GRAPH"] = None + myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {} + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = inputs + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.BiasAdd + + elif node.op_type == "Relu": + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Relu + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + # Constant node value has to be skipped for slicing + # This node is not used by any other onnx model in utests/models directory + # Const value of slice is taken care inside "Slice" condition + elif node.op_type == "Constant": # ~ like an initializer + pass + + elif node.op_type == "Add": + swap_inputs = False + if is_constant(n0name, model_onnx.graph.initializer): + additional = {} + additional["data"] = node + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data(n0name, False, model_onnx.graph, verbose) + map_onnx_to_myGraph[n0name] = n0name + myGraph[n0name] = {} + myGraph[n0name]["inputs"] = [] + myGraph[n0name]["additional"] = additional + myGraph[n0name]["op_type"] = OPTYPE.Const + swap_inputs = True + if is_constant(node.input[1], model_onnx.graph.initializer): + additional = {} + additional["data"] = node + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose) + map_onnx_to_myGraph[node.input[1]] = node.input[1] + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Add + if not swap_inputs: + D1 = extract_dims(n0name, model_onnx.graph) + D2 = extract_dims(node.input[1], model_onnx.graph) + if D1 is not None and D2 is not None and len(D1) < len(D2): + swap_inputs = True + + if swap_inputs: + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[node.input[1]], + map_onnx_to_myGraph[n0name], + ] + else: + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[n0name], + map_onnx_to_myGraph[node.input[1]], + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "MaxPool": + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.MaxPool + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + myGraph[node.output[0]]["additional"] = {} + a = getAttribute(node, "strides") + myGraph[node.output[0]]["additional"]["strides"] = [1, a.ints[0], a.ints[1], 1] + a = getAttribute(node, "pads") + if a is None: + pp = [0, 0, 0, 0] + else: + pp = a.ints + myGraph[node.output[0]]["additional"]["pads"] = pp + a = getAttribute(node, "kernel_shape") + myGraph[node.output[0]]["additional"]["kernel_shape"] = [ + 1, + a.ints[0], + a.ints[1], + 1, + ] + myGraph[node.output[0]]["additional"]["data"] = node + # todo: check pads? + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Mul": + # check the inputs + if is_constant(n0name, model_onnx.graph.initializer) and is_constant( + node.input[1], model_onnx.graph.initializer + ): + quit("[ERROR] unsupported double constants Mul", node) + swap_inputs = False + if is_constant(n0name, model_onnx.graph.initializer): + additional = {} + additional["data"] = node + n2 = getNodesWithOutput(n0name, model_onnx) + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data( + n0name, node_annotation[n2.name].to_transpose, model_onnx.graph, verbose + ) + map_onnx_to_myGraph[n0name] = n0name + myGraph[n0name] = {} + myGraph[n0name]["inputs"] = [] + myGraph[n0name]["additional"] = additional + myGraph[n0name]["op_type"] = OPTYPE.Const + swap_inputs = True + if is_constant(node.input[1], model_onnx.graph.initializer): + additional = {} + additional["data"] = node + n2 = getNodesWithOutput(node.input[1], model_onnx) + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data( + node.input[1], + node_annotation[n2.name].to_transpose, + model_onnx.graph, + verbose, + ) + map_onnx_to_myGraph[node.input[1]] = node.input[1] + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Mul + if swap_inputs: + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[node.input[1]], + map_onnx_to_myGraph[n0name], + ] + else: + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[n0name], + map_onnx_to_myGraph[node.input[1]], + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Identity" or node.op_type == "Cast": + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Identity + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "LeakyRelu": + # leaky coef + additional = {} + additional["data"] = node + additional["dims"] = [1] + additional["raw_data"] = np.array( + float(node.attribute[0].f), dtype=np.float32 + ).tobytes() + additional["dtype"] = DTYPE_SADL.FLOAT + map_onnx_to_myGraph[node.output[0] + "_COEF_NOT_IN_GRAPH"] = None + myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {} + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = [] + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.Const + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.LeakyReLU + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[n0name], + node.output[0] + "_NOT_IN_GRAPH", + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "PRelu": + additional = {} + additional["data"] = node + n2 = getNodesWithOutput(node.input[1], model_onnx) + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose) + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + map_onnx_to_myGraph[node.input[1]] = node.input[1] + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.PReLU + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + [ + map_onnx_to_myGraph[node.input[1]] + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Flatten": + inputs, additional = [], {} + inputs = [map_onnx_to_myGraph[n0name]] + additional["data"] = node + a = getAttribute(node, "axis") + additional["axis"] = a.i + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["inputs"] = inputs + myGraph[node.output[0]]["additional"] = additional + myGraph[node.output[0]]["op_type"] = OPTYPE.Flatten + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Shape": + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Shape + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Expand": + inputs, additional = [], {} + inputs = [map_onnx_to_myGraph[n0name], map_onnx_to_myGraph[node.input[1]]] + additional["data"] = node + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["inputs"] = inputs + myGraph[node.output[0]]["additional"] = additional + myGraph[node.output[0]]["op_type"] = OPTYPE.Expand + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Reshape" or node.op_type == "MatMul": + # Const + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + myGraph[node.input[1]]["inputs"] = [] + additional = {} + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose) + additional["data"] = node + myGraph[node.input[1]]["additional"] = additional + map_onnx_to_myGraph[node.input[1]] = node.input[1] + n2 = getNodesWithOutput(node.input[0], model_onnx) + # Reshape + inputs, additional = [], {} + inputs = [map_onnx_to_myGraph[n0name], node.input[1]] + additional["data"] = node + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["inputs"] = inputs + myGraph[node.output[0]]["additional"] = additional + + if node.op_type == "Reshape": + myGraph[node.output[0]]["op_type"] = OPTYPE.Reshape + elif node.op_type == "MatMul": + myGraph[node.output[0]]["op_type"] = OPTYPE.MatMul + + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Concat": + # Const + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Const + myGraph[node.output[0]]["inputs"] = [] + additional = {} + additional["dims"] = [1] + additional["raw_data"] = np.array(node.attribute[0].i, dtype=np.int32).tobytes() + additional["dtype"] = DTYPE_SADL.INT32 + additional["data"] = node + myGraph[node.output[0]]["additional"] = additional + map_onnx_to_myGraph[node.output[0] + "_NOT_IN_GRAPH"] = None + + # Concatenate + inputs, additional = [], {} + for inp in node.input: + inputs.append(map_onnx_to_myGraph[inp]) + inputs.append(node.output[0]) + additional["data"] = node + myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {} + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = inputs + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional + myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.ConcatV2 + map_onnx_to_myGraph[node.output[0]] = node.output[0] + "_NOT_IN_GRAPH" + + elif node.op_type == "Max": + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Maximum + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[n0name], + map_onnx_to_myGraph[node.input[1]], + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Min": + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Minimum + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[n0name], + map_onnx_to_myGraph[node.input[1]], + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Unsqueeze": + # No need to parse Unsqueeze as SADL can handle it. + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Transpose": + # Const + reshape_coef_name = node.output[0] + "_COEF_TRANSPOSE" + myGraph[reshape_coef_name] = {} + myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const + myGraph[reshape_coef_name]["inputs"] = [] + additional = {} + d = toList(getAttribute(node, "perm").ints) + additional["dims"] = [len(d)] + additional["raw_data"] = np.array(d, dtype=np.int32).tobytes() + additional["dtype"] = DTYPE_SADL.INT32 + additional["data"] = node + myGraph[reshape_coef_name]["additional"] = additional + map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Transpose + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[n0name], + reshape_coef_name, + ] + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Slice": + # Slice + if len(node.input) == 5: # PyTorch + initializer = getInitializer(node.input[3], model_onnx) + # Case: In pytorch, Slice is not in model_onnx.graph.initializer but in model_onnx.graph.node + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[3], model_onnx), "value" + ) + initializer = attribute.t + axes = getDims(initializer) + + initializer = getInitializer(node.input[4], model_onnx) + # Case: In pytorch, Slice is not in model_onnx.graph.initializer but in model_onnx.graph.node + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[4], model_onnx), "value" + ) + initializer = attribute.t + steps = getDims(initializer) + + if len(axes) != 1: + quit( + "[ERROR] currently sadl slicing support lenght of axes equal to one" + ) + if axes[0] == 0: + quit("[ERROR] currently slicing not supported for first dimension") + if not (len(steps) == 1 and steps[0] == 1): + quit("[ERROR] currently step has to be default one") + + # Currently slicing support only across width is added + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Slice + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + # assume depth is the last one, assume axes are always 0, 1, 2, etc. + + initializer = getInitializer(node.input[1], model_onnx) + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[1], model_onnx), "value" + ) + initializer = attribute.t + start = getDims(initializer) + + initializer = getInitializer(node.input[2], model_onnx) + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[2], model_onnx), "value" + ) + initializer = attribute.t + end = getDims(initializer) + additional = {} + dim_keys = ["b", "h", "w", "c"] + for i in range(1, len(dim_keys)): + additional[f"start_{dim_keys[i]}"] = 0 + additional[f"end_{dim_keys[i]}"] = 2147483647 + + # model_onnx got from tensorflow has length of start and end equal to 4 + # model_onnx got from pytorch has length of start and end equal to 1. The dimension of slicing + # i.e., if slicing is done across C or H or W is controlled by axes + if len(start) > 1: # TensorFlow to onnx models + if start[0] != 0: + quit("[ERROR] currently slicing not supported for first dimension") + if end[0] != 2147483647: + quit("[ERROR] currently slicing not supported for first dimension") + for i in range(1, len(start)): + initializer = getInitializer(node.input[1], model_onnx) + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[1], model_onnx), "value" + ) + initializer = attribute.t + start_d = getDims(initializer)[i] + + initializer = getInitializer(node.input[2], model_onnx) + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[2], model_onnx), "value" + ) + initializer = attribute.t + end_d = getDims(initializer)[i] + if ( + end_d > 2147483647 + ): # The default infinity number in PyTorch INT64 ONNX is 9223372036854775807. + end_d = 2147483647 + additional[f"start_{dim_keys[i]}"] = start_d + additional[f"end_{dim_keys[i]}"] = end_d + else: # PyTorch to onnx models + dim_keys_torch = ["b", "c", "h", "w"] + for i in range(len(end) - 1): + if start[i] != 0: + quit("[ERROR] currently slicing not supported for first dimension") + if end[i] < 2147483647: + quit("[ERROR] currently slicing only supported for last channel") + + initializer = getInitializer(node.input[1], model_onnx) + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[1], model_onnx), "value" + ) + initializer = attribute.t + start_d = getDims(initializer)[-1] + + initializer = getInitializer(node.input[2], model_onnx) + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[2], model_onnx), "value" + ) + initializer = attribute.t + end_d = getDims(initializer)[-1] + if ( + end_d > 2147483647 + ): # The default infinity number in PyTorch INT64 ONNX is 9223372036854775807. + end_d = 2147483647 + additional[f"start_{dim_keys_torch[axes[0]]}"] = start_d + additional[f"end_{dim_keys_torch[axes[0]]}"] = end_d + myGraph[node.output[0]]["additional"] = additional + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "ScatterND": + # The default input order for the ScatterND is data, indices, and updates. + if not is_constant(node.input[1], model_onnx.graph.initializer): + quit("[ERROR] The second input of the ScatterND must be indices.") + # indices + additional = {} + additional["data"] = node + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose) + if len(additional["dims"]) == 5: + # When the tensor format is specified as NCHW4 (or NHWC4) and the value of N is 1, the format is transformed + # to CHW4 (or HWC4). Here, the "4" indicates the position index within a 4-dimensional tensor. + additional["dims"] = additional["dims"][1:] + # transpose CHW4 to HWC4 + if node_annotation[node.input[1]].to_transpose: + tmp = [ + additional["dims"][1], + additional["dims"][2], + additional["dims"][0], + additional["dims"][3], + ] + x = ( + np.frombuffer(additional["raw_data"], dtype=np.int32) + .reshape(additional["dims"]) + .transpose(1, 2, 0, 3) + ) + indices = x.copy() + for i in np.ndindex(indices.shape[:-1]): + indices[i] = [ + indices[i][0], + indices[i][2], + indices[i][3], + indices[i][1], + ] + additional["dims"] = tmp + additional["raw_data"] = indices.flatten().tobytes() + else: + quit("[ERROR] Currently, ScatterND only supports indices of length 5.") + map_onnx_to_myGraph[node.input[1]] = node.input[1] + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.ScatterND + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[n0name], # data + map_onnx_to_myGraph[node.input[2]], # updates + map_onnx_to_myGraph[node.input[1]], + ] # indices + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "GridSample": + # Currently, the official TensorFlow does not have an implementation for GridSample. + align_corners = getAttribute(node, "align_corners").i + mode = getAttribute(node, "mode").s.decode("utf-8") + padding_mode = getAttribute(node, "padding_mode").s.decode("utf-8") + + mode_list = ["nearest", "bilinear"] + if mode not in mode_list: + quit("[ERROR] Currently, the mode of GridSample must in", mode_list, node) + else: + mode = mode_list.index(mode) + padding_mode_list = ["border"] + if padding_mode not in padding_mode_list: + quit( + "[ERROR] Currently, the padding_mode of GridSample must in", + padding_mode_list, + node, + ) + else: + padding_mode = padding_mode_list.index(padding_mode) + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.GridSample + myGraph[node.output[0]]["inputs"] = [ + map_onnx_to_myGraph[n0name], + map_onnx_to_myGraph[node.input[1]], + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + myGraph[node.output[0]]["additional"]["align_corners"] = align_corners + myGraph[node.output[0]]["additional"]["mode"] = mode + myGraph[node.output[0]]["additional"]["padding_mode"] = padding_mode + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Resize": + # Between 2 and 4 inputs. The default input order for the Resize is X, roi, scales, and sizes. + input_label = 0 + input_list = [map_onnx_to_myGraph[n0name]] + for input_index, input_name in enumerate(node.input): + if is_constant(input_name, model_onnx.graph.initializer): + additional = {} + additional["data"] = node + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data( + input_name, False, model_onnx.graph, verbose + ) + if ( + additional["raw_data"] == b"" + ): # When tensor data is empty, just ignore it. + continue + + map_onnx_to_myGraph[input_name] = input_name + myGraph[input_name] = {} + myGraph[input_name]["inputs"] = [] + myGraph[input_name]["additional"] = additional + myGraph[input_name]["op_type"] = OPTYPE.Const + input_list.append(map_onnx_to_myGraph[input_name]) + input_label = input_label + (1 << (3 - input_index)) + if input_label != 1 and input_label != 2: + quit( + "[ERROR] Currently, the inputs of Resize have to be X and sizes, or X and scales." + ) + + # attribute (str -> int) + coordinate_transformation_mode = getAttribute( + node, "coordinate_transformation_mode" + ).s.decode("utf-8") + cubic_coeff_a = getAttribute(node, "cubic_coeff_a") + exclude_outside = getAttribute(node, "exclude_outside") + mode = getAttribute(node, "mode").s.decode("utf-8") + nearest_mode = getAttribute(node, "nearest_mode").s.decode("utf-8") + + coordinate_transformation_mode_list = ["half_pixel", "asymmetric"] + if coordinate_transformation_mode not in coordinate_transformation_mode_list: + quit( + "[ERROR] Currently, the coordinate_transformation_mode of Resize must in", + coordinate_transformation_mode_list, + node, + ) + else: + coordinate_transformation_mode = coordinate_transformation_mode_list.index( + coordinate_transformation_mode + ) + if cubic_coeff_a is None: + cubic_coeff_a = -0.75 + else: + cubic_coeff_a = cubic_coeff_a.f + if not cubic_coeff_a == -0.75: + quit( + "[ERROR] Currently, the cubic_coeff_a of Resize must be default -0.75.", + node, + ) + if exclude_outside is None: + exclude_outside = 0 + else: + exclude_outside = exclude_outside.i + if not exclude_outside == 0: + quit( + "[ERROR] Currently, the exclude_outside of Resize must be default 0.", + node, + ) + mode_list = ["linear", "nearest"] + if mode not in mode_list: + quit("[ERROR] Currently, the mode of Resize must in", mode_list, node) + else: + mode = mode_list.index(mode) + nearest_mode_list = ["floor", "round_prefer_ceil"] + if nearest_mode not in nearest_mode_list: + quit( + "[ERROR] Currently, the nearest_mode of Resize must in", + nearest_mode_list, + node, + ) + else: + nearest_mode = nearest_mode_list.index(nearest_mode) + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Resize + myGraph[node.output[0]]["inputs"] = input_list + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + myGraph[node.output[0]]["additional"]["input_label"] = input_label + myGraph[node.output[0]]["additional"][ + "coordinate_transformation_mode" + ] = coordinate_transformation_mode + myGraph[node.output[0]]["additional"]["mode"] = mode + myGraph[node.output[0]]["additional"]["nearest_mode"] = nearest_mode + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Less": + additional = {} + additional["data"] = node + if is_constant(node.input[1], model_onnx.graph.initializer): + n2 = getNodesWithOutput(node.input[1], model_onnx) # constant + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data( + node.input[1], + False, + model_onnx.graph, + verbose, + ) + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + map_onnx_to_myGraph[node.input[1]] = node.input[1] + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Compare + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + [ + map_onnx_to_myGraph[node.input[1]] + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + myGraph[node.output[0]]["additional"]["mode"] = 0 + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Greater": + additional = {} + additional["data"] = node + if is_constant(node.input[1], model_onnx.graph.initializer): + n2 = getNodesWithOutput(node.input[1], model_onnx) # constant + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data( + node.input[1], + False, + model_onnx.graph, + verbose, + ) + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + map_onnx_to_myGraph[node.input[1]] = node.input[1] + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Compare + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + [ + map_onnx_to_myGraph[node.input[1]] + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + myGraph[node.output[0]]["additional"]["mode"] = 1 + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Where": + if is_constant(node.input[1], model_onnx.graph.initializer): + additional = {} + additional["data"] = node + n2 = getNodesWithOutput(node.input[1], model_onnx) + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose) + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + map_onnx_to_myGraph[node.input[1]] = node.input[1] + if is_constant(node.input[2], model_onnx.graph.initializer): + additional = {} + additional["data"] = node + n2 = getNodesWithOutput(node.input[2], model_onnx) + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data(node.input[2], False, model_onnx.graph, verbose) + myGraph[node.input[2]] = {} + myGraph[node.input[2]]["op_type"] = OPTYPE.Const + myGraph[node.input[2]]["inputs"] = [] + myGraph[node.input[2]]["additional"] = additional + map_onnx_to_myGraph[node.input[2]] = node.input[2] + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Where + myGraph[node.output[0]]["inputs"] = ( + [map_onnx_to_myGraph[n0name]] + + [map_onnx_to_myGraph[node.input[1]]] + + [map_onnx_to_myGraph[node.input[2]]] + ) + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + elif node.op_type == "Equal": + additional = {} + additional["data"] = node + if is_constant(node.input[1], model_onnx.graph.initializer): + n2 = getNodesWithOutput(node.input[1], model_onnx) # constant + ( + additional["dims"], + additional["raw_data"], + additional["dtype"], + ) = extract_additional_data( + node.input[1], + False, + model_onnx.graph, + verbose, + ) + myGraph[node.input[1]] = {} + myGraph[node.input[1]]["op_type"] = OPTYPE.Const + myGraph[node.input[1]]["inputs"] = [] + myGraph[node.input[1]]["additional"] = additional + map_onnx_to_myGraph[node.input[1]] = node.input[1] + + myGraph[node.output[0]] = {} + myGraph[node.output[0]]["op_type"] = OPTYPE.Compare + myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + [ + map_onnx_to_myGraph[node.input[1]] + ] + myGraph[node.output[0]]["additional"] = {} + myGraph[node.output[0]]["additional"]["data"] = node + myGraph[node.output[0]]["additional"]["mode"] = 2 + map_onnx_to_myGraph[node.output[0]] = node.output[0] + + else: + raise Exception("[ERROR] node not supported:\n{})".format(node)) + + if node_annotation[node.name].add_transpose_after: + n0name = add_transpose_after(node, myGraph, map_onnx_to_myGraph) + + +def parse_onnx(model_onnx, node_annotation, verbose=False): + myGraph, map_onnx_to_myGraph = OrderedDict(), {} + + # Inputs + for inp in model_onnx.graph.input: + myGraph[inp.name] = parse_graph_input_node( + inp, map_onnx_to_myGraph, node_annotation[inp.name].to_transpose + ) + + # Nodes removal + for node in model_onnx.graph.node: + if node.name in node_annotation and node_annotation[node.name].to_remove: + curr_key = node.input[0] + while ( + map_onnx_to_myGraph[curr_key] is not None + and map_onnx_to_myGraph[curr_key] != curr_key + ): + next_key = map_onnx_to_myGraph[curr_key] + curr_key = next_key + if curr_key not in map_onnx_to_myGraph: + curr_key = node.input[0] + break + + map_onnx_to_myGraph[node.output[0]] = curr_key + else: + parse_graph_node( + node, model_onnx, myGraph, node_annotation, map_onnx_to_myGraph, verbose + ) + + myInputs = [] + for inp in model_onnx.graph.input: + myInputs.append(inp.name) + + myOutputs = [] + for out in model_onnx.graph.output: + for key, value in map_onnx_to_myGraph.items(): + if key == out.name: + myOutputs.append(value) + + return myGraph, myInputs, myOutputs + + +QUANTIZERS = dict() + + +def dump_onnx(graph, my_inputs, my_outputs, output_filename, verbose=False): + # graph[my_name]={ op_type + # inputs: [] + # dtype: + # onnx : model.graph.node[x] + # } + + # my_input=[my_name, my_name..] + # outputs=[my_name, ...] + # print(graph) + input_q = 11 + relu_q = 11 + mul_q = 13 + + map_name_to_idx = dict() + for idx, (key, _value) in enumerate(graph.items()): + map_name_to_idx[key] = idx + + # dbg print(map_name_to_idx) + with open(output_filename, "wb") as f: + f.write(str.encode("SADL0004")) + # output of the network type 0: int32 | 1: float | 2: int16 | default: float(1) + f.write(struct.pack("i", int(DTYPE_SADL.FLOAT))) + + if verbose: + print(f"# Nb layers: {len(graph.keys())}") + f.write(struct.pack("i", int(len(graph.keys())))) + + inputs = [] + for name in my_inputs: + inputs.append(map_name_to_idx[name]) + if verbose: + print(f"# Nb inputs: {len(inputs)}") + f.write(struct.pack("i", int(len(inputs)))) + for i in inputs: + if verbose: + print(f"# input {i}") + f.write(struct.pack("i", int(i))) + + outputs = [] + for name in my_outputs: + outputs.append(map_name_to_idx[name]) + if verbose: + print(f"# Nb outputs: {len(outputs)}") + f.write(struct.pack("i", int(len(outputs)))) + for i in outputs: + if verbose: + print(f"# output {i}") + f.write(struct.pack("i", int(i))) + + for name, node in graph.items(): + quantizer_idx = map_name_to_idx[name] + QUANTIZERS[quantizer_idx] = 0 + + if verbose: + print(f"# Layer id {map_name_to_idx[name]}") + f.write(struct.pack("i", int(map_name_to_idx[name]))) + + if verbose: + print("#\t op " + str(node["op_type"])) + f.write(struct.pack("i", int(node["op_type"].value))) + + # Name size + if verbose: + print(f"#\t name_size {len(name)}") + f.write(struct.pack("i", int(len(name)))) + + # Name + if verbose: + print(f"#\t name {name}") + f.write(str.encode(str(name))) + + # Nb inputs + if verbose: + print(f"#\t nb_inputs {len(node['inputs'])}") + f.write(struct.pack("i", int(len(node["inputs"])))) + + for name_i in node["inputs"]: + idx = map_name_to_idx[name_i] + if verbose: + print(f"#\t\t {idx} ({name_i})") + f.write(struct.pack("i", int(idx))) + + # Additional info depending on OPTYPE + if node["op_type"] == OPTYPE.Const: + if verbose: + print(f"#\t nb_dim {len(node['additional']['dims'])}") + f.write(struct.pack("i", int(len(node["additional"]["dims"])))) + + for dim in node["additional"]["dims"]: + if verbose: + print(f"#\t\t {dim}") + f.write(struct.pack("i", int(dim))) + + if verbose: + print(f"#\t dtype {node['additional']['dtype']}") + f.write(struct.pack("i", int(node["additional"]["dtype"]))) + + if node["additional"]["dtype"] != DTYPE_SADL.FLOAT: # not float + if verbose: + print("#\t quantizer 0") + f.write(struct.pack("i", int(0))) + + f.write(node["additional"]["raw_data"]) + + # update names + if ( + args.orig_params_to_new is not None + and name in map_orig_params_to_new_params + ): + orig_name = map_orig_params_to_new_params[name] + else: + orig_name = name + + if ".multiplier" in orig_name: + QUANTIZERS[quantizer_idx] = mul_q + elif "PRelu" in name: + QUANTIZERS[quantizer_idx] = relu_q + else: + QUANTIZERS[quantizer_idx] = base_quantizers.get(orig_name, 0) + + if verbose: + print(orig_name, QUANTIZERS[quantizer_idx]) + + # ??? if "alpha" in layer['additional']: + # f.write(struct.pack('f', float(layer['additional']['alpha']))) + + elif node["op_type"] == OPTYPE.Slice: + dim_keys = ["h", "w", "c"] + for dim in dim_keys: + if verbose: + print( + f"#\t start_depth index for {dim} slicing", + node["additional"][f"start_{dim}"], + ) + print( + f"#\t end_depth index for {dim} slicing", + node["additional"][f"end_{dim}"], + ) + f.write(struct.pack("i", int(node["additional"][f"start_{dim}"]))) + f.write(struct.pack("i", int(node["additional"][f"end_{dim}"]))) + + elif node["op_type"] == OPTYPE.Conv2D: + if verbose: + print("#\t nb_dim_strides", len(node["additional"]["strides"])) + f.write(struct.pack("i", int(len(node["additional"]["strides"])))) + + for stride in node["additional"]["strides"]: + if verbose: + print(f"#\t\t {stride}") + f.write(struct.pack("i", int(stride))) + + if verbose: + print("#\t nb_dim_pads", len(node["additional"]["pads"])) + f.write(struct.pack("i", int(len(node["additional"]["pads"])))) + + for p in node["additional"]["pads"]: + if verbose: + print(f"#\t\t {p}") + f.write(struct.pack("i", int(p))) + + if verbose: + print("#\t nb_group", node["additional"]["group"]) + f.write(struct.pack("i", int(node["additional"]["group"]))) + + elif node["op_type"] == OPTYPE.Conv2DTranspose: + if verbose: + print("#\t nb_dim_strides", len(node["additional"]["strides"])) + f.write(struct.pack("i", int(len(node["additional"]["strides"])))) + + for stride in node["additional"]["strides"]: + if verbose: + print(f"#\t\t {stride}") + f.write(struct.pack("i", int(stride))) + + if verbose: + print("#\t nb_dim_pads", len(node["additional"]["pads"])) + f.write(struct.pack("i", int(len(node["additional"]["pads"])))) + + for p in node["additional"]["pads"]: + if verbose: + print(f"#\t\t {p}") + f.write(struct.pack("i", int(p))) + + if verbose: + print( + "#\t nb_dim_output_padding", + len(node["additional"]["output_padding"]), + ) + f.write( + struct.pack("i", int(len(node["additional"]["output_padding"]))) + ) + + for p in node["additional"]["output_padding"]: + if verbose: + print(f"#\t\t {p}") + f.write(struct.pack("i", int(p))) + + elif node["op_type"] == OPTYPE.Placeholder: + if verbose: + print(f"#\t nb input dimension {len(node['additional']['dims'])}") + f.write(struct.pack("i", int(len(node["additional"]["dims"])))) + + for dim in node["additional"]["dims"]: + if verbose: + print(f"#\t\t {dim}") + f.write(struct.pack("i", int(dim))) + + # output the quantizer of the input default: 0 + if verbose: + print("#\t quantizer_of_input 0") + f.write(struct.pack("i", int(0))) + + if "Equal" not in name: + QUANTIZERS[quantizer_idx] = input_q + else: + QUANTIZERS[quantizer_idx] = 0 + + elif node["op_type"] == OPTYPE.MaxPool: + if verbose: + print("#\t nb_dim_strides", len(node["additional"]["strides"])) + f.write(struct.pack("i", int(len(node["additional"]["strides"])))) + + for stride in node["additional"]["strides"]: + if verbose: + print(f"#\t\t {stride}") + f.write(struct.pack("i", int(stride))) + + if verbose: + print("#\t nb_dim_kernel", len(node["additional"]["kernel_shape"])) + f.write(struct.pack("i", int(len(node["additional"]["kernel_shape"])))) + + for ks in node["additional"]["kernel_shape"]: + if verbose: + print(f"#\t\t {ks}") + f.write(struct.pack("i", int(ks))) + + if verbose: + print("#\t nb_dim_pads", len(node["additional"]["pads"])) + f.write(struct.pack("i", int(len(node["additional"]["pads"])))) + + for p in node["additional"]["pads"]: + if verbose: + print(f"#\t\t {p}") + f.write(struct.pack("i", int(p))) + + elif node["op_type"] == OPTYPE.Flatten: + if verbose: + print("#\t axis", node["additional"]["axis"]) + f.write(struct.pack("i", int(node["additional"]["axis"]))) + + elif node["op_type"] == OPTYPE.GridSample: + if verbose: + print("#\t align_corners", node["additional"]["align_corners"]) + f.write(struct.pack("i", int(node["additional"]["align_corners"]))) + + if verbose: + print("#\t mode", node["additional"]["mode"]) + f.write(struct.pack("i", int(node["additional"]["mode"]))) + + if verbose: + print("#\t padding_mode", node["additional"]["padding_mode"]) + f.write(struct.pack("i", int(node["additional"]["padding_mode"]))) + + elif node["op_type"] == OPTYPE.Resize: + if verbose: + print("#\t input_label", node["additional"]["input_label"]) + f.write(struct.pack("i", int(node["additional"]["input_label"]))) + + if verbose: + print( + "#\t coordinate_transformation_mode", + node["additional"]["coordinate_transformation_mode"], + ) + f.write( + struct.pack( + "i", int(node["additional"]["coordinate_transformation_mode"]) + ) + ) + + if verbose: + print("#\t mode", node["additional"]["mode"]) + f.write(struct.pack("i", int(node["additional"]["mode"]))) + + if verbose: + print("#\t nearest_mode", node["additional"]["nearest_mode"]) + f.write(struct.pack("i", int(node["additional"]["nearest_mode"]))) + + elif node["op_type"] == OPTYPE.Compare: + if verbose: + print("#\t mode", node["additional"]["mode"]) + f.write(struct.pack("i", int(node["additional"]["mode"]))) + + if ( + node["op_type"] == OPTYPE.Conv2D + or node["op_type"] == OPTYPE.Conv2DTranspose + or node["op_type"] == OPTYPE.MatMul + or node["op_type"] == OPTYPE.Mul + ): + # output the internal quantizer default: 0 + f.write(struct.pack("i", int(0))) + + if verbose: + print("") + + +# adatp (remove/add) the current node to the data_layout and +# recurse in the output +def annotate_node( + node, model_onnx, node_annotation, global_data_layout, verbose +): # recusrive + if node.name in node_annotation: + return + if verbose > 1: + print("[INFO] annotate {}".format(node.name)) + + if verbose: + print("[INFO] annotate_node {} op={}".format(node.name, node.op_type)) + data_layout = None + + # inherit from input + for inp in node.input: + n2 = getNodesWithOutputNotConst(inp, model_onnx) + if n2 is not None: + if n2.name in node_annotation: + if data_layout is None: + data_layout = node_annotation[n2.name].layout_onnx + elif ( + node_annotation[n2.name].layout_onnx is not None + and node_annotation[n2.name].layout_onnx != data_layout + ): + quit( + "[ERROR] inputs with diferent layout for\n{}Layouts: {}".format( + node, node_annotation + ) + ) + else: # not ready yet + return + + if verbose > 1 and data_layout is None: + print( + "[WARNING] no data layout constraints for {}\n {}".format(node.name, node) + ) + + if node.name not in node_annotation: + node_annotation[node.name] = Node_Annotation() + node_annotation[node.name].layout_onnx = data_layout # default + + if node.op_type == "Transpose": + a = getAttribute(node, "perm") + if data_layout == "nhwc": + if ( + a.ints[0] == 0 and a.ints[1] == 3 and a.ints[2] == 1 and a.ints[3] == 2 + ): # nhwc ->nchw + node_annotation[node.name].to_remove = True # will be removed + node_annotation[node.name].layout_onnx = "nchw" # new layout at output + else: + if verbose > 1: + print("[WARNING] transpose not for NCHW handling in\n", node) + elif data_layout == "nchw": + if ( + a.ints[0] == 0 and a.ints[1] == 2 and a.ints[2] == 3 and a.ints[3] == 1 + ): # nchw ->nhwc + node_annotation[node.name].to_remove = True # will be removed + node_annotation[node.name].layout_onnx = "nhwc" # new layout at output + else: + if verbose > 1: + print("[WARNING] transpose not for NCHW handling in\n", node) + + if node_annotation[node.name].to_remove: + # The GridSample is usually used with Transpose. Cause the optical-flow will be + # transposed from (N,2,H,W) to (N,H,W,2) and this operation should not be removed. + # Meanwhile there will not have other operations after transposed feature with + # shape (N,H,W,2) in PyTorch, so the code in this IF statement will not influnce + # other situations. + nexts = getNodesWithInput(node.output[0], model_onnx) + for n in nexts: + if n.op_type == "GridSample": + node_annotation[node.name].to_remove = False + node_annotation[node.name].layout_onnx = "nchw" + break + + elif node.op_type == "Reshape": + initializer = getInitializer(node.input[1], model_onnx) + # Case: In pytorch, Reshape is not in model_onnx.graph.initializer but in model_onnx.graph.node + if initializer is None: + attribute = getAttribute( + getNodesWithOutput(node.input[1], model_onnx), "value" + ) + initializer = attribute.t + dims = getDims(initializer) + + # detect if this reshape is actually added by onnx to emulate a transpose + # we need to test more if reshpae is for transpose... + if len(dims) == 4 and (dims[0] == 1 or dims[0] == -1): + if data_layout == "nhwc": + if dims[1] == 1: # or dims2 * dims3 == 1 # nhwc ->nchw + node_annotation[node.name].to_remove = True # will be removed + node_annotation[ + node.name + ].layout_onnx = "nchw" # new layout at output + else: + if verbose > 1: + print("[WARNING] reshape unknown for", node, " dims", dims) + node_annotation[node.name].layout_onnx = None + elif data_layout == "ncwh": + if dims[3] == 1: # # or dims2 * dims3 == 1 nchw ->nhwc + node_annotation[node.name].to_remove = True # will be removed + node_annotation[ + node.name + ].layout_onnx = "nhwc" # new layout at output + else: + if verbose > 1: + print("[WARNING] reshape unknown for", node, " dims", dims) + node_annotation[node.name].layout_onnx = None + elif data_layout is None: + node_annotation[ + node.name + ].layout_onnx = global_data_layout # back to org + if global_data_layout == "nchw": + node_annotation[ + node.name + ].add_transpose_after = True # a bit too agressive + else: + node_annotation[node.name].layout_onnx = None + + n2 = getNodesWithOutputNotConst(node.input[0], model_onnx) + if ( + node_annotation[n2.name].layout_onnx == "nchw" + ): # need to go back to original layout before reshape + node_annotation[node.name].add_transpose_before = True + + elif node.op_type == "Flatten": + if ( + node_annotation[node.name].layout_onnx == "nchw" + ): # need to go back to original layout before reshape + node_annotation[node.name].add_transpose_before = True + + elif node.op_type == "Concat": + if data_layout == "nchw": # nhwc -> nhwc + a = getAttribute(node, "axis") + if a.i == 1: + a.i = 3 + elif a.i == 2: + a.i = 1 + elif a.i == 3: + a.i = 2 + elif a.i == -3: + a.i = -1 + + elif node.op_type == "Unsqueeze": + node_annotation[node.name].to_remove = True + + elif node.op_type == "Conv": + n2 = getInitializer(node.input[1], model_onnx) + node_annotation[n2.name].to_transpose = True + node_annotation[n2.name].layout_onnx = "nhwc" + + elif node.op_type == "ConvTranspose": + n2 = getInitializer(node.input[1], model_onnx) + node_annotation[n2.name].to_transpose = True + node_annotation[n2.name].layout_onnx = "nhwc" + + elif node.op_type == "Gemm": + n2 = getInitializer(node.input[1], model_onnx) + if global_data_layout == "nchw": + node_annotation[n2.name].to_transpose = True + # node_annotation[n2.name].layout_onnx = 'nhwc' + + elif node.op_type == "ScatterND": + n2 = getInitializer(node.input[1], model_onnx) + if global_data_layout == "nchw": + node_annotation[n2.name].to_transpose = True + + nexts = getNodesWithInput(node.output[0], model_onnx) + for n in nexts: + annotate_node( + n, model_onnx, node_annotation, global_data_layout, verbose + ) # rec + + +def annotate_graph(model_onnx, node_annotation, data_layout, verbose): + + # track the data layout in the graph and remove/add layers if necessary + for inp in model_onnx.graph.input: + node_annotation[inp.name] = Node_Annotation() + if len(inp.type.tensor_type.shape.dim) == 4: + node_annotation[inp.name].layout_onnx = data_layout + if data_layout == "nchw": + node_annotation[inp.name].to_transpose = True + else: + node_annotation[inp.name].layout_onnx = None + + for inp in model_onnx.graph.initializer: + node_annotation[inp.name] = Node_Annotation() + node_annotation[inp.name].layout_onnx = None + + for inp in model_onnx.graph.node: + if inp.op_type == "Constant": + node_annotation[inp.name] = Node_Annotation() + node_annotation[inp.name].layout_onnx = None + + for inp in model_onnx.graph.input: + nexts = getNodesWithInput(inp.name, model_onnx) + for n in nexts: + annotate_node( + n, model_onnx, node_annotation, data_layout, verbose + ) # recusrive + + if verbose > 1: + for node in model_onnx.graph.node: + if node.op_type == "Transpose" and ( + node.name not in node_annotation + or not node_annotation[node.name].to_remove + ): + print( + "[ERROR] preprocess_onnxGraph: all transpose node should be removed but this is not the case here: {}\n{}".format( + node.name, node + ) + ) + + +def detectDataType(model): # more adaptation to do here if tf is using nchw + if model.producer_name == "tf2onnx": + return "nhwc" + elif model.producer_name == "pytorch": + return "nchw" + else: + quit("[ERROR] unable to detect data layout") + + +def dumpModel(model_onnx, output_filename, data_layout, verbose, user_annotation): + """Writes the neural network model in the \"sadl\" format to binary file. + + Parameters + ---------- + model : onnx model + output_filename : either str or None + Path to the binary file to which the neural network model + is written. + data_type: None, 'ncwh' or 'nwhc' + verbose : bool + Is additional information printed? + """ + model_onnx_copy = copy.deepcopy(model_onnx) + if data_layout is None: + data_layout = detectDataType(model_onnx_copy) + + if verbose: + print("[INFO] assume data type", data_layout) + + if verbose > 1: + # remove data + gg = copy.deepcopy(model_onnx.graph) + for node in gg.initializer: + node.raw_data = np.array(0.0).tobytes() + print("[INFO] original graph:\n", gg) + del gg + + if data_layout != "nhwc" and data_layout != "nchw": + quit("[ERROR] unsupported layout", data_layout) + + node_annotation = {} + annotate_graph(model_onnx_copy, node_annotation, data_layout, verbose) + + for k, v in user_annotation.items(): + if k in node_annotation: + if v.add_transpose_before is not None: + node_annotation[k].add_transpose_before = v.add_transpose_before + if v.add_transpose_after is not None: + node_annotation[k].add_transpose_after = v.add_transpose_after + if v.to_remove is not None: + node_annotation[k].to_remove = v.to_remove + if v.to_transpose is not None: + node_annotation[k].to_transpose = v.to_transpose + else: + print("[ERROR] unknown node user custom", k) + quit() + + if verbose > 1: + print( + "INFO] annotations:\n{" + + "\n".join("{!r}: {!r},".format(k, v) for k, v in node_annotation.items()) + + "}" + ) # print("[INFO] node annotations:", node_annotation) + my_graph, my_inputs, my_outputs = parse_onnx( + model_onnx_copy, node_annotation, verbose=verbose + ) + dump_onnx(my_graph, my_inputs, my_outputs, output_filename, verbose=verbose) + if data_layout == "nchw": + print( + "[INFO] in SADL, your inputs and outputs has been changed from NCHW to NHWC" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="onnx2sadl conversion", usage="NB: force run on CPU" + ) + parser.add_argument( + "--input_onnx", + action="store", + nargs="?", + type=str, + help="name of the onnx file", + ) + parser.add_argument( + "--output", + action="store", + nargs="?", + type=str, + help="name of model binary file", + ) + parser.add_argument( + "--out_quant", + action="store", + nargs="?", + type=str, + default="", + help="name of quantizer file", + ) + parser.add_argument( + "--base_quantizers", + action="store", + nargs="?", + type=str, + help="path to base quantizers file", + ) + parser.add_argument( + "--orig_params_to_new", + action="store", + nargs="?", + type=str, + help="path to the file that contains the mapping of orig param names to new", + ) + parser.add_argument("--nchw", action="store_true") + parser.add_argument("--nhwc", action="store_true") + parser.add_argument("--verbose", action="count") + parser.add_argument( + "--do_not_add_transpose_before", + action="store", + nargs="+", + default=[], + help="specify a node where add transpose before will be disable", + ) + parser.add_argument( + "--do_not_add_transpose_after", + action="store", + nargs="+", + default=[], + help="specify a node where add transpose after will be disable", + ) + + args = parser.parse_args() + if args.input_onnx is None: + quit("[ERROR] You should specify an onnx file") + if args.output is None: + quit("[ERROR] You should specify an output file") + + print("[INFO] ONNX converter") + if args.verbose is None: + args.verbose = 0 + + model_onnx = onnx.load(args.input_onnx) + + user_annotation = {} + for node in args.do_not_add_transpose_before: + if node not in user_annotation: + user_annotation[node] = Node_Annotation() + user_annotation[node].to_remove = None + user_annotation[node].add_transpose_before = None + user_annotation[node].add_transpose_after = None + user_annotation[node].to_transpose = None + user_annotation[node].add_transpose_before = False + + for node in args.do_not_add_transpose_after: + if node not in user_annotation: + user_annotation[node] = Node_Annotation() + user_annotation[node].to_remove = None + user_annotation[node].add_transpose_before = None + user_annotation[node].add_transpose_after = None + user_annotation[node].to_transpose = None + user_annotation[node].add_transpose_after = False + + data_layout = None + if args.nchw: + data_layout = "nchw" + elif args.nhwc: + data_layout = "nhwc" + + base_quantizers = read_json_file(Path(args.base_quantizers)) + + if args.orig_params_to_new is not None: + map_orig_params_to_new_params = read_json_file(Path(args.orig_params_to_new)) + + dumpModel(model_onnx, args.output, data_layout, args.verbose, user_annotation) + + original_stdout = sys.stdout + if args.out_quant != "": + with open(args.out_quant, "w") as f: + sys.stdout = f # Change the standard output to the file we created. + for idx in QUANTIZERS: + print(f"{idx} {QUANTIZERS[idx]}", end=" ") + sys.stdout = ( + original_stdout # Reset the standard output to its original value + ) + + # The modifications: output quantizers when dumping onnx model (in the function dump_onnx()) diff --git a/training/training_scripts/NN_Adaptive_Filtering/conversion/sadl2torch.py b/training/training_scripts/NN_Adaptive_Filtering/conversion/sadl2torch.py new file mode 100644 index 0000000000..32f58f7004 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/conversion/sadl2torch.py @@ -0,0 +1,539 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import hashlib +import os +import sys +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Dict + +import numpy as np +import torch + +from conversion.onnx2sadl_modified import DTYPE_SADL, OPTYPE +from models import load_torch_model, model_lop2, model_lop2_with_multiplier +from util.file_system import read_json_file, write_json_file + +# used to load tensor in SADL, the max dimension of tensor is 6 +MAX_DIMENSION = 6 +verbose = 0 + + +def md5file(filename): + md5 = hashlib.md5() + # handle content in binary form + f = open(filename, "rb") + while True: + chunk = f.read(4096) + if not chunk: + break + md5.update(chunk) + return md5.hexdigest() + + +def load_prefix(sadl_file): + L = int.from_bytes(sadl_file.read(4), byteorder="little") + maxLength = 2048 + assert L > 0 and L + 1 < maxLength # max name size + layer_name = sadl_file.read(L).decode("utf-8") + if verbose: + print(f" - name: {layer_name}") + + L = int.from_bytes(sadl_file.read(4), byteorder="little") + assert 0 <= L < 8 + inputs_id = [ + int.from_bytes(sadl_file.read(4), byteorder="little") for _ in range(L) + ] + if verbose: + print(" - inputs:", inputs_id) + return layer_name + + +def load_internal_const(sadl_file): + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: + quit(f"[ERROR] invalid nb of dimensions: {x}") + + d = [0] * x + for k in range(len(d)): + d[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - tensor: {d}") + + num_elements = np.prod(d) + ele_type = int.from_bytes(sadl_file.read(4), byteorder="little") + + if ele_type == DTYPE_SADL.FLOAT: + ele_dtype = np.float32 + quantizer = 0 + elif ele_type == DTYPE_SADL.INT16: + ele_dtype = np.int16 + quantizer = int.from_bytes(sadl_file.read(4), byteorder="little") + elif ele_type == DTYPE_SADL.INT32: + ele_dtype = np.int32 + quantizer = int.from_bytes(sadl_file.read(4), byteorder="little") + else: + if verbose: + print("[ERROR] unknown internal type") + + element_size = np.dtype(ele_dtype).itemsize + b_data = sadl_file.read(element_size * num_elements) + data = np.frombuffer(b_data, ele_dtype) + # dequantize data + data = data / float(2**quantizer) + + if verbose: + print(" - data:", end=" ") + for k in range(4 if len(data) >= 4 else len(data)): + if verbose: + print(data[k], end=" ") + if verbose: + print("...") + return d, data.copy(), quantizer + + +def load_internal_conv2d(sadl_file): + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: + quit(f"[ERROR] invalid nb of dimensions: {x}") + + strides = [0] * x + for k in range(len(strides)): + strides[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + + if len(strides) == 2: + strides = [1, strides[0], strides[1], 1] + + if len(strides) != 4: + quit(f"[ERROR] invalid strides: {len(strides)}") + + if strides[0] != 1: + quit(f"[ERROR] invalid strides[0]: {strides[0]}") + + if strides[3] != 1: + quit(f"[ERROR] invalid strides[3]: {strides[3]}") + + if strides[1] != 1 and strides[1] != 2: + quit(f"[ERROR] not 1 or 2: to check {strides}") + + if strides[2] != 1 and strides[2] != 2: + quit(f"[ERROR] not 1 or 2: to check {strides}") + + if verbose: + print(f" - strides: {strides}") + + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: + quit(f"[ERROR] invalid nb of dimensions: {x}") + + pads = [0] * x + for k in range(len(pads)): + pads[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - pads: {pads}") + + groups = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - groups: {groups}") + + quantizer = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - q: {quantizer}") + + +def load_internal_conv2dTranspose(sadl_file): + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: # Dimensions.MaxDim + quit(f"[ERROR] invalid nb of dimensions: {x}") + + strides = [0] * x + for k in range(len(strides)): + strides[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + + if len(strides) == 2: + strides = [1, strides[0], strides[1], 1] + + if len(strides) != 4: + quit(f"[ERROR] invalid strides: {len(strides)}") + + if strides[0] != 1: + quit(f"[ERROR] invalid strides[0]: {strides[0]}") + + if strides[3] != 1: + quit(f"[ERROR] invalid strides[3]: {strides[3]}") + + if strides[1] != 2 or strides[2] != 2: + quit(f"[ERROR] stride not 2: to check {strides}") + + if verbose: + print(f" - strides: {strides}") + + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: # Dimensions.MaxDim + quit(f"[ERROR] invalid nb of dimensions: {x}") + + pads = [0] * x + for k in range(len(pads)): + pads[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + if pads[k] != 1 and pads[k] != 2: + quit(f"[ERROR] pads values not supported: {pads[k]}") + + if verbose: + print(f" - pads: {pads}") + + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: + quit(f"[ERROR] invalid nb of dimensions: {x}") + + out_pads = [0] * x + for k in range(len(out_pads)): + out_pads[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + if out_pads[k] != 1: + quit(f"[ERROR] output pads !=1 {out_pads[k]}") + + if verbose: + print(f" - out_pads: {out_pads}") + + quantizer = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - q: {quantizer}") + + +def load_internal_flatten(sadl_file): + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: # Dimensions.MaxDim + quit(f"[ERROR] invalid axis: {x}") + + m_axis = x + if verbose: + print(f" - start axis: {m_axis}") + + +def load_internal_matmul(sadl_file): + quantizer = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - q: {quantizer}") + + +def load_internal_maxpool(sadl_file): + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: + quit(f"[ERROR] invalid nb of dimensions strides: {x}") + + strides = [0] * x + for k in range(len(strides)): + strides[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - strides: {strides}") + + if len(strides) != 4: + quit(f"[ERROR] invalid strides: {len(strides)}") + + if strides[0] != 1: + quit(f"[ERROR] invalid strides[0]: {strides[0]}") + + if strides[3] != 1: + quit(f"[ERROR] invalid strides[3]: {strides[3]}") + + if strides[1] != strides[2]: + quit(f"[ERROR] invalid stride H Vs: {strides}") + + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: + quit(f"[ERROR] invalid nb of dimensions kernel: {x}") + + kernel = [0] * x + for k in range(len(kernel)): + kernel[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - kernel: {kernel}") + + if len(kernel) != 4: + quit(f"[ERROR] invalid kernel: {len(kernel)}") + + if kernel[0] != 1: + quit(f"[ERROR] invalid kernel[0]: {kernel[0]}") + + if kernel[3] != 1: + quit(f"[ERROR] invalid kernel[3]: {kernel[3]}") + + if kernel[1] != kernel[2]: + quit(f"[ERROR] invalid kernel H V: {kernel}") + + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: + quit(f"[ERROR] invalid nb of dimensions: {x}") + + pads = [0] * x + for k in range(len(pads)): + pads[k] = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - pads: {pads}") + + +def load_internal_mul(sadl_file): + quantizer = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - q: {quantizer}") + + +def load_internal_placeholder(sadl_file): + x = int.from_bytes(sadl_file.read(4), byteorder="little") + if x <= 0 or x > MAX_DIMENSION: # Dimensions.MaxDim + quit(f"[ERROR] invalid nb of dimensions: {x}") + + dims = [int.from_bytes(sadl_file.read(4), byteorder="little") for _ in range(x)] + if len(dims) == 1: + dims = [1] + dims + + quantizer = int.from_bytes(sadl_file.read(4), byteorder="little") + + if verbose: + print(f" - dim: {dims}") + if verbose: + print(f" - q: {quantizer}") + + +def load_internal_slice(sadl_file): + start_h = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - start_h: {start_h}") + + end_h = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - end_h: {end_h}") + + start_w = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - start_w: {start_w}") + + end_w = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - end_w: {end_w}") + + start_c = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - start_c: {start_c}") + + end_c = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f" - end_c: {end_c}") + + +MAP_OPERATION_TO_LOAD_FUNC = { + OPTYPE.Const: load_internal_const, + OPTYPE.Conv2D: load_internal_conv2d, + OPTYPE.Conv2DTranspose: load_internal_conv2dTranspose, + OPTYPE.Flatten: load_internal_flatten, + OPTYPE.MatMul: load_internal_matmul, + OPTYPE.MaxPool: load_internal_maxpool, + OPTYPE.Mul: load_internal_mul, + OPTYPE.Placeholder: load_internal_placeholder, + OPTYPE.Slice: load_internal_slice, +} + + +def load_internal(sadl_file, operation): + try: + op_func = MAP_OPERATION_TO_LOAD_FUNC[operation] + out = op_func(sadl_file) + except Exception: + # No need to load internal for current layer + return + + return out + + +def read_sadl_file(sadl_model_path): + base_model_quantizers = {} + sadl_file = open(sadl_model_path, "rb") + magic = sadl_file.read(8).decode() + if verbose: + print(f"[INFO] read magic {magic}") + + x = int.from_bytes(sadl_file.read(4), byteorder="little") + model_type = DTYPE_SADL(x).name + if verbose: + print(f"[INFO] Model type: {model_type}") + + nb_layers = int.from_bytes(sadl_file.read(4), byteorder="little") + if verbose: + print(f"[INFO] Num layers: {nb_layers}") + + # get input and output id + nb = int.from_bytes(sadl_file.read(4), byteorder="little") + ids_input = list( + int.from_bytes(sadl_file.read(4), byteorder="little") for _ in range(nb) + ) + nb = int.from_bytes(sadl_file.read(4), byteorder="little") + ids_output = list( + int.from_bytes(sadl_file.read(4), byteorder="little") for _ in range(nb) + ) + if verbose: + print("[INFO] input id:", ids_input) + if verbose: + print("[INFO] output id:", ids_output) + + # load layers + const_layers = {} + for k in range(nb_layers): + layer_id = int.from_bytes(sadl_file.read(4), byteorder="little") + layer_op = int.from_bytes(sadl_file.read(4), byteorder="little") + if not (0 < layer_op < OPTYPE.Count): + quit(f"[ERROR] Pb reading model: layer op {layer_op}") + + if verbose: + print(f"[INFO] id: {layer_id} op {OPTYPE(layer_op).name}") + layer_name = load_prefix(sadl_file) + out = load_internal(sadl_file, OPTYPE(layer_op)) + if OPTYPE(layer_op) == OPTYPE.Const: + if ( + ".weight" in layer_name + or ".bias" in layer_name + or "PRelu" in layer_name + ): + const_layers[layer_name] = {} + const_layers[layer_name]["dimension"] = out[0] + const_layers[layer_name]["data"] = out[1] + base_model_quantizers[layer_name] = out[2] + + if verbose: + print("[INFO] == end model loading ==\n") + + return const_layers, base_model_quantizers + + +def map_sadl_param_name_to_torch(layer_data: Dict): + # only names of PRelu need to be modified + param_names = list(layer_data.keys()) + for i, k in enumerate(param_names): + if "PRelu" in k: + name_of_last_param = param_names[i - 1].split(".") + name_of_current_param = ( + ".".join(name_of_last_param[:-2]) + + "." + + str(int(name_of_last_param[-2]) + 1) + + ".weight" + ) + layer_data[name_of_current_param] = layer_data[k] + layer_data.pop(k, None) + + +def create_model(orig_sadl_model, orig_model_config): + layer_data, base_model_quantizers = read_sadl_file(orig_sadl_model) + map_sadl_param_name_to_torch(layer_data) + + torch_model = load_torch_model(orig_model_config, model_lop2) + for name, _ in torch_model.SADL_model.state_dict().items(): + dimension = layer_data[name]["dimension"] + weights = torch.Tensor(layer_data[name]["data"]) + # in SADL, dim format of conv weights is WHC_inC_out, convert to C_outC_intWH here + if len(dimension) == 4: + weights = weights.reshape(dimension) + weights = weights.permute(3, 2, 0, 1) + torch_model.SADL_model.state_dict()[name][:] = weights + + return torch_model, base_model_quantizers + + +def map_params(orig_model_config) -> Dict[str, str]: + orig_model = load_torch_model(orig_model_config, model_lop2) + orig_params = orig_model.state_dict() + + new_model = load_torch_model(orig_model_config, model_lop2_with_multiplier) + new_params = new_model.state_dict() + + param_map = {} + + tmp_params = list() + + for new_param in new_params: + if "multiplier" not in new_param: + tmp_params.append(new_param) + + for (orig_param, new_param) in zip(orig_params.keys(), tmp_params): + tmp_orig_param = orig_param.replace("SADL_model.", "") + tmp_new_param = new_param.replace("SADL_model.", "") + param_map[tmp_new_param] = tmp_orig_param + + return param_map + + +if __name__ == "__main__": + config = read_json_file("resources/config.json") + + if "lop2" != config["architecture"]: + print(f"{config['architecture']} not supported") + sys.exit(-1) + + model_config = read_json_file("models/models.json")[config["architecture"]] + + orig_sadl_model_i = config["sadl2torch"]["base_model_sadl_i"] + output_path = config["sadl2torch"]["base_model_torch"] + + verbose = 0 + + # dequantized model + torch_model_dq, base_quantizers = create_model(orig_sadl_model_i, model_config) + torch.save(torch_model_dq.state_dict(), output_path) + + # save quantizers of base model + base_quantizers_file = config["sadl2torch"]["base_model_quantizers"] + write_json_file(base_quantizers, base_quantizers_file) + + org_to_new_params = map_params(model_config) + write_json_file( + org_to_new_params, Path(config["sadl2torch"]["base_model_params_to_new"]) + ) + + # NOTE: current scripts can be used to convert SADL0004 version to Torch + if verbose: + with NamedTemporaryFile() as tmp_onnx, NamedTemporaryFile() as tmp_sadl_f, NamedTemporaryFile() as tmp_sadl_i, NamedTemporaryFile() as tmp_quantizers: + torch_model_dq.SADL_model.to_onnx(tmp_onnx.name) + + os.system( + f"python conversion/onnx2sadl_modified.py --input {tmp_onnx.name} --output {tmp_sadl_f.name} --out_quant {tmp_quantizers.name} --base_model_quantizers {base_quantizers_file}" + ) + os.system( + f"tail -n 1 {tmp_quantizers.name} | ./naive_quantization {tmp_sadl_f.name} {tmp_sadl_i.name}" + ) + + # check equivalence of base int model and requantized dequantized base model + src_md5 = md5file(orig_sadl_model_i) + dst_md5 = md5file(tmp_sadl_i.name) + + assert ( + src_md5 == dst_md5 + ), "Something went wrong in new custom SADL model reader" diff --git a/training/training_scripts/NN_Adaptive_Filtering/conversion/torch2sadl.py b/training/training_scripts/NN_Adaptive_Filtering/conversion/torch2sadl.py new file mode 100644 index 0000000000..a46251c8f9 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/conversion/torch2sadl.py @@ -0,0 +1,130 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os +import sys +from pathlib import Path +from typing import Union + +import torch + +from models import load_torch_model, model_lop2_with_multiplier +from util.file_system import create_directory, read_json_file + + +def torch_to_onnx( + torch_model_file: Union[Path, str], + out_onnx_path: Union[Path, str], + model_config: Union[Path, str], + model_module: torch.nn.Module = model_lop2_with_multiplier, +): + # Load model + model = load_torch_model(model_config, model_module) + model.load_state_dict(torch.load(torch_model_file)) + + # Torch to Onnx + model.SADL_model.to_onnx(out_onnx_path) + + # remove identity layer (cause errors in SADL conversion) + exit_code = os.system( + f"python -m onnxoptimizer {out_onnx_path} {out_onnx_path} -p eliminate_identity" + ) + if exit_code: + sys.exit(-1) + + +def onnx_to_sadl( + sadl_models_dir, + onnx_model_path, + quantizers_dir, + sadl_int_models_dir, + base_quantizers, + orig_params_to_new=None, +): + sadl_model_file = sadl_models_dir / f"{onnx_model_path.stem}.sadl" + q_out = quantizers_dir / f"Q_{onnx_model_path.stem}.log" + command = ( + f"python conversion/onnx2sadl_modified.py --input_onnx {onnx_model_path} --output {sadl_model_file} " + f"--out_quant {q_out} --base_quantizers {base_quantizers}" + ) + if orig_params_to_new is not None: + command += f" --orig_params_to_new {orig_params_to_new}" + if os.system(command): + sys.exit(-1) + + # convert float SADL to int16 + output_file = sadl_int_models_dir / f"{onnx_model_path.stem}.sadl" + command_int = ( + f"tail -n 1 {q_out} | ./naive_quantization {sadl_model_file} {output_file}" + ) + if os.system(command_int): + sys.exit(-1) + + +def torch_model_to_sadl(): + input_models_files = (root_dir / "nnr_models").glob("*.pt") + + sadl_float_dir = root_dir / "sadl_float_models" + sadl_int_dir = root_dir / "sadl_int_models" + quantizers_dir = root_dir / "quantizers" + onnx_models_dir = root_dir / "onnx_models" + create_directory(onnx_models_dir) + create_directory(sadl_float_dir) + create_directory(sadl_int_dir) + create_directory(quantizers_dir) + + # Torch to ONNX + for torch_model in input_models_files: + model_name = torch_model.stem + + print("-----------------------------") + print(model_name) + print("-----------------------------") + + # Torch to Onnx + out_onnx_path = onnx_models_dir / f"{model_name}.onnx" + torch_to_onnx(torch_model, out_onnx_path, model_config_file) + + # Onnx to SADL + onnx_to_sadl(sadl_float_dir, out_onnx_path, quantizers_dir, sadl_int_dir) + + +if __name__ == "__main__": + test_cfg = read_json_file(Path("resources/config.json")) + root_dir = Path(test_cfg["training"]["output_path"]) + model_config_file = Path("models/models.json") + + torch_model_to_sadl() + + sys.exit(0) diff --git a/training/training_scripts/NN_Adaptive_Filtering/create_dataset_dirs.py b/training/training_scripts/NN_Adaptive_Filtering/create_dataset_dirs.py new file mode 100644 index 0000000000..c190cc6421 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/create_dataset_dirs.py @@ -0,0 +1,57 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from pathlib import Path + +from util.file_system import create_directory, read_json_file + + +def create_directories() -> None: + dataset_orig_path = create_directory(config["dataset"]["orig_path"]) + dataset_deco_path = create_directory(config["dataset"]["deco_path"]) + + seq_labels = read_json_file(Path("resources/datasets/jvet_labels.json")) + qps = ["22", "27", "32", "37", "42"] + + for seq, label in seq_labels.items(): + create_directory(dataset_orig_path / label) + + deco_seq_path = create_directory(dataset_deco_path / label) + for qp in qps: + create_directory(deco_seq_path / qp) + + +if __name__ == "__main__": + config = read_json_file(Path("resources/config.json")) + create_directories() diff --git a/training/training_scripts/NN_Adaptive_Filtering/create_env.sh b/training/training_scripts/NN_Adaptive_Filtering/create_env.sh new file mode 100755 index 0000000000..218d0e8581 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/create_env.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +python3 -m venv ${HOME}/lop2_overf +source ${HOME}/lop2_overf/bin/activate + +# Torch dependencies +pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 + +pip install click==8.1.7 openpyxl + +# NCTM dependencies +pip install protobuf==3.20.2 memory_profiler==0.55.0 dcase-util==0.2.10 pandas opencv-python scikit-image scikit-learn tensorflow-gpu==2.8.0 + +# SADL dependencies +pip install tqdm onnxoptimizer + +deactivate diff --git a/training/training_scripts/NN_Adaptive_Filtering/data_preparation.py b/training/training_scripts/NN_Adaptive_Filtering/data_preparation.py new file mode 100644 index 0000000000..caa9571caf --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/data_preparation.py @@ -0,0 +1,121 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os +from pathlib import Path + +from util.file_system import ( + check_directory, + check_file, + list_dirs, + read_json_file, + write_json_file, +) +from util.regex import get_frame_info_from_decoder_log_line + + +def extract_coding_info(root_dir: Path) -> None: + """ + Extracts coding info from the configuration file and the encoding log + :param root_dir: Directory that contains the coding simulation results + """ + seq_dirs = list_dirs(root_dir) + for seq_dir in seq_dirs: + qp_dirs = list_dirs(seq_dir) + for qp_dir in qp_dirs: + coding_info = { + "base_qp": int(qp_dir.name), + "POC": {}, + } + + log_file = check_file(qp_dir / "log_dec.txt") + with open(log_file, "r") as stream: + for line in stream: + if line.startswith("POC"): + ( + poc, + temporal_layer, + slice_type, + slice_qp, + ) = get_frame_info_from_decoder_log_line(line) + + coding_info["POC"][poc] = { + "temporal_layer": temporal_layer, + "slice_type": slice_type, + "slice_qp": slice_qp, + } + + write_json_file(coding_info, qp_dir / "coding_info.json") + + +def decode_nnvc_bitstream( + dataset_dir: Path, decoder_bin: Path, lop_filter_path: Path, intra_filter_dir: Path +) -> None: + """ + Decodes video bitstreams using NNVC + :param dataset_dir: Directory that contains the bitstream (e.g. root_dir -> seq_dir -> qp -> bitstream.266) + :param decoder_bin: Video decoder binary + :param lop_filter_path: Default LOP2 filter + :param intra_filter_dir: Directory that contains intra filters + """ + seq_dirs = list_dirs(dataset_dir) + + for seq_dir in seq_dirs: + qp_dirs = list_dirs(seq_dir) + for qp_dir in qp_dirs: + print(seq_dir.name, qp_dir.name) + bitstream = qp_dir / f"{seq_dir.name}_{qp_dir.name}.266" + check_file(bitstream) + log_file = qp_dir / "log_dec.txt" + command = f"{decoder_bin} -b {bitstream} --DumpBasename={qp_dir / seq_dir.name} --NnlfModelName={lop_filter_path} --PrefixAbsolutePathsToGraphsOutput={intra_filter_dir} > {log_file}" + os.system(command) + + +def run_data_preparation() -> None: + dataset_deco_path = check_directory(config["dataset"]["deco_path"]) + decode_nnvc_bitstream( + config["dataset"]["deco_path"], + config["decoder_bin"], + config["sadl2torch"]["base_model_sadl_i"], + config["intra_filters"], + ) + + # Extract POC, frame QP and temporal layer + extract_coding_info(dataset_deco_path) + + +if __name__ == "__main__": + config = read_json_file(Path("resources/config.json")) + seq_cfg = read_json_file(Path("resources/datasets/jvet.json")) + run_data_preparation() diff --git a/training/training_scripts/NN_Adaptive_Filtering/environment.yml b/training/training_scripts/NN_Adaptive_Filtering/environment.yml new file mode 100644 index 0000000000..569d237f7f --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/environment.yml @@ -0,0 +1,163 @@ +name: lop2_overf +channels: + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotli-python=1.0.9=py39h6a678d5_7 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.12.12=h06a4308_0 + - certifi=2023.11.17=py39h06a4308_0 + - cffi=1.16.0=py39h5eee18b_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cryptography=41.0.7=py39hdda0065_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - ffmpeg=4.3=hf484d3e_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - idna=3.4=py39h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - jpeg=9e=h5eee18b_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libdeflate=1.17=h5eee18b_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.4=h5eee18b_0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libwebp=1.3.2=h11a3e52_0 + - libwebp-base=1.3.2=h5eee18b_0 + - lz4-c=1.9.4=h6a678d5_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py39h5eee18b_1 + - mkl_fft=1.3.8=py39h5eee18b_0 + - mkl_random=1.2.4=py39hdb19cb5_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - numpy=1.26.3=py39h5f9d8c6_0 + - numpy-base=1.26.3=py39hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.0.12=h7f8727e_0 + - pillow=10.0.1=py39ha6cbd5a_0 + - pip=23.3.1=py39h06a4308_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=23.2.0=py39h06a4308_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.18=h955ad1f_0 + - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 + - pytorch-mutex=1.0=cuda + - readline=8.2=h5eee18b_0 + - requests=2.31.0=py39h06a4308_0 + - setuptools=68.2.2=py39h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.12.1=py39_cu113 + - torchvision=0.13.1=py39_cu113 + - typing_extensions=4.9.0=py39h06a4308_0 + - urllib3=1.26.18=py39h06a4308_0 + - wheel=0.41.2=py39h06a4308_0 + - xz=5.4.5=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.5=hc292b87_0 + - pip: + - absl-py==2.0.0 + - astunparse==1.6.3 + - audioread==3.0.1 + - black==23.3.0 + - cachetools==5.3.2 + - click==8.1.7 + - contourpy==1.2.0 + - cycler==0.12.1 + - dcase-util==0.2.10 + - decorator==5.1.1 + - flake8==6.1.0 + - flatbuffers==23.5.26 + - fonttools==4.47.2 + - future==0.18.3 + - gast==0.5.4 + - google-auth==2.26.2 + - google-auth-oauthlib==0.4.6 + - google-pasta==0.2.0 + - grpcio==1.60.0 + - h5py==3.10.0 + - imageio==2.33.1 + - importlib-metadata==7.0.1 + - importlib-resources==6.1.1 + - joblib==1.3.2 + - keras==2.8.0 + - keras-preprocessing==1.1.2 + - kiwisolver==1.4.5 + - lazy-loader==0.3 + - libclang==16.0.6 + - librosa==0.10.1 + - llvmlite==0.41.1 + - markdown==3.5.2 + - markupsafe==2.1.3 + - matplotlib==3.8.2 + - mccabe==0.7.0 + - memory-profiler==0.55.0 + - msgpack==1.0.7 + - mypy-extensions==1.0.0 + - networkx==3.2.1 + - numba==0.58.1 + - oauthlib==3.2.2 + - onnx==1.15.0 + - onnxoptimizer==0.3.13 + - opencv-python==4.9.0.80 + - opt-einsum==3.3.0 + - packaging==23.2 + - pandas==2.1.4 + - pathspec==0.12.1 + - platformdirs==4.1.0 + - pooch==1.8.0 + - protobuf==3.20.2 + - psutil==5.9.7 + - pyasn1==0.5.1 + - pyasn1-modules==0.3.0 + - pycodestyle==2.11.1 + - pydot-ng==2.0.0 + - pyflakes==3.1.0 + - pyparsing==3.1.1 + - python-dateutil==2.8.2 + - python-magic==0.4.27 + - pytz==2023.3.post1 + - pyyaml==6.0.1 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - scikit-image==0.22.0 + - scikit-learn==1.3.2 + - scipy==1.11.4 + - six==1.16.0 + - soundfile==0.12.1 + - soxr==0.3.7 + - tensorboard==2.8.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - tensorflow-gpu==2.8.0 + - tensorflow-io-gcs-filesystem==0.35.0 + - termcolor==2.4.0 + - tf-estimator-nightly==2.8.0.dev2021122109 + - threadpoolctl==3.2.0 + - tifffile==2023.12.9 + - tomli==2.0.1 + - tqdm==4.66.1 + - tzdata==2023.4 + - validators==0.22.0 + - werkzeug==3.0.1 + - wrapt==1.16.0 + - zipp==3.17.0 +prefix: ~/anaconda3/envs/lop2_overf diff --git a/training/training_scripts/NN_Adaptive_Filtering/launch_pipeline.py b/training/training_scripts/NN_Adaptive_Filtering/launch_pipeline.py new file mode 100644 index 0000000000..4818f71152 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/launch_pipeline.py @@ -0,0 +1,164 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import argparse +import os +import sys +from pathlib import Path + +from util.file_system import create_directory, read_json_file + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Launching whole pipeline of overfittings, including training, nnr compression and conversion to SADL" + ) + parser.add_argument( + "-vs", + "--video_sequences", + nargs="+", + type=str, + required=True, + help="Sequence name/tag", + ) + parser.add_argument("--qps", nargs="+", type=str, required=True, help="Sequence QP") + return parser.parse_args() + + +def launch_pipeline(): + data_cfg = read_json_file(data_cfg_file) + for seq_name in args.video_sequences: + num_frames = data_cfg[seq_name]["frames"] + intra_period = 64 if data_cfg[seq_name]["fps"] > 32 else 32 + num_part = int((num_frames - 1) // intra_period) + 1 + + for qp in args.qps: + num_param = min( + num_layer_config[seq_name][qp], config["training"]["num_layers"] + ) + + for part in range(num_part): + min_poc = part * intra_period + max_poc = min(min_poc + intra_period + 1, num_frames) + + print( + f"\x1b[6;30;42m Overfitting for video: {seq_name}, qp: {qp}, part: {part} \x1b[0m" + ) + + curr_overfitting_dir = overfittings_dir / f"{seq_name}_{qp}_part{part}" + + command = ( + f"python {script_to_execute} --max_epochs={epochs} " + f"--deco_dir_train={dataset_deco_path} --orig_dir_train={dataset_orig_path} --prop_file_train={data_cfg_file} " + f"--seq_qp_train={qp} --seq_name={seq_name} --slice_type=* --base_nn_params={base_model_path} " + f"--output_dir={curr_overfitting_dir} --arch={config['architecture']} " + f"--stop_patience {stop_patience} --lr_patience {lr_patience} --lr={lr} " + f"--min_poc {min_poc} --max_poc {max_poc} --part_idx {part} --select_layer --num_param {num_param} " + f"--nnr_models_dir {nnr_models_dir} --onnx_models_dir {onnx_models_dir} --sadl_float_dir {sadl_float_dir} --sadl_int_dir {sadl_int_dir} --quantizers_dir {quantizers_dir} " + f"--base_quantizers {base_quantizers} --map_orig_params_to_new {map_orig_params_to_new} " + ) + + if start_epoch: + nn_params = ( + curr_overfitting_dir + / f"models/checkpoints/model_{start_epoch:03d}.pt" + ) + optimiser_params = ( + curr_overfitting_dir + / f"models/checkpoints/model_{start_epoch:03d}_optimiser.pt" + ) + lr_scheduler_params = ( + curr_overfitting_dir + / f"models/checkpoints/model_{start_epoch:03d}_lr_scheduler.pt" + ) + command = ( + command + + f"--start_epoch {start_epoch} --nn_params {nn_params} --optimiser_params {optimiser_params} --lr_scheduler_params {lr_scheduler_params}" + ) + + os.system(command) + + +if __name__ == "__main__": + args = parse_arguments() + + config = read_json_file("resources/config.json") + + if "lop2" != config["architecture"]: + print(f"{config['architecture']} architecture not supported") + sys.exit(-1) + + model_config = "models/models.json" + + dataset_orig_path = config["dataset"]["orig_path"] + dataset_deco_path = config["dataset"]["deco_path"] + data_cfg_file = Path("resources/datasets/jvet.json") + output_dir = Path(config["training"]["output_path"]) + + # path to overfitted models + overfittings_dir = output_dir / "overfittings" + + # path to nnr models + nnr_models_dir = output_dir / "nnr_models" + + # path to SADL models + sadl_float_dir = output_dir / "sadl_float_models" + sadl_int_dir = output_dir / "sadl_int_models" + quantizers_dir = output_dir / "quantizers" + onnx_models_dir = output_dir / "onnx_models" + + create_directory(overfittings_dir) + create_directory(nnr_models_dir) + create_directory(nnr_models_dir / "nnr") + create_directory(sadl_float_dir) + create_directory(sadl_int_dir) + create_directory(quantizers_dir) + create_directory(onnx_models_dir) + + base_model_path = config["sadl2torch"]["base_model_torch"] + base_quantizers = config["sadl2torch"]["base_model_quantizers"] + map_orig_params_to_new = config["sadl2torch"]["base_model_params_to_new"] + + num_layer_config = read_json_file("resources/num_layers.json") + + script_to_execute = "./overfitting_pipeline.py" + epochs = 1 + lr = 1e-3 + stop_patience = 50 + lr_patience = 100 + weight_decay = 0 + total_terms = 100 + start_epoch = 0 + + launch_pipeline() diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/models/__init__.py new file mode 100644 index 0000000000..91510001de --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/models/__init__.py @@ -0,0 +1,61 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Union + +from util.file_system import read_json_file + + +def instantiate_from_dict( + namespace, config: Dict[str, Any], *args: Any, **kwargs: Dict[str, Any] +) -> Any: + """Instantiate instance from config dict. Value of 'class' in dict should be a class available in namespace. + Args: + namespace: load class definition from this namespace + config: dict containing kwargs for instantiation + *args: args for instantiation + **kwargs: additional kwargs for instantiation + """ + return getattr(namespace, config.pop("class"))(*args, **config, **kwargs) + + +def load_torch_model(model_config: Union[Path, Dict], namespace): + config = ( + model_config if isinstance(model_config, Dict) else read_json_file(model_config) + ) + copy_config = deepcopy(config) + model = instantiate_from_dict(namespace, copy_config["model"]) + return model diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2.py b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2.py new file mode 100644 index 0000000000..47e781fba8 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2.py @@ -0,0 +1,360 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import Dict, Iterable, List, Optional, Tuple, Type, Union + +import torch +from torch import nn +from torch.nn import functional as F + + +class Conv(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = None, + is_separable: bool = False, + hidden_separable_channels: Optional[int] = None, + post_activation: Optional[Type] = nn.PReLU, + **kwargs, + ): + """ + Args: + in_channels: the number of input channels + out_channels: the number of output channels + kernel_size: the convolution's kernel size + stride: the convolution's stride(s) + padding: the convolution's padding + is_separable: whether to implement convolution separably + hidden_separable_channels: If is_separable, the number of hidden channels between convolutions. If None, use out_channels + post_activation: activation function to use after convolution. If None, no activation after convolution + **kwargs: additional kwargs to pass to nn.Conv2d + """ + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + ) + self.stride = (stride, stride) if isinstance(stride, int) else stride + if padding is not None: + self.padding = (padding, padding) if isinstance(padding, int) else padding + else: + self.padding = tuple([k // 2 for k in self.kernel_size]) + self.is_separable = is_separable + self.post_activation = post_activation + + if self.is_separable: + self.hidden_separable_channels = hidden_separable_channels or out_channels + modules = [ + nn.Conv2d( + self.in_channels, + self.hidden_separable_channels, + (self.kernel_size[0], 1), + (self.stride[0], 1), + (self.padding[0], 0), + groups=self.hidden_separable_channels, + **kwargs, + ), + nn.Conv2d( + self.hidden_separable_channels, + self.out_channels, + (1, self.kernel_size[1]), + (1, self.stride[1]), + (0, self.padding[1]), + groups=self.hidden_separable_channels, + **kwargs, + ), + nn.Conv2d( + self.out_channels, + self.out_channels, + (1, 1), + (1, 1), + (0, 0), + **kwargs, + ), + ] + else: + modules = [ + nn.Conv2d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + **kwargs, + ) + ] + + if self.post_activation is not None: + modules.append(self.post_activation()) + + super(Conv, self).__init__(*modules) + + +class MultiBranchModule(nn.Module): + """A module representing multple, parallel branches. If the input is a list, each element in the list is fed into the corresponding branch, + otherwise the input is fed into every branch. The outputs of each branch are then merged. + """ + + def __init__(self, *branch_modules, merge_dimension: int = -3): + """ + Args: + branch_modules: modules to run in parallel + merge_dimension: the dimension to merge outputs from each branch + """ + super().__init__() + self.branches = nn.ModuleList(branch_modules) + self.merge_dimension = merge_dimension + + def forward(self, args: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: + inputs = args if isinstance(args, list) else len(self.branches) * [args] + branch_outputs = [branch(input) for branch, input in zip(self.branches, inputs)] + return torch.cat(branch_outputs, dim=self.merge_dimension) + + +class ResidualBlock(nn.Sequential): + def __init__(self, C: int = 64, C1: int = 160, C21: int = 32): + super(ResidualBlock, self).__init__( + Conv(C, C1, kernel_size=1), + Conv(C1, C, kernel_size=1, post_activation=None), + Conv( + C, + C, + kernel_size=3, + post_activation=None, + is_separable=True, + hidden_separable_channels=C21, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + super(ResidualBlock, self).forward(x) + + +class SplitLumaChromaBlocks(nn.Sequential): + def __init__( + self, + N_Y: int = 12, + N_UV: int = 6, + C: int = 16, + C1_Y: int = 64, + C1_UV: int = 48, + C21: int = 16, + output_channels_y: int = 4, + output_channels_uv: int = 2, + ): + super().__init__() + + self.split_y_path = nn.Sequential( + *[ResidualBlock(C, C1_Y, C21) for _ in range(N_Y)], + Conv(C, C, kernel_size=3, is_separable=True, hidden_separable_channels=C21), + Conv(C, output_channels_y, kernel_size=3, post_activation=None), + ) + + self.split_uv_path = nn.Sequential( + *[ResidualBlock(C, C1_UV, C21) for _ in range(N_UV)], + Conv(C, C, kernel_size=3, is_separable=True, hidden_separable_channels=C21), + Conv(C, output_channels_uv, kernel_size=3, post_activation=None), + ) + + self.Cy = C + self.Cuv = C + + def forward(self, x: torch.Tensor) -> torch.Tensor: + split_y_input = x[:, : self.Cy, :, :] + split_uv_input = x[:, self.Cy : self.Cy + self.Cuv, :, :] + + y_output = self.split_y_path.forward(split_y_input) + uv_output = self.split_uv_path.forward(split_uv_input) + return torch.cat((y_output, uv_output), dim=1) + + +class SADLNet(nn.Sequential): + """The network used during SADL inference""" + + def __init__( + self, + input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1], + input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3], + D1: int = 12, + D2: int = 8, + D3: int = 4, + D4: int = 2, + D5: int = 2, + D6: int = 24, + N_Y: int = 12, + N_UV: int = 6, + C: int = 16, + C1_Y: int = 64, + C1_UV: int = 48, + C21: int = 16, + output_channels_y: int = 4, + output_channels_uv: int = 2, + ): + """ + Args: + input_channels: the number of channels expected for each input + input_kernels: the kernel size for each input convolution + output_channels: the number of output channels + """ + self.input_channels = input_channels + self.input_kernels = input_kernels + self.input_features = [D1, D2, D3, D4, D4, D5] + + super(SADLNet, self).__init__( + MultiBranchModule( + *[ + Conv(c, d, kernel_size=k, post_activation=None) + for c, d, k in zip( + self.input_channels, self.input_features, self.input_kernels + ) + ] + ), + Conv(sum(self.input_features), D6, kernel_size=1), + Conv(D6, C + C, kernel_size=3, stride=2), + SplitLumaChromaBlocks( + N_Y, N_UV, C, C1_Y, C1_UV, C21, output_channels_y, output_channels_uv + ), + ) + + def get_example_inputs( + self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1 + ): + patch_size = ( + (patch_size, patch_size) if isinstance(patch_size, int) else patch_size + ) + return [ + torch.rand( + batch_size, conv.in_channels, *patch_size, device=conv[0].weight.device + ) + for conv in self[0].branches + ] + + def to_onnx( + self, filename: str, patch_size: int = 144, batch_size: int = 1, **kwargs + ) -> None: + mode = self.training + self.eval() + torch.onnx.export( + self, self.get_example_inputs(patch_size, batch_size), filename, **kwargs + ) + self.train(mode) + + +class Net(nn.Module): + """Wrapper for SADL model that implements input pre- and post-processing for training.""" + + def __init__( + self, + input_channels: Iterable[Iterable[str]] = [ + ["rec_before_dbf_Y", "rec_before_dbf_U", "rec_before_dbf_V"], + ["pred_Y", "pred_U", "pred_V"], + ["bs_Y", "bs_U", "bs_V"], + ["qp_base"], + ["qp_slice"], + ["ipb_Y"], + ], + input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3], + D1: int = 12, + D2: int = 8, + D3: int = 4, + D4: int = 2, + D5: int = 2, + D6: int = 24, + N_Y: int = 12, + N_UV: int = 6, + C: int = 16, + C1_Y: int = 64, + C1_UV: int = 48, + C21: int = 16, + ): + super(Net, self).__init__() + assert len(input_channels) == len( + input_kernels + ), "[ERROR] input size and kernels size not equal" + self.input_channels = input_channels + sizes = [len(a) for a in input_channels] + self.SADL_model = SADLNet( + sizes, + input_kernels, + D1, + D2, + D3, + D4, + D5, + D6, + N_Y, + N_UV, + C, + C1_Y, + C1_UV, + C21, + 4, + 2, + ) + self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest") + + def preprocess_args( + self, batch: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + for name, data in batch.items(): + if "U" in name or "V" in name: + batch[name] = self.chroma_upsampler(batch[name]) + + return [ + torch.cat([batch[name] for name in input_], dim=1) + for input_ in self.input_channels + ] + + def postprocess_outputs( + self, batch: Dict[str, torch.Tensor], out: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + Y_res, UV_res = out.split([4, 2], dim=1) + return ( + batch["rec_before_dbf_Y"] + F.pixel_shuffle(Y_res, 2), + torch.cat((batch["rec_before_dbf_U"], batch["rec_before_dbf_V"]), dim=1)[ + ..., ::2, ::2 + ] + + UV_res, + ) + + def forward( + self, batch: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + args = self.preprocess_args(batch) + out = self.SADL_model(args) + return self.postprocess_outputs(batch, out) diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_nnr.py b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_nnr.py new file mode 100644 index 0000000000..aafefb2f50 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_nnr.py @@ -0,0 +1,279 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import copy +import logging +from collections import OrderedDict +from pathlib import Path +from typing import Dict, Tuple, Union + +import nnc_core +import torch + +import config +from models import load_torch_model, model_lop2_with_multiplier +from util import Colour +from util.metrics import compute_loss_psnr, mse_loss + +LOGGER = logging.getLogger() + + +def shorten_multiplier_name(in_var: str) -> str: + out_var = in_var.replace("SADL_model.", "") + out_var = out_var.replace(".branches.", "") + out_var = out_var.replace("split.", "") + out_var = out_var.replace("split_", "") + out_var = out_var.replace("_path", "") + out_var = out_var.replace("conv", "") + out_var = out_var.replace("nonsep_mul", "") + out_var = out_var.replace("sep_mul", "") + out_var = out_var.replace("..", ".") + out_var = out_var.replace(".multiplier", "") + return out_var + + +class PytorchModel(nnc_core.nnr_model.NNRModel): + def __init__( + self, model_config: Union[Path, Dict], block_size: int, border_size: int + ) -> None: + self.model = load_torch_model(model_config, model_lop2_with_multiplier) + self.__model_info = None + self.__model_parameters = None + self.dataset = None + + self.y_block_size = block_size + self.uv_block_size = block_size // 2 + self.y_border_size = border_size + self.uv_border_size = border_size // 2 + + self.device = None + self.dataloader = None + + def load_model( + self, + model_path, + is_ovef_model: bool, + ): + # Defining device and device's map location + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_file = torch.load(model_path, map_location=self.device) + + if is_ovef_model: + self.model.load_state_dict(model_file, strict=True) + else: + tmp_params = OrderedDict() + dst_params = OrderedDict() + + for name, value in self.model.state_dict().items(): + if "multiplier" not in name: + tmp_params[name] = value + + for src_param, dst_param in zip(model_file.keys(), tmp_params.keys()): + dst_params[dst_param] = model_file[src_param] + + self.model.load_state_dict(dst_params, strict=False) + self.disable_all_grads() + org_model, self.__model_info = self.pytorch_to_nctm(self.model.state_dict()) + return org_model["parameters"] + + def set_dataset(self, dataloader): + self.dataloader = dataloader + + @property + def model_info(self): + return self.__model_info + + def disable_all_grads(self): + for param in self.model.parameters(): + param.requires_grad = False + + def eval_model( + self, parameters, bn_folding=False, verbose=False + ) -> Tuple[float, float, float]: + # torch.set_num_threads(1) + if len(parameters) == 0: + return -100.0, -100.0, 100.0 + + Model = copy.deepcopy(self.model) + base_model_arch = Model.state_dict() + model_dict = OrderedDict() + for module_name in base_model_arch: + var_name = module_name + if ".multiplier" in var_name: + var_name = shorten_multiplier_name(var_name) + + if module_name in parameters or var_name in parameters: + model_dict[module_name] = ( + parameters[var_name] + if torch.is_tensor(parameters[var_name]) + else torch.tensor(parameters[var_name]) + ) + elif ( + config.PUT_SYNTAX() + ): # loading unchanged weights from the feature extractor + model_dict[module_name] = base_model_arch[module_name] + + Model.load_state_dict(model_dict) + Model = Model.to(self.device) + + avg_loss = 0.0 + avg_psnr = 0.0 + avg_psnr_gain = 0.0 + num_elements = 0 + + use_multiplier = torch.ones(1).to(self.device) + cond_target_value = torch.ones(1).to(self.device) + + for input_data, label_data in self.dataloader: + for k, v in input_data.items(): + input_data[k] = v.to(self.device) + for k, v in label_data.items(): + label_data[k] = v.to(self.device) + + reco_Y = input_data["rec_before_dbf_Y"][ + :, + :, + self.y_border_size : self.y_border_size + self.y_block_size, + self.y_border_size : self.y_border_size + self.y_block_size, + ] + reco_U = input_data["rec_before_dbf_U"][ + :, + :, + self.uv_border_size : self.uv_border_size + self.uv_block_size, + self.uv_border_size : self.uv_border_size + self.uv_block_size, + ] + reco_V = input_data["rec_before_dbf_V"][ + :, + :, + self.uv_border_size : self.uv_border_size + self.uv_block_size, + self.uv_border_size : self.uv_border_size + self.uv_block_size, + ] + + prediction = Model(input_data, use_multiplier, cond_target_value) + prediction_Y = prediction[0][ + :, + :, + self.y_border_size : self.y_border_size + self.y_block_size, + self.y_border_size : self.y_border_size + self.y_block_size, + ] + prediction_U, prediction_V = prediction[1][ + :, + :, + self.uv_border_size : self.uv_border_size + self.uv_block_size, + self.uv_border_size : self.uv_border_size + self.uv_block_size, + ].split([1, 1], dim=1) + + _, vtm_psnr, vtm_mask = compute_loss_psnr( + label_data, reco_Y, reco_U, reco_V, mse_loss + ) + pred_loss, pred_psnr, pred_mask = compute_loss_psnr( + label_data, prediction_Y, prediction_U, prediction_V, mse_loss + ) + + mask = vtm_mask * pred_mask + vtm_psnr *= mask + pred_psnr *= mask + delta_psnr_wrt_vtm = pred_psnr - vtm_psnr + + num_elements += torch.sum(mask[Colour.YCbCr]) + avg_loss += torch.sum(pred_loss[Colour.YCbCr]) + avg_psnr += torch.sum(pred_psnr[Colour.YCbCr]) + avg_psnr_gain += torch.sum(delta_psnr_wrt_vtm[Colour.YCbCr]) + + avg_loss /= len(self.dataloader) * self.dataloader.batch_size + avg_psnr /= num_elements + avg_psnr_gain /= num_elements + + del Model + + return avg_psnr_gain, avg_psnr, avg_loss + + def pytorch_to_nctm(self, model_file): + model_dict = model_file + model_data = {"parameters": {}, "reduction_method": "uniform"} + model_info = { + "parameter_type": {}, + "parameter_dimensions": {}, + "parameter_index": {}, + "block_identifier": {}, + "topology_storage_format": nnc_core.nnr_model.TopologyStorageFormat.NNR_TPL_TEF, + "topology_compression_format": nnc_core.nnr_model.TopologyCompressionFormat.NNR_PT_RAW, + } + + for i, module_name in enumerate(model_dict): + var_data = model_dict[module_name].data + var_name = module_name + # shorten parameter names for multipliers + if ".multiplier" in module_name: + var_name = shorten_multiplier_name(var_name) + + model_data["parameters"][var_name] = var_data.cpu().detach().numpy() + model_info["parameter_dimensions"][var_name] = model_data["parameters"][ + var_name + ].shape + model_info["parameter_index"][var_name] = i + + if ".weight" in module_name: + model_info["parameter_type"][var_name] = "conv.weight" + else: + model_info["parameter_type"][var_name] = "conv.bias" + + return model_data, model_info + + def save_state_dict(self, path, model_data): + pass + + def train_model( + self, + parameters, + ): + pass + + def restore_and_save(self, parameters: Dict, output_path: str) -> None: + model_dict = OrderedDict() + Model = copy.deepcopy(self.model) + + for module_name in Model.state_dict(): + var_name = module_name + # shorten parameter names for multipliers + if ".multiplier" in module_name: + var_name = shorten_multiplier_name(var_name) + + var_data = torch.tensor(parameters[var_name]) + + model_dict[module_name] = var_data + + Model.load_state_dict(model_dict) + torch.save(Model.state_dict(), output_path) + del Model diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_with_multiplier.py b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_with_multiplier.py new file mode 100644 index 0000000000..0c6c5414d3 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_with_multiplier.py @@ -0,0 +1,468 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import Dict, Iterable, List, Optional, Tuple, Type, Union + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + + +class Multiplier(nn.Module): + """ + Multiplier layer. It is initialised with ones + """ + + def __init__(self, units=1): + super(Multiplier, self).__init__() + self.units = units + self._multiplier = torch.ones(size=(self.units, 1, 1), dtype=torch.float32) + self.multiplier = torch.nn.Parameter(self._multiplier, requires_grad=True) + + def forward(self, x): + return x * self.multiplier + + +class Conv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = None, + is_separable: bool = False, + hidden_separable_channels: Optional[int] = None, + post_activation: Optional[Type] = nn.PReLU, + mul: bool = False, + **kwargs, + ): + """ + Args: + in_channels: the number of input channels + out_channels: the number of output channels + kernel_size: the convolution's kernel size + stride: the convolution's stride(s) + padding: the convolution's padding + is_separable: whether to implement convolution separably + hidden_separable_channels: If is_separable, the number of hidden channels between convolutions. If None, use out_channels + post_activation: activation function to use after convolution. If None, no activation after convolution + **kwargs: additional kwargs to pass to nn.Conv2d + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + ) + self.stride = (stride, stride) if isinstance(stride, int) else stride + if padding is not None: + self.padding = (padding, padding) if isinstance(padding, int) else padding + else: + self.padding = tuple([k // 2 for k in self.kernel_size]) + self.is_separable = is_separable + self.mul = mul + + if self.is_separable: + self.hidden_separable_channels = hidden_separable_channels or out_channels + self.sep_conv1 = nn.Conv2d( + self.in_channels, + self.hidden_separable_channels, + (self.kernel_size[0], 1), + (self.stride[0], 1), + (self.padding[0], 0), + groups=self.hidden_separable_channels, + **kwargs, + ) + if self.mul: + self.sep_mul1 = Multiplier(self.hidden_separable_channels) + self.sep_conv2 = nn.Conv2d( + self.hidden_separable_channels, + self.out_channels, + (1, self.kernel_size[1]), + (1, self.stride[1]), + (0, self.padding[1]), + groups=self.hidden_separable_channels, + **kwargs, + ) + if self.mul: + self.sep_mul2 = Multiplier(self.out_channels) + self.sep_conv3 = nn.Conv2d( + self.out_channels, + self.out_channels, + (1, 1), + (1, 1), + (0, 0), + **kwargs, + ) + if self.mul: + self.sep_mul3 = Multiplier(self.out_channels) + else: + self.nonsep_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + **kwargs, + ) + if self.mul: + self.nonsep_mul = Multiplier(self.out_channels) + + self.post_activation = None if post_activation is None else post_activation() + + def forward( + self, + x: torch.Tensor, + activate_multiplier: torch.Tensor = None, + ) -> torch.Tensor: + if self.is_separable: + y = self.sep_conv1(x) + if self.mul and activate_multiplier is not None: + m = self.sep_mul1(y) + y = torch.where(activate_multiplier, m, y) + + y = self.sep_conv2(y) + if self.mul and activate_multiplier is not None: + m = self.sep_mul2(y) + y = torch.where(activate_multiplier, m, y) + + y = self.sep_conv3(y) + if self.mul and activate_multiplier is not None: + m = self.sep_mul3(y) + y = torch.where(activate_multiplier, m, y) + else: + y = self.nonsep_conv(x) + if self.mul and activate_multiplier is not None: + m = self.nonsep_mul(y) + y = torch.where(activate_multiplier, m, y) + if self.post_activation is not None: + y = self.post_activation(y) + + return y + + +class MultiBranchModule(nn.Module): + """A module representing multple, parallel branches. If the input is a list, each element in the list is fed into the corresponding branch, + otherwise the input is fed into every branch. The outputs of each branch are then merged. + """ + + def __init__(self, *branch_modules, merge_dimension: int = -3): + """ + Args: + branch_modules: modules to run in parallel + merge_dimension: the dimension to merge outputs from each branch + """ + super().__init__() + self.branches = nn.ModuleList(branch_modules) + self.merge_dimension = merge_dimension + + def forward(self, args: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: + inputs = args if isinstance(args, list) else len(self.branches) * [args] + branch_outputs = [branch(input) for branch, input in zip(self.branches, inputs)] + return torch.cat(branch_outputs, dim=self.merge_dimension) + + +class ResidualBlock(nn.Module): + def __init__(self, C: int = 64, C1: int = 160, C21: int = 32, mul: bool = False): + super().__init__() + self.conv1 = Conv(C, C1, kernel_size=1, mul=mul) + self.conv2 = Conv(C1, C, kernel_size=1, post_activation=None, mul=mul) + self.conv3 = Conv( + C, + C, + kernel_size=3, + post_activation=None, + is_separable=True, + hidden_separable_channels=C21, + mul=mul, + ) + + def forward( + self, + x: torch.Tensor, + activate_multiplier: torch.Tensor, + ) -> torch.Tensor: + y = self.conv1(x, activate_multiplier) + y = self.conv2(y, activate_multiplier) + y = self.conv3(y, activate_multiplier) + return x + y + + +class SplitLumaChromaBlocks(nn.Module): + def __init__( + self, + N_Y: int = 12, + N_UV: int = 6, + C: int = 16, + C1_Y: int = 64, + C1_UV: int = 48, + C21: int = 16, + output_channels_y: int = 4, + output_channels_uv: int = 2, + ): + super().__init__() + + self.split_y_path = nn.ModuleList( + m + for m in [ + *[ResidualBlock(C, C1_Y, C21, True) for _ in range(N_Y)], + Conv( + C, + C, + kernel_size=3, + is_separable=True, + hidden_separable_channels=C21, + mul=True, + ), + Conv( + C, output_channels_y, kernel_size=3, post_activation=None, mul=True + ), + ] + ) + + self.split_uv_path = nn.ModuleList( + m + for m in [ + *[ResidualBlock(C, C1_UV, C21, True) for _ in range(N_UV)], + Conv( + C, + C, + kernel_size=3, + is_separable=True, + hidden_separable_channels=C21, + mul=True, + ), + Conv( + C, output_channels_uv, kernel_size=3, post_activation=None, mul=True + ), + ] + ) + + self.Cy = C + self.Cuv = C + + def forward( + self, + x: torch.Tensor, + activate_multiplier: torch.Tensor = None, + ) -> torch.Tensor: + split_y_input = x[:, : self.Cy, :, :] + split_uv_input = x[:, self.Cy : self.Cy + self.Cuv, :, :] + + y_output = split_y_input + for m in self.split_y_path: + y_output = m(y_output, activate_multiplier) + + uv_output = split_uv_input + for m in self.split_uv_path: + uv_output = m(uv_output, activate_multiplier) + + return torch.cat((y_output, uv_output), dim=1) + + +class SADLNet(nn.Module): + """The network used during SADL inference""" + + def __init__( + self, + input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1], + input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3], + D1: int = 12, + D2: int = 8, + D3: int = 4, + D4: int = 2, + D5: int = 2, + D6: int = 24, + N_Y: int = 12, + N_UV: int = 6, + C: int = 16, + C1_Y: int = 64, + C1_UV: int = 48, + C21: int = 16, + output_channels_y: int = 4, + output_channels_uv: int = 2, + ): + """ + Args: + input_channels: the number of channels expected for each input + input_kernels: the kernel size for each input convolution + output_channels: the number of output channels + """ + super().__init__() + self.input_channels = input_channels + self.input_kernels = input_kernels + self.input_features = [D1, D2, D3, D4, D4, D5] + + self.multi_branch = MultiBranchModule( + *[ + Conv(c, d, kernel_size=k, post_activation=None) + for c, d, k in zip( + self.input_channels, self.input_features, self.input_kernels + ) + ] + ) + self.conv1 = Conv(sum(self.input_features), D6, kernel_size=1) + self.conv2 = Conv(D6, C + C, kernel_size=3, stride=2) + self.split = SplitLumaChromaBlocks( + N_Y, N_UV, C, C1_Y, C1_UV, C21, output_channels_y, output_channels_uv + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + y = self.multi_branch(inputs[:-2]) + y = self.conv1(y) + y = self.conv2(y) + activate_multipliers = torch.eq(inputs[-2], inputs[-1]) + y = self.split(y, activate_multipliers) + return y + + def get_example_inputs( + self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1 + ) -> List[torch.Tensor]: + patch_size = ( + (patch_size, patch_size) if isinstance(patch_size, int) else patch_size + ) + x = [ + torch.rand( + batch_size, + conv.in_channels, + *patch_size, + device=conv.nonsep_conv.weight.device, + ) + for conv in self.multi_branch.branches + ] + x.append(torch.ones(1)) + x.append(torch.ones(1)) + return x + + def to_onnx( + self, filename: str, patch_size: int = 144, batch_size: int = 1, **kwargs + ) -> None: + mode = self.training + self.eval() + torch.onnx.export( + self, + self.get_example_inputs(patch_size, batch_size), + filename, + **kwargs, + ) + self.train(mode) + + +class Net(nn.Module): + """Wrapper for SADL model that implements input pre- and post-processing for training.""" + + def __init__( + self, + input_channels: Iterable[Iterable[str]] = [ + ["rec_before_dbf_Y", "rec_before_dbf_U", "rec_before_dbf_V"], + ["pred_Y", "pred_U", "pred_V"], + ["bs_Y", "bs_U", "bs_V"], + ["qp_base"], + ["qp_slice"], + ["ipb_Y"], + ], + input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3], + D1: int = 12, + D2: int = 8, + D3: int = 4, + D4: int = 2, + D5: int = 2, + D6: int = 24, + N_Y: int = 12, + N_UV: int = 6, + C: int = 16, + C1_Y: int = 64, + C1_UV: int = 48, + C21: int = 16, + ): + super(Net, self).__init__() + assert len(input_channels) == len( + input_kernels + ), "[ERROR] input size and kernels size not equal" + self.input_channels = input_channels + sizes = [len(a) for a in input_channels] + self.SADL_model = SADLNet( + sizes, + input_kernels, + D1, + D2, + D3, + D4, + D5, + D6, + N_Y, + N_UV, + C, + C1_Y, + C1_UV, + C21, + 4, + 2, + ) + self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest") + + def preprocess_args(self, batch: Dict[str, torch.Tensor]) -> List[torch.Tensor]: + for name, data in batch.items(): + if "U" in name or "V" in name: + batch[name] = self.chroma_upsampler(batch[name]) + + return [ + torch.cat([batch[name] for name in input_], dim=1) + for input_ in self.input_channels + ] + + def postprocess_outputs( + self, batch: Dict[str, torch.Tensor], out: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + Y_res, UV_res = out.split([4, 2], dim=1) + return ( + batch["rec_before_dbf_Y"] + F.pixel_shuffle(Y_res, 2), + torch.cat((batch["rec_before_dbf_U"], batch["rec_before_dbf_V"]), dim=1)[ + ..., ::2, ::2 + ] + + UV_res, + ) + + def forward( + self, + batch: Dict[str, torch.Tensor], + equal_left_side: Tensor, + equal_right_side: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + args = self.preprocess_args(batch) + args.append(equal_left_side) + args.append(equal_right_side) + out = self.SADL_model(args) + return self.postprocess_outputs(batch, out) diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/models.json b/training/training_scripts/NN_Adaptive_Filtering/models/models.json new file mode 100755 index 0000000000..8afb5299ce --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/models/models.json @@ -0,0 +1,51 @@ +{ + "lop2": + { + "model": + { + "class" : "Net", + "input_channels" : + [ + [ + "rec_before_dbf_Y", + "rec_before_dbf_U", + "rec_before_dbf_V" + ], + [ + "pred_Y", + "pred_U", + "pred_V" + ], + [ + "bs_Y", + "bs_U", + "bs_V" + ], + [ "qp_base" ], + [ "qp_slice" ], + [ "ipb_Y" ] + ], + "input_kernels" : + [ + 3, + 3, + 1, + 1, + 1, + 1 + ], + "D1" : 12, + "D2" : 8, + "D3" : 4, + "D4" : 2, + "D5" : 2, + "D6" : 24, + "N_Y" : 14, + "N_UV" : 4, + "C" : 16, + "C1_Y" : 64, + "C1_UV" : 32, + "C21" : 16 + } + } +} diff --git a/training/training_scripts/NN_Adaptive_Filtering/overfitting_pipeline.py b/training/training_scripts/NN_Adaptive_Filtering/overfitting_pipeline.py new file mode 100644 index 0000000000..293dd18380 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/overfitting_pipeline.py @@ -0,0 +1,396 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import sys +from tempfile import NamedTemporaryFile + +import click +import numpy as np +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from trainer.nn_filter import NnFilter +from util.dataset_bin import NnFilterBinDataset +from util.dataset_yuv import NnFilterYuvDataset +from util.file_system import read_json_file +from wu_encoding import compress_one_model_and_2sadl + + +@click.command() +@click.option( + "--deco_dir_train", + default=None, + type=click.Path(), + help="Directory that contains decoded training data", +) +@click.option( + "--orig_dir_train", + default=None, + type=click.Path(), + help="Directory that contains original training data", +) +@click.option( + "--prop_file_train", + default=None, + type=click.Path(), + help="Properties file for the training data", +) +@click.option( + "--seq_qp_train", default=list(), multiple=True, help="List of training QPs" +) +@click.option( + "--slice_type", + default="B", + type=click.Choice(["I", "B", "*"]), + help="Slice type", +) +@click.option("--min_slice_qp", default=0, type=int, help="Minimum slice QP (included)") +@click.option( + "--max_slice_qp", default=63, type=int, help="Maximum slice QP (included)" +) +@click.option( + "--block_size", + default=128, + type=int, + help="Size of the luma block", +) +@click.option( + "--border_size", + default=8, + type=int, + help="Size of the border to be added to the input patch (one side for chroma)", +) +@click.option( + "--random_patches", + default=False, + is_flag=True, + help="Patches are selected randomly", +) +@click.option( + "--num_patches", + default=10, + type=int, + help="Number of patches to extract from each frame", +) +@click.option("--bit_depth", default=10, type=int, help="Sequence bit depth") +@click.option("--seq_name", default=None, type=str, help="Sequence name/tag") +@click.option( + "--num_workers", default=16, type=int, help="Number of threads used to load data" +) +@click.option("--start_epoch", default=0, type=int, help="Starting epoch") +@click.option("--max_epochs", default=100, type=int, help="Max number of epochs") +@click.option("--lr", default=0.001, type=float, help="Learning rate") +@click.option( + "--lr_patience", + default=30, + type=int, + help="Number of epochs to drop the learning rate when the loss stops improving", +) +@click.option( + "--lr_gamma", + default=0.1, + type=float, + help="Learning rate drop factor for the scheduler", +) +@click.option( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay for regularisation", +) +@click.option("--batch_size", default=64, type=int, help="Batch size") +@click.option("--arch", default=None, type=str, help="Model architecture") +@click.option( + "--base_nn_params", + default=None, + type=click.Path(), + help="Abs path to NN parameters of base model", +) +@click.option( + "--base_quantizers", + default=None, + type=click.Path(), + help="Abs path to base model quantizers", +) +@click.option( + "--map_orig_params_to_new", + default=None, + type=click.Path(), + help="Abs path to file that maps original parameters to new parameters", +) +@click.option( + "--nn_params", default=None, type=click.Path(), help="Abs path to NN parameters" +) +@click.option( + "--optimiser_params", + default=None, + type=click.Path(), + help="Abs path to optimiser parameters", +) +@click.option( + "--lr_scheduler_params", + default=None, + type=click.Path(), + help="Abs path to learning rate scheduler parameters", +) +@click.option( + "--loss_op", default="mse", type=click.Choice(["mse", "mae"]), help="Loss operation" +) +@click.option( + "--stop_patience", + default=50, + type=int, + help="Number of epochs to stop training when the metrics stops improving", +) +@click.option( + "--output_dir", + default="/tmp/output_dir", + type=click.Path(), + help="Output directory", +) +@click.option( + "--min_poc", + default=0, + type=int, + help="minimum poc idx", +) +@click.option( + "--max_poc", + default=1000, + type=int, + help="max poc idx", +) +@click.option( + "--part_idx", + default=0, + type=int, + help="Random Access segment idx", +) +@click.option( + "--select_layer", + is_flag=True, + help="if select layer during the training", +) +@click.option( + "--num_param", + default=1000, + type=int, + help="number of parameters to be selected during the training", +) +@click.option( + "--nnr_models_dir", type=click.Path(), default=None, help="Path to nnr models" +) +@click.option( + "--onnx_models_dir", type=click.Path(), default=None, help="Path to onnx models" +) +@click.option( + "--sadl_float_dir", + type=click.Path(), + default=None, + help="Path to SADL float models", +) +@click.option( + "--sadl_int_dir", + type=click.Path(), + default=None, + help="Path to SADL integer models", +) +@click.option( + "--quantizers_dir", type=click.Path(), default=None, help="Path to quantizers" +) +def run_pipeline( + deco_dir_train, + orig_dir_train, + prop_file_train, + seq_qp_train, + slice_type, + min_slice_qp, + max_slice_qp, + block_size, + border_size, + random_patches, + num_patches, + bit_depth, + seq_name, + num_workers, + start_epoch, + max_epochs, + lr, + lr_patience, + lr_gamma, + weight_decay, + batch_size, + arch, + base_nn_params, + base_quantizers, + map_orig_params_to_new, + nn_params, + optimiser_params, + lr_scheduler_params, + loss_op, + stop_patience, + output_dir, + min_poc, + max_poc, + part_idx, + select_layer, + num_param, + nnr_models_dir, + onnx_models_dir, + sadl_float_dir, + sadl_int_dir, + quantizers_dir, +): + assert block_size % 2 == 0, "The block size must be multiple of 2" + assert border_size % 2 == 0, "The border size must be multiple of 2" + + if "lop2" != arch: + print(f"{arch} architecture not supported") + sys.exit(-1) + + model_config = read_json_file("models/models.json")[arch] + + train_data_yuv = NnFilterYuvDataset( + deco_dir_train, + orig_dir_train, + prop_file_train, + seq_name, + seq_qp_train, + bit_depth, + block_size, + border_size, + slice_type, + min_slice_qp, + max_slice_qp, + random_patches, + num_patches, + min_poc, + max_poc, + ) + + input_keys = [ + "rec_before_dbf_Y", + "rec_before_dbf_U", + "rec_before_dbf_V", + "pred_Y", + "pred_U", + "pred_V", + "bs_Y", + "bs_U", + "bs_V", + "ipb_Y", + ] + label_keys = ["orig_Y", "orig_U", "orig_V", "mask_Y", "mask_U", "mask_V"] + num_patches = len(train_data_yuv) + # Store pre-processed data as Temporary file to save training time + with NamedTemporaryFile( + dir=".", suffix=".bin", prefix=f"{seq_name}_{seq_qp_train[0]}_part{part_idx}_" + ) as tmp_data: + print("[INFO] Dump {} patches in {}".format(num_patches, tmp_data.name)) + + for idx in tqdm(range(num_patches)): + input_data, label_data, base_qp, slice_qp = train_data_yuv[idx] + for k in input_keys: + data = input_data[k].cpu().detach().numpy() + data.tofile(tmp_data) + for k in label_keys: + data = label_data[k].cpu().detach().numpy() + data.tofile(tmp_data) + np.array([base_qp, slice_qp], dtype=np.float32).tofile(tmp_data) + + train_data = NnFilterBinDataset( + tmp_data.name, + block_size, + border_size, + input_keys, + label_keys, + num_patches, + np.float32, + ) + + train_loader = DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + drop_last=True, + num_workers=num_workers, + pin_memory=True, + ) + + print("Size of training data: ", num_patches) + nn_filter = NnFilter( + model_config, + lr, + lr_patience, + lr_gamma, + weight_decay, + block_size, + border_size, + base_nn_params, + nn_params, + optimiser_params, + lr_scheduler_params, + loss_op, + stop_patience, + output_dir, + select_layer, + ) + + nn_filter.train_loop(start_epoch, max_epochs, train_loader, num_param) + + # NNR Compression + compress_one_model_and_2sadl( + output_dir, + base_nn_params, + model_config, + base_quantizers, + map_orig_params_to_new, + nnr_models_dir, + onnx_models_dir, + sadl_float_dir, + sadl_int_dir, + quantizers_dir, + train_loader, + block_size, + border_size, + ) + + +def main(*args, **kwargs): + print("call: {}".format(" ".join(sys.argv))) + run_pipeline() + + +if __name__ == "__main__": + main() diff --git a/training/training_scripts/NN_Adaptive_Filtering/resources/config.json b/training/training_scripts/NN_Adaptive_Filtering/resources/config.json new file mode 100644 index 0000000000..a65c6bdefb --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/resources/config.json @@ -0,0 +1,36 @@ +{ + "architecture": "lop2", + "sadl2torch": { + "base_model_sadl_i": "/src/models/nnlf_lop2_model_int16.sadl", + "base_model_torch": "resources/nnlf_lop2_model.pt", + "base_model_quantizers": "resources/nnlf_lop2_quantizers.json", + "base_model_params_to_new": "resources/nnlf_lop2_params_to_new_model.json" + }, + "decoder_bin": "/path/to/decoder/bin", + "intra_filters": "/path/to/intra/filters", + "dataset": { + "orig_path": "/path/to/dataset/orig", + "deco_path": "/path/to/dataset/deco", + "bit-depth": 10, + "data_cfg": "resources/datasets/jvet.json" + }, + "nnvc": { + "log": "log_dec.txt", + "bs": "bs.yuv", + "rec": "rec_before_dbf.yuv", + "pred": "pred.yuv", + "ipb": "bpm.yuv" + }, + "training": { + "output_path": "/path/to/output", + "num_models": 1, + "batch-size": 64, + "block-size": 128, + "pad-size": 8, + "min-tid": 0, + "num_layers": 106, + "slice-type": "*", + "input-keys": ["rec_before_dbf_Y", "rec_before_dbf_U", "rec_before_dbf_V", "pred_Y", "pred_U", "pred_V", "bs_Y", "bs_U", "bs_V", "ipb_Y"], + "label-keys": ["orig_Y", "orig_U", "orig_V", "mask_Y", "mask_U", "mask_V"] + } +} diff --git a/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet.json b/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet.json new file mode 100644 index 0000000000..ebc62b38ba --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet.json @@ -0,0 +1,212 @@ +{ + "A1_Tango": { + "width": 3840, + "height": 2160, + "frames": 294, + "fps": 60, + "bit-depth": 10 + }, + "A1_FoodMarket": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 60, + "bit-depth": 10 + }, + "A1_CampfireParty": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 30, + "bit-depth": 10 + }, + "A2_CatRobot": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 60, + "bit-depth": 10 + }, + "A2_DaylightRoad": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 60, + "bit-depth": 10 + }, + "A2_ParkRunning": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 50, + "bit-depth": 10 + }, + "B_MarketPlace": { + "width": 1920, + "height": 1080, + "frames": 600, + "fps": 60, + "bit-depth": 10 + }, + "B_RitualDance": { + "width": 1920, + "height": 1080, + "frames": 600, + "fps": 60, + "bit-depth": 10 + }, + "B_Cactus": { + "width": 1920, + "height": 1080, + "frames": 500, + "fps": 50, + "bit-depth": 8 + }, + "B_BasketBallDrive": { + "width": 1920, + "height": 1080, + "frames": 500, + "fps": 50, + "bit-depth": 8 + }, + "B_BQTerrace": { + "width": 1920, + "height": 1080, + "frames": 600, + "fps": 60, + "bit-depth": 8 + }, + "C_BasketballDrill": { + "width": 832, + "height": 480, + "frames": 500, + "fps": 50, + "bit-depth": 8 + }, + "C_BQMall": { + "width": 832, + "height": 480, + "frames": 600, + "fps": 60, + "bit-depth": 8 + }, + "C_PartyScene": { + "width": 832, + "height": 480, + "frames": 500, + "fps": 50, + "bit-depth": 8 + }, + "C_RaceHorses_big": { + "width": 832, + "height": 480, + "frames": 300, + "fps": 30, + "bit-depth": 8 + }, + "D_BasketBallPass": { + "width": 416, + "height": 240, + "frames": 500, + "fps": 50, + "bit-depth": 8 + }, + "D_BQSquare": { + "width": 416, + "height": 240, + "frames": 600, + "fps": 60, + "bit-depth": 8 + }, + "D_BlowingBubbles": { + "width": 416, + "height": 240, + "frames": 500, + "fps": 50, + "bit-depth": 8 + }, + "D_RaceHorses_s": { + "width": 416, + "height": 240, + "frames": 300, + "fps": 30, + "bit-depth": 8 + }, + "E_FourPeople": { + "width": 1280, + "height": 720, + "frames": 600, + "fps": 60, + "bit-depth": 8 + }, + "E_Johnny": { + "width": 1280, + "height": 720, + "frames": 600, + "fps": 60, + "bit-depth": 8 + }, + "E_KristenAndSara": { + "width": 1280, + "height": 720, + "frames": 600, + "fps": 60, + "bit-depth": 8 + }, + "F_BBDrillText": { + "width": 832, + "height": 480, + "frames": 500, + "fps": 50, + "bit-depth": 8 + }, + "F_ArenaOfValor": { + "width": 1920, + "height": 1080, + "frames": 600, + "fps": 60, + "bit-depth": 8 + }, + "F_SlideEditing": { + "width": 1280, + "height": 720, + "frames": 300, + "fps": 30, + "bit-depth": 8 + }, + "F_SlideShow": { + "width": 1280, + "height": 720, + "frames": 500, + "fps": 20, + "bit-depth": 8 + }, + "H2_DayStreet": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 60, + "bit-depth": 10 + }, + "H2_FlyingBirds2": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 60, + "bit-depth": 10 + }, + "H2_PeopleInShop": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 60, + "bit-depth": 10 + }, + "H2_SunsetBeach2": { + "width": 3840, + "height": 2160, + "frames": 300, + "fps": 60, + "bit-depth": 10 + } +} diff --git a/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet_labels.json b/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet_labels.json new file mode 100644 index 0000000000..013994fe65 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet_labels.json @@ -0,0 +1,25 @@ +{ + "Tango2_3840x2160_60fps_10bit_420": "A1_Tango", + "FoodMarket4_3840x2160_60fps_10bit_420": "A1_FoodMarket", + "Campfire_3840x2160_30fps_bt709_420_videoRange": "A1_CampfireParty", + "CatRobot1_3840x2160p_60_10_709_420": "A2_CatRobot", + "DaylightRoad2_3840x2160_60fps_10bit_420": "A2_DaylightRoad", + "ParkRunning3_3840x2160_50fps_10bit_420": "A2_ParkRunning", + "MarketPlace_1920x1080_60fps_10bit_420": "B_MarketPlace", + "RitualDance_1920x1080_60fps_10bit_420": "B_RitualDance", + "Cactus_1920x1080_50": "B_Cactus", + "BasketballDrive_1920x1080_50": "B_BasketBallDrive", + "BQTerrace_1920x1080_60": "B_BQTerrace", + "BasketballDrill_832x480_50": "C_BasketballDrill", + "BQMall_832x480_60": "C_BQMall", + "PartyScene_832x480_50": "C_PartyScene", + "RaceHorses_832x480_30": "C_RaceHorses_big", + "BasketballPass_416x240_50": "D_BasketBallPass", + "BQSquare_416x240_60": "D_BQSquare", + "BlowingBubbles_416x240_50": "D_BlowingBubbles", + "RaceHorses_416x240_30": "D_RaceHorses_s", + "BasketballDrillText_832x480_50": "F_BBDrillText", + "ArenaOfValor_1920x1080_60_8bit_420": "F_ArenaOfValor", + "SlideEditing_1280x720_30": "F_SlideEditing", + "SlideShow_1280x720_20": "F_SlideShow" +} diff --git a/training/training_scripts/NN_Adaptive_Filtering/resources/num_layers.json b/training/training_scripts/NN_Adaptive_Filtering/resources/num_layers.json new file mode 100644 index 0000000000..edbdd91661 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/resources/num_layers.json @@ -0,0 +1,163 @@ +{ + "A1_Tango": { + "22": 154, + "27": 64, + "32": 33, + "37": 19, + "42": 10 + }, + "A1_FoodMarket": { + "22": 144, + "27": 72, + "32": 37, + "37": 20, + "42": 10 + }, + "A1_CampfireParty": { + "22": 457, + "27": 120, + "32": 64, + "37": 34, + "42": 16 + }, + "A2_CatRobot": { + "22": 165, + "27": 77, + "32": 39, + "37": 21, + "42": 11 + }, + "A2_DaylightRoad": { + "22": 206, + "27": 79, + "32": 38, + "37": 20, + "42": 10 + }, + "A2_ParkRunning": { + "22": 1166, + "27": 468, + "32": 207, + "37": 95, + "42": 40 + }, + "B_MarketPlace": { + "22": 398, + "27": 161, + "32": 73, + "37": 34, + "42": 15 + }, + "B_RitualDance": { + "22": 291, + "27": 148, + "32": 78, + "37": 42, + "42": 22 + }, + "B_Cactus": { + "22": 362, + "27": 145, + "32": 70, + "37": 35, + "42": 18 + }, + "B_BasketBallDrive": { + "22": 437, + "27": 178, + "32": 86, + "37": 44, + "42": 23 + }, + "B_BQTerrace": { + "22": 678, + "27": 120, + "32": 49, + "37": 24, + "42": 12 + }, + "C_BasketballDrill": { + "22": 49, + "27": 23, + "32": 11, + "37": 6, + "42": 3 + }, + "C_BQMall": { + "22": 41, + "27": 19, + "32": 9, + "37": 5, + "42": 2 + }, + "C_PartyScene": { + "22": 99, + "27": 44, + "32": 20, + "37": 10, + "42": 4 + }, + "C_RaceHorses_big": { + "22": 57, + "27": 24, + "32": 11, + "37": 5, + "42": 2 + }, + "D_BasketBallPass": { + "22": 23, + "27": 11, + "32": 5, + "37": 2, + "42": 1 + }, + "D_BQSquare": { + "22": 18, + "27": 6, + "32": 3, + "37": 1, + "42": 1 + }, + "D_BlowingBubbles": { + "22": 23, + "27": 10, + "32": 5, + "37": 2, + "42": 1 + }, + "D_RaceHorses_s": { + "22": 14, + "27": 7, + "32": 3, + "37": 1, + "42": 1 + }, + "F_BBDrillText": { + "22": 65, + "27": 30, + "32": 15, + "37": 8, + "42": 4 + }, + "F_ArenaOfValor": { + "22": 228, + "27": 99, + "32": 47, + "37": 24, + "42": 12 + }, + "F_SlideEditing": { + "22": 15, + "27": 11, + "32": 8, + "37": 6, + "42": 4 + }, + "F_SlideShow": { + "22": 13, + "27": 7, + "32": 4, + "37": 2, + "42": 1 + } +} \ No newline at end of file diff --git a/training/training_scripts/NN_Adaptive_Filtering/run/TMLogger.py b/training/training_scripts/NN_Adaptive_Filtering/run/TMLogger.py new file mode 100644 index 0000000000..81b1a5d880 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/run/TMLogger.py @@ -0,0 +1,252 @@ +""" +The copyright in this software is being made available under this Software +Copyright License. This software may be subject to other third party and +contributor rights, including patent rights, and no such rights are +granted under this license. +Copyright (c) 1995 - 2021 Fraunhofer-Gesellschaft zur Förderung der +angewandten Forschung e.V. (Fraunhofer) +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted for purpose of testing the functionalities of +this software provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, +this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. +* Neither the names of the copyright holders nor the names of its +contributors may be used to endorse or promote products derived from this +software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +NO EXPRESS OR IMPLIED LICENSES TO ANY PATENT CLAIMS, INCLUDING +WITHOUT LIMITATION THE PATENTS OF THE COPYRIGHT HOLDERS AND +CONTRIBUTORS, ARE GRANTED BY THIS SOFTWARE LICENSE. THE +COPYRIGHT HOLDERS AND CONTRIBUTORS PROVIDE NO WARRANTY OF PATENT +NON-INFRINGEMENT WITH RESPECT TO THIS SOFTWARE. +""" + + +# Format changed + + +import logging +import os +import sys +from timeit import default_timer as timer + +# custom timing class +from memory_profiler import MemTimer, Pipe, choose_backend + + +class TMLogger: + def __init__(self, logfile, overwrite): + self.timings = [] + self.memory = [] + # LOGGER + self.FORMAT = logging.Formatter("iNCTM - %(levelname)s - %(message)s") + logging.root.setLevel(logging.NOTSET) + self.LOGGER = logging.getLogger() + if overwrite and os.path.exists(logfile): + os.remove(logfile) + + if os.path.isfile(logfile): + print("Logfile {} exists. Exiting...".format(logfile)) + sys.exit(1) + + _fileh = logging.FileHandler(logfile) + _fileh.setFormatter(self.FORMAT) + _fileh.setLevel(logging.INFO) + self.LOGGER.addHandler(_fileh) + + +class TimerMemory: + def __init__( + self, + tm_logger, + msg, + scope="", + interval=0.01, + timestamps=False, + include_children=False, + max_usage=False, + backend=None, + ): + self.tm_logger = tm_logger + self.msg = msg + self.scope = msg if scope == "" else scope + self.child_conn, self.parent_conn = Pipe() + self.backend = choose_backend(backend) + self.interval = interval + self.timestamps = timestamps + self.include_children = include_children + self.max_usage = max_usage + + def __enter__(self): + self.t1 = timer() + p = MemTimer( + os.getpid(), + self.interval, + self.child_conn, + self.backend, + timestamps=self.timestamps, + max_usage=self.max_usage, + include_children=self.include_children, + ) + p.start() + self.parent_conn.recv() # wait until we start getting memory + return self + + def __exit__(self, *args): + self.t2 = timer() + res = self.t2 - self.t1 + if self.tm_logger is not None: + self.tm_logger.LOGGER.info(self.msg + ": {}".format(res)) + self.parent_conn.send(0) # finish timing + ret = self.parent_conn.recv() + # n_measurements = self.parent_conn.recv() + # LOGGER.info('Memory usage ({}): min: {}, max: {}, peak increase: {}'.format(self.scope, min(ret), max(ret), max(ret) - min(ret))) + if self.tm_logger is not None: + self.tm_logger.LOGGER.info( + " Memory usage {} (peak increase): {}".format( + self.scope, max(ret) - min(ret) + ) + ) + self.tm_logger.memory.append(max(ret) - min(ret)) + self.tm_logger.timings.append(res) + + +class TMLoggerICNN: + def __init__(self, logfile, overwrite): + self.timings = {} + self.memory = {} + # LOGGER + self.FORMAT = logging.Formatter("ICNN - %(levelname)s - %(message)s") + logging.root.setLevel(logging.NOTSET) + self.LOGGER = logging.getLogger() + if overwrite and os.path.exists(logfile): + os.remove(logfile) + + if os.path.isfile(logfile): + print("Logfile {} exists. Exiting...".format(logfile)) + sys.exit(1) + + _fileh = logging.FileHandler(logfile) + _fileh.setFormatter(self.FORMAT) + _fileh.setLevel(logging.INFO) + self.LOGGER.addHandler(_fileh) + + +class TimerMemoryICNN: + def __init__( + self, + tm_logger, + epoch, + client, + tag, + msg, + scope="", + interval=0.01, + timestamps=False, + include_children=False, + max_usage=False, + backend=None, + ): + self.tm_logger = tm_logger + self.msg = msg + self.scope = msg if scope == "" else scope + self.child_conn, self.parent_conn = Pipe() + self.backend = choose_backend(backend) + self.interval = interval + self.timestamps = timestamps + self.include_children = include_children + self.max_usage = max_usage + self.tag = tag + self.epoch = epoch + self.client = client + + def __enter__(self): + self.t1 = timer() + p = MemTimer( + os.getpid(), + self.interval, + self.child_conn, + self.backend, + timestamps=self.timestamps, + max_usage=self.max_usage, + include_children=self.include_children, + ) + p.start() + self.parent_conn.recv() # wait until we start getting memory + return self + # pass + + def __exit__(self, *args): + self.t2 = timer() + res = self.t2 - self.t1 + if self.tm_logger is not None: + self.tm_logger.LOGGER.info(self.msg + ": {}".format(res)) + self.parent_conn.send(0) # finish timing + ret = self.parent_conn.recv() + # n_measurements = self.parent_conn.recv() + # LOGGER.info('Memory usage ({}): min: {}, max: {}, peak increase: {}'.format(self.scope, min(ret), max(ret), max(ret) - min(ret))) + if self.tm_logger is not None: + self.tm_logger.LOGGER.info( + " Memory usage {} (peak increase): {}".format( + self.scope, max(ret) - min(ret) + ) + ) + if ( + self.epoch in self.tm_logger.memory.keys() + and self.epoch in self.tm_logger.timings.keys() + ): + if self.client in self.tm_logger.memory[self.epoch].keys(): + if ( + self.tag + "_m" + in self.tm_logger.memory[self.epoch][self.client].keys() + or self.tag + "_t" + in self.tm_logger.timings[self.epoch][self.client].keys() + ): + assert ( + 0 + ), "Tag already used: {} ! Logging requires unique tags!".format( + self.tag + ) + else: + self.tm_logger.memory[self.epoch][self.client][ + self.tag + "_m" + ] = max(ret) - min(ret) + self.tm_logger.timings[self.epoch][self.client][ + self.tag + "_t" + ] = res + else: + self.tm_logger.memory[self.epoch][self.client] = {} + self.tm_logger.timings[self.epoch][self.client] = {} + self.tm_logger.memory[self.epoch][self.client][ + self.tag + "_m" + ] = max(ret) - min(ret) + self.tm_logger.timings[self.epoch][self.client][ + self.tag + "_t" + ] = res + else: + self.tm_logger.memory[self.epoch] = {} + self.tm_logger.timings[self.epoch] = {} + self.tm_logger.memory[self.epoch][self.client] = {} + self.tm_logger.timings[self.epoch][self.client] = {} + self.tm_logger.memory[self.epoch][self.client][self.tag + "_m"] = max( + ret + ) - min(ret) + self.tm_logger.timings[self.epoch][self.client][self.tag + "_t"] = res + + # pass diff --git a/training/training_scripts/NN_Adaptive_Filtering/run/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/run/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/training/training_scripts/NN_Adaptive_Filtering/segment_on_off.py b/training/training_scripts/NN_Adaptive_Filtering/segment_on_off.py new file mode 100644 index 0000000000..e9e7a496cd --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/segment_on_off.py @@ -0,0 +1,325 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import argparse +from pathlib import Path + +import numpy as np +import pandas as pd + +from util.file_system import ( + check_directory, + check_file, + create_directory, + list_dirs, + list_selected_dirs, + read_json_file, +) +from util.regex import get_data_from_encoder_log, ra_segment_rgx + + +def get_encoding_data(root_dir: Path) -> pd.DataFrame: + seq_dirs = list() + sim_dirs = list_selected_dirs(root_dir, "vtm_ra") + for sim_dir in sim_dirs: + seq_dirs.extend(list_dirs(sim_dir)) + + sims_data = list() + for seq_dir in seq_dirs: + seq_name = seq_dir.name + if args.video_sequences is not None and seq_name not in args.video_sequences: + continue + + fps = config[seq_name]["fps"] + for qp in args.qps: + enc_logs = list(seq_dir.glob(f"**/log_enc_{qp}*.txt")) + enc_logs = sorted( + enc_logs, + key=lambda s: int(ra_segment_rgx.search(str(s)).groups()[0]), + ) + for enc_log in enc_logs: + part = int(ra_segment_rgx.search(enc_log.stem).group(1)) + bitstream = seq_dir / f"{seq_dir.name}_{qp}_p{part}.266" + + check_file(bitstream) + total_bits = 0 + + sum_psnr_y = 0 + sum_psnr_u = 0 + sum_psnr_v = 0 + + sum_mssim_y = 0 + sum_mssim_u = 0 + sum_mssim_v = 0 + + num_frames = 0 + + with open(enc_log, "r") as stream: + poc_line = 0 + for line in stream: + if line.startswith("POC"): + poc_line += 1 + if poc_line == 1 and part > 0: + continue + ( + slice_type, + slice_qp, + bits, + psnr_y, + psnr_u, + psnr_v, + mssim_y, + mssim_u, + mssim_v, + ) = get_data_from_encoder_log(line) + + sum_psnr_y += psnr_y + sum_psnr_u += psnr_u + sum_psnr_v += psnr_v + + sum_mssim_y += mssim_y + sum_mssim_u += mssim_u + sum_mssim_v += mssim_v + + total_bits += bits + + num_frames += 1 + + sims_data.append( + [ + seq_name, + qp, + part, + len(enc_logs), + num_frames, + fps, + total_bits, + sum_psnr_y, + sum_psnr_u, + sum_psnr_v, + sum_mssim_y, + sum_mssim_u, + sum_mssim_v, + bitstream, + ] + ) + df = pd.DataFrame( + sims_data, + columns=[ + "sequence", + "qp", + "part", + "num_parts", + "num_frames", + "fps", + "bits", + "sum_psnr_y", + "sum_psnr_u", + "sum_psnr_v", + "sum_mssim_y", + "sum_mssim_u", + "sum_mssim_v", + "bitstream", + ], + ) + + df = df.set_index(["sequence", "qp", "part"]) + return df + + +def create_parcat_script(df: pd.DataFrame) -> None: + with open(output_dir / "run.sh", "w") as main_stream: + main_stream.write("#!/bin/bash\n\n") + main_stream.write(f"PARCAT={parcat_bin}\n\n") + tmp = df[["num_parts", "bitstream"]] + cmd = list() + for it, (index, row) in enumerate(tmp.iterrows()): + curr_seq, curr_qp, curr_part = index + + curr_part = int(curr_part) + bitstream = row["bitstream"] + num_parts = row["num_parts"] + + if curr_part == 0: + cmd = ["$PARCAT"] + + # Adding best RA segment bitstream + cmd.append(bitstream) + + if curr_part + 1 == num_parts: + # Adding merge output bitstream to main script + output_bitstream = output_dir / f"{curr_seq}_{curr_qp}.266" + cmd.append(output_bitstream) + main_stream.write(" ".join(str(s) for s in cmd)) + main_stream.write("\n\n") + + # Script to decode single merged bitstream + seq_script = output_dir / f"{curr_seq}_{curr_qp}.sh" + with open(seq_script, "w") as sequence_stream: + log_file = output_dir / f"{curr_seq}_{curr_qp}_log_dec.txt" + sequence_stream.write("#!/bin/bash\n\n") + sequence_stream.write(f"DECODER={decoder_bin}\n\n") + + output_stem = output_dir / f"{curr_seq}_{curr_qp}" + sequence_stream.write( + f"$DECODER -b {output_bitstream} -o /dev/null --PrefixAbsolutePathsToGraphsOutput={intra_models_dir} --NnlfModelName={lop2_model} --NnfuOutputFileStem={output_stem} > {log_file}" + ) + + # Adding cluster params and calling script to decode single merged bitstream + main_stream.write( + f"qsub -q batch.q -terse -cwd -S /bin/bash -V -o {log_file} -j y -N {curr_seq}_{curr_qp} -l vf=8000M {seq_script}\n\n" + ) + + +def segment_on_off() -> None: + pass1_df = get_encoding_data(pass1_dir) + pass2_df = get_encoding_data(pass2_dir) + + df = pass1_df.join(pass2_df, lsuffix="_pass1", rsuffix="_pass2") + + df["pass1"] = ( + df["sum_psnr_y_pass2"] / df["num_frames_pass2"] + - df["sum_psnr_y_pass1"] / df["num_frames_pass1"] + < 0.02 + ) + + best = df.filter(["sequence", "qp", "part"], axis=1) + metrics = [ + "bits", + "sum_psnr_y", + "sum_psnr_u", + "sum_psnr_v", + "sum_mssim_y", + "sum_mssim_u", + "sum_mssim_v", + "bitstream", + "num_parts", + ] + for metric in metrics: + best[metric] = np.where( + df["pass1"], df[f"{metric}_pass1"], df[f"{metric}_pass2"] + ) + + create_parcat_script(best) + + best["fps"] = df["fps_pass1"] + best["num_frames"] = df["num_frames_pass1"] + best = best.drop(columns=["num_parts", "bitstream"]) + + best = best.groupby(["sequence", "qp", "fps"]).sum() + best = best.reset_index() + + metrics = metrics[:-2] + for metric in metrics: + if "bits" in metric: + best["bitrate"] = best.get("bits").astype("float") * ( + best.get("fps").astype("float") + / 1000 + / best.get("num_frames").astype("float") + ) + else: + tokens = metric.split("_") + best[f"avg_{tokens[1]}_{tokens[2]}"] = best.get(metric) / best.get( + "num_frames" + ).astype("float") + + best["seq_qp"] = best["sequence"] + "_" + best["qp"] + best = best.set_index("seq_qp").sort_values( + by=["sequence", "qp"], key=lambda x: x.map(custom_order) + ) + best = best.drop(columns=metrics).drop( + columns=["sequence", "qp", "fps", "num_frames"] + ) + best.to_excel(output_dir / "metrics.xlsx") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-vs", + "--video_sequences", + nargs="+", + type=str, + help="Sequence name/tag", + ) + parser.add_argument("--qps", nargs="+", type=str, required=True, help="NNVC QP") + parser.add_argument("--pass1", type=str, required=True, help="First pass directory") + parser.add_argument( + "--pass2", type=str, required=True, help="Second pass directory" + ) + parser.add_argument( + "--parcat_bin", type=str, required=True, help="Path to parcat binary" + ) + parser.add_argument( + "--decoder_bin", type=str, required=True, help="Path to decoder binary" + ) + parser.add_argument( + "--intra_models_dir", + type=str, + required=True, + help="Path to NN intra pred models dir", + ) + parser.add_argument( + "--lop2_model", type=str, required=True, help="Path to LOP2 model" + ) + parser.add_argument( + "--output_dir", type=str, required=True, help="Output directory" + ) + return parser.parse_args() + + +if __name__ == "__main__": + # If video_sequences arg is not given, all sequences in the directories are processed + # pass1 and pass2 point to the main directory simulation (the one that contains the + # directories vtm_ra* or vtm_rasplitall*) + args = parse_arguments() + + pass1_dir = check_directory(args.pass1) + pass2_dir = check_directory(args.pass2) + + parcat_bin = check_file(args.parcat_bin) + decoder_bin = check_file(args.decoder_bin) + + intra_models_dir = check_directory(args.intra_models_dir) + lop2_model = check_file(args.lop2_model) + + output_dir = create_directory(args.output_dir) + + config = read_json_file("resources/datasets/jvet.json") + + custom_order = {} + for idx, seq in enumerate(config.keys()): + custom_order[seq] = idx + 1 + + segment_on_off() diff --git a/training/training_scripts/NN_Adaptive_Filtering/trainer/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/trainer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/training/training_scripts/NN_Adaptive_Filtering/trainer/nn_filter.py b/training/training_scripts/NN_Adaptive_Filtering/trainer/nn_filter.py new file mode 100644 index 0000000000..34f2455bd7 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/trainer/nn_filter.py @@ -0,0 +1,429 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from collections import OrderedDict +from typing import Dict, Tuple + +import numpy as np +import torch +from torch import Tensor, nn + +import models.model_lop2_with_multiplier as model_with_multiplier +from models import load_torch_model +from util import DEVICE, Colour +from util.file_system import check_file +from util.logger import Logger +from util.metrics import LOSS_OP, compute_loss_psnr, zero_out_negative + + +class NnFilter: + def __init__( + self, + model_config: Dict, + lr: float, + lr_patience: int, + lr_gamma: float, + weight_decay: float, + block_size: int, + border_size: int, + base_nn_params: str, + nn_params: str, + optimiser_params: str, + scheduler_params: str, + loss_op: str, + stop_patience: int, + output_dir: str, + select_layer: bool, + ): + torch.manual_seed(1234) + np.random.seed(1234) + + self._select_layer = select_layer + self._y_block_size = block_size + self._uv_block_size = block_size // 2 + self._y_border_size = border_size + self._uv_border_size = border_size // 2 + + self._model = nn.DataParallel( + load_torch_model(model_config, model_with_multiplier) + ).to(DEVICE) + self._base_model = nn.DataParallel( + load_torch_model(model_config, model_with_multiplier) + ).to(DEVICE) + + self._logger = Logger(output_dir, stop_patience) + self._trainable_parameter_names = self._get_trainable_parameter_names(nn_params) + self._trainable_parameters = self._initialise_training() + + self._optimiser = torch.optim.Adam( + self._trainable_parameters, lr=lr, weight_decay=weight_decay + ) + + self._lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self._optimiser, patience=lr_patience, verbose=True, factor=lr_gamma + ) + self._loss_op = LOSS_OP[loss_op] + + self._load_models(base_nn_params, nn_params, optimiser_params, scheduler_params) + + def _get_trainable_parameter_names(self, nn_params): + # try to load from txt file, stored after online selection + param_names = self._logger.load_param_names() + + # if not store parameter names before, then param can be overfitted when multipliers are different from initialized values + if len(param_names) == 0 and nn_params: + check_file(nn_params) + parameters = torch.load(nn_params) + for name in parameters: + if ".multiplier" in name and not (parameters[name] == 1.0).all(): + param_names.append(name) + + # if nn_param is None or all multipliers are not overfitted, then return names of all multipliers + if len(param_names) == 0: + for name, _ in self._model.module.named_parameters(): + if ".multiplier" in name: + param_names.append(name) + + return param_names + + def _initialise_training(self): + trainable_parameters = [] + for name, param in self._model.module.named_parameters(): + if name not in self._trainable_parameter_names: + param.requires_grad = False + else: + trainable_parameters.append(param) + + for name, param in self._base_model.named_parameters(): + param.requires_grad = False + self._base_model.eval() + print("num of trainable terms", len(trainable_parameters)) + return trainable_parameters + + def _load_models( + self, + base_nn_params: str, + nn_params: str, + optimiser_params: str, + scheduler_params: str, + ) -> None: + check_file(base_nn_params) + + src_params = torch.load(base_nn_params) + tmp_params = OrderedDict() + dst_params = OrderedDict() + + for name, value in self._base_model.module.state_dict().items(): + if "multiplier" not in name: + tmp_params[name] = value + + for src_param, dst_param in zip(src_params.keys(), tmp_params.keys()): + dst_params[dst_param] = src_params[src_param] + + self._base_model.module.load_state_dict(dst_params, strict=False) + print(f"Base model params loaded from {base_nn_params}") + + if nn_params is None: + self._model.module.load_state_dict(dst_params, strict=False) + print(f"Model params loaded from {base_nn_params}") + else: + check_file(nn_params) + self._model.module.load_state_dict(torch.load(nn_params)) + print(f"Model params loaded from {nn_params}") + + if optimiser_params is not None: + check_file(optimiser_params) + self._optimiser.load_state_dict(torch.load(optimiser_params)) + print(f"Optimiser params loaded from {optimiser_params}") + + if scheduler_params is not None: + check_file(scheduler_params) + self._lr_scheduler.load_state_dict(torch.load(scheduler_params)) + print(f"Scheduler params loaded from {scheduler_params}") + + def _compute_base_metrics( + self, + input_data: Dict, + label_data: Dict, + use_multiplier: Tensor, + cond_target_value: Tensor, + loss_op, + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes PSNR for the VTM reconstruction and the initial model + :param input_data: 4D tensor that includes the reconstruction, QP and boundary strength + :param label_data: Ground truth + :param loss_op: loss function + :return: PSNR (sample-wise [B]) + """ + reco_Y = input_data["rec_before_dbf_Y"][ + :, + :, + self._y_border_size : self._y_border_size + self._y_block_size, + self._y_border_size : self._y_border_size + self._y_block_size, + ] + reco_U = input_data["rec_before_dbf_U"][ + :, + :, + self._uv_border_size : self._uv_border_size + self._uv_block_size, + self._uv_border_size : self._uv_border_size + self._uv_block_size, + ] + reco_V = input_data["rec_before_dbf_V"][ + :, + :, + self._uv_border_size : self._uv_border_size + self._uv_block_size, + self._uv_border_size : self._uv_border_size + self._uv_block_size, + ] + + prediction = self._base_model(input_data, use_multiplier, cond_target_value) + prediction_Y = prediction[0][ + :, + :, + self._y_border_size : self._y_border_size + self._y_block_size, + self._y_border_size : self._y_border_size + self._y_block_size, + ] + prediction_U, prediction_V = prediction[1][ + :, + :, + self._uv_border_size : self._uv_border_size + self._uv_block_size, + self._uv_border_size : self._uv_border_size + self._uv_block_size, + ].split([1, 1], dim=1) + + _, vtm_psnr, vtm_mask = compute_loss_psnr( + label_data, reco_Y, reco_U, reco_V, loss_op + ) + _, base_psnr, base_mask = compute_loss_psnr( + label_data, prediction_Y, prediction_U, prediction_V, loss_op + ) + + return vtm_psnr, vtm_mask, base_psnr, base_mask + + def _select_parameters_to_be_overfitted(self, top_k_luma, top_k_chroma) -> None: + """ + Selects the parameters to be over-fitted based on the l1-norm of the weight-update (normalised by the number of units + in the tensor). The highest top_K parameters out of the variable_lists are selected. + + :param top_k: The number of parameters to be selected + """ + energy_luma = [] + energy_chroma = [] + var_names_luma = [] + var_names_chroma = [] + for (org_name, org_param), (new_name, new_param) in zip( + self._base_model.named_parameters(), self._model.named_parameters() + ): + assert org_name == new_name, "The models have different variables" + num_ele = torch.numel(org_param) + if "y_path" in new_name: + energy_luma.append( + torch.sum(torch.abs(new_param - org_param)) / num_ele + ) + var_names_luma.append(org_name) + elif "uv_path" in new_name: + energy_chroma.append( + torch.sum(torch.abs(new_param - org_param)) / num_ele + ) + var_names_chroma.append(org_name) + + energy_luma = torch.Tensor(energy_luma) + _, e_idx = torch.topk(energy_luma, k=top_k_luma) + var_names_luma = [var_names_luma[idx] for idx in e_idx] + + energy_chroma = torch.Tensor(energy_chroma) + _, e_idx = torch.topk(energy_chroma, k=top_k_chroma) + var_names_chroma = [var_names_chroma[idx] for idx in e_idx] + + var_names = var_names_luma + var_names_chroma + + # reset model + self._model.module.load_state_dict(self._base_model.module.state_dict()) + self._trainable_parameters = [] + for name, param in self._model.named_parameters(): + # multiplier + if name in var_names: + self._trainable_parameters.append(param) + else: + param.requires_grad = False + + self._optimiser = torch.optim.Adam( + self._trainable_parameters, + lr=self._optimiser.param_groups[0]["lr"], + weight_decay=self._optimiser.param_groups[0]["weight_decay"], + ) + + # store parameter names + self._logger.save_param_names(var_names) + + def train_one_epoch(self, epoch: int, dataloader) -> None: + avg_loss = torch.zeros(Colour.YCbCr + 1).to(DEVICE) + avg_psnr = torch.zeros(Colour.YCbCr + 1).to(DEVICE) + avg_delta_psnr_wrt_vtm = torch.zeros(Colour.YCbCr + 1).to(DEVICE) + avg_delta_psnr_wrt_base = torch.zeros(Colour.YCbCr + 1).to(DEVICE) + num_elements = torch.zeros(Colour.YCbCr + 1).to(DEVICE) + avg_positive_delta_psnr_wrt_vtm = torch.zeros(Colour.YCbCr + 1).to(DEVICE) + num_elements_positive_psnr = torch.zeros(Colour.YCbCr + 1).to(DEVICE) + use_multiplier = torch.ones(1).to(DEVICE) + cond_target_value = torch.ones(1).to(DEVICE) + + for batch_idx, (input_data, label_data) in enumerate(dataloader): + for k, v in input_data.items(): + input_data[k] = v.to(DEVICE) + for k, v in label_data.items(): + label_data[k] = v.to(DEVICE) + + vtm_psnr, vtm_mask, base_psnr, base_mask = self._compute_base_metrics( + input_data, label_data, use_multiplier, cond_target_value, self._loss_op + ) + + self._optimiser.zero_grad() + prediction = self._model(input_data, use_multiplier, cond_target_value) + pred_Y = prediction[0][ + :, + :, + self._y_border_size : self._y_border_size + self._y_block_size, + self._y_border_size : self._y_border_size + self._y_block_size, + ] + pred_U, pred_V = prediction[1][ + :, + :, + self._uv_border_size : self._uv_border_size + self._uv_block_size, + self._uv_border_size : self._uv_border_size + self._uv_block_size, + ].split([1, 1], dim=1) + + pred_loss, pred_psnr, pred_mask = compute_loss_psnr( + label_data, pred_Y, pred_U, pred_V, self._loss_op + ) + loss = torch.mean(pred_loss[Colour.YCbCr]) + loss.backward() + self._optimiser.step() + + mask = vtm_mask * base_mask * pred_mask + + vtm_psnr *= mask + base_psnr *= mask + pred_psnr *= mask + delta_psnr_wrt_vtm = pred_psnr - vtm_psnr + delta_psnr_wrt_base = pred_psnr - base_psnr + + positive_delta_psnr_wrt_vtm = [0.0] * (Colour.YCbCr + 1) + pp_mask = [0.0] * (Colour.YCbCr + 1) + + for color in range(Colour.YCbCr + 1): + positive_delta_psnr_wrt_vtm[color], pp_mask[color] = zero_out_negative( + delta_psnr_wrt_vtm[color] + ) + + positive_delta_psnr_wrt_vtm = torch.stack(positive_delta_psnr_wrt_vtm) + pp_mask = torch.stack(pp_mask) + + num_elements_positive_psnr += torch.sum(pp_mask, dim=1) + num_elements += torch.sum(mask, dim=1) + avg_loss += torch.sum(pred_loss, dim=1) + avg_psnr += torch.sum(pred_psnr, dim=1) + avg_delta_psnr_wrt_vtm += torch.sum(delta_psnr_wrt_vtm, dim=1) + avg_delta_psnr_wrt_base += torch.sum(delta_psnr_wrt_base, dim=1) + avg_positive_delta_psnr_wrt_vtm += torch.sum( + positive_delta_psnr_wrt_vtm, dim=1 + ) + + avg_loss /= len(dataloader) * dataloader.batch_size + avg_psnr /= num_elements + avg_delta_psnr_wrt_vtm /= num_elements + avg_delta_psnr_wrt_base /= num_elements + avg_positive_delta_psnr_wrt_vtm /= num_elements_positive_psnr + portion_positive_delta_psnr = num_elements_positive_psnr / num_elements + + avg_loss = avg_loss.tolist() + avg_psnr = avg_psnr.tolist() + avg_delta_psnr_wrt_vtm = avg_delta_psnr_wrt_vtm.tolist() + avg_delta_psnr_wrt_base = avg_delta_psnr_wrt_base.tolist() + avg_positive_delta_psnr_wrt_vtm = avg_positive_delta_psnr_wrt_vtm.tolist() + portion_positive_delta_psnr = portion_positive_delta_psnr.tolist() + + self._logger.log_metrics( + epoch, + 0, + avg_loss, + avg_psnr, + avg_delta_psnr_wrt_vtm, + avg_delta_psnr_wrt_base, + avg_positive_delta_psnr_wrt_vtm, + portion_positive_delta_psnr, + ) + return avg_loss[Colour.YCbCr] + + def train_loop( + self, + start_epoch: int, + max_epochs: int, + train_dataloader, + num_trainable_parameters: int, + ): + print(f"Starting at epoch {start_epoch}...") + self._model.train(True) + if self._select_layer: + select_epoch = max(1, int(max_epochs * 0.1)) if self._select_layer else -1 + + for epoch in range(start_epoch, max_epochs): + if self._select_layer and epoch == select_epoch: + num_trainable_parameters_luma = max( + 1, int(0.8 * num_trainable_parameters) + ) + num_trainable_parameters_chroma = ( + num_trainable_parameters - num_trainable_parameters_luma + ) + self._select_parameters_to_be_overfitted( + num_trainable_parameters_luma, num_trainable_parameters_chroma + ) + self._logger.reset() + print( + f"After selection, the amount of trainable variables is : {len(self._trainable_parameters)}" + ) + + epoch_loss = self.train_one_epoch(epoch, train_dataloader) + self._lr_scheduler.step(epoch_loss) + + stop_train = self._logger.save_model( + epoch, epoch_loss, self._model, self._optimiser, self._lr_scheduler, 30 + ) + + if stop_train: + print("No improvements... Stopping the training now") + break + + if self._optimiser.param_groups[0]["lr"] < 1e-8: + print("Learning rate too small... Stopping the training now") + break + + self._logger.save_training_time() diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/util/__init__.py new file mode 100644 index 0000000000..e2b214a3c1 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/util/__init__.py @@ -0,0 +1,63 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import NamedTuple + +import torch + + +class Colour: + Y = 0 + Cb = 1 + Cr = 2 + YCbCr = 3 + + +class Position(NamedTuple): + x: int + y: int + + +COLOUR_LABEL = { + Colour.Y: "Y", + Colour.Cb: "Cb", + Colour.Cr: "Cr", + Colour.YCbCr: "YCbCr", +} + + +COLOUR_WEIGHT = {Colour.Y: 12.0 / 14.0, Colour.Cb: 1.0 / 14.0, Colour.Cr: 1.0 / 14.0} +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +DEVICE_CPU = torch.device("cpu") +TIME_FORMAT: str = "%Y%m%d-%H%M%S" diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/dataset_bin.py b/training/training_scripts/NN_Adaptive_Filtering/util/dataset_bin.py new file mode 100644 index 0000000000..57876746fd --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/util/dataset_bin.py @@ -0,0 +1,153 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import List, Tuple + +import numpy as np +import torch +from torch import Tensor +from torch.utils.data.dataset import Dataset + + +class NnFilterBinDataset(Dataset): + def __init__( + self, + data_path: str, + block_size: int, + pad_size: int, + input_keys: List[str], + label_keys: List[str], + num_patches: int, + data_type: np.dtype, + ): + """ + Constructor + :param data_path: Directory that contains the decoded video data + :param block_size: Block size, without padding + :param pad_size: Padding size (# of samples taken for a luma block at each side) + :param input_keys: Keys of input data dictionary + :param label_kyes: Keys of label data dictionary + :param num_patches: Number of patches in the dataset + :param data_type: Type of data in binary data file + """ + + self._data_path = data_path + self._block_size_y = block_size + self._block_size_uv = block_size // 2 + self._pad_size = pad_size + self._num_patches = num_patches + self._data_type = data_type + self._input_keys = input_keys + self._label_keys = label_keys + self._patch_size_y = self._block_size_y + self._pad_size * 2 + self._patch_size_uv = self._patch_size_y // 2 + num_channels_input_Y = 4 + num_channels_input_uv = 6 + num_channels_label_Y = 2 + num_channels_label_uv = 4 + self._num_scalar = 2 + self._input_block_volume = ( + num_channels_input_Y * self._patch_size_y**2 + + num_channels_input_uv * self._patch_size_uv**2 + ) + self._label_block_volume = ( + num_channels_label_Y * self._block_size_y**2 + + num_channels_label_uv * self._block_size_uv**2 + ) + self._shape_input_y = (1, self._patch_size_y, self._patch_size_y) + self._shape_input_uv = (1, self._patch_size_uv, self._patch_size_uv) + self._shape_label_y = (1, self._block_size_y, self._block_size_y) + self._shape_label_uv = (1, self._block_size_uv, self._block_size_uv) + + def __len__(self) -> int: + """ + Dataset length + """ + return self._num_patches + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: + """ + Returns a single item (input and label) + :param idx: Index + :return: input and label data + """ + + input_offset = ( + self._input_block_volume * idx * np.dtype(self._data_type).itemsize + ) + label_offset = ( + self._label_block_volume * idx * np.dtype(self._data_type).itemsize + ) + qp_offset = self._num_scalar * idx * np.dtype(self._data_type).itemsize + + with open(self._data_path) as f: + data = np.fromfile( + f, + dtype=self._data_type, + count=self._input_block_volume + + self._label_block_volume + + self._num_scalar, + offset=input_offset + label_offset + qp_offset, + ) + + input_data = {} + label_data = {} + for k in self._input_keys: + if "Y" in k: + input_data[k] = data[: self._patch_size_y**2] + input_data[k] = input_data[k].reshape(self._shape_input_y) + data = np.delete(data, range(self._patch_size_y**2)) + else: + input_data[k] = data[: self._patch_size_uv**2] + input_data[k] = input_data[k].reshape(self._shape_input_uv) + data = np.delete(data, range(self._patch_size_uv**2)) + + input_data[k] = torch.from_numpy(input_data[k]) + + for k in self._label_keys: + if "Y" in k: + label_data[k] = data[: self._block_size_y**2] + label_data[k] = label_data[k].reshape(self._shape_label_y) + data = np.delete(data, range(self._block_size_y**2)) + else: + label_data[k] = data[: self._block_size_uv**2] + label_data[k] = label_data[k].reshape(self._shape_label_uv) + data = np.delete(data, range(self._block_size_uv**2)) + + label_data[k] = torch.from_numpy(label_data[k]) + + input_data["qp_base"] = data[-2] * torch.ones(self._shape_input_y) + input_data["qp_slice"] = data[-1] * torch.ones(self._shape_input_y) + + return input_data, label_data diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/dataset_yuv.py b/training/training_scripts/NN_Adaptive_Filtering/util/dataset_yuv.py new file mode 100644 index 0000000000..97c03c6cb4 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/util/dataset_yuv.py @@ -0,0 +1,533 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import List, NamedTuple, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.utils.data.dataset import Dataset +from torchvision.transforms.functional import pad + +from util import Position +from util.file_system import ( + check_directory, + check_file, + list_dirs, + list_selected_dirs, + read_json_file, +) +from util.image_ops import pad_image, read_yuv_frame + + +class PatchInfo(NamedTuple): + poc: int + seq_width: int + seq_height: int + reco_yuv_path: str + boun_yuv_path: str + pred_yuv_path: str + ipb_yuv_path: str + orig_yuv_path: str + slice_qp: int + base_qp: int + temporal_layer: int + slice_type: str + bit_depth: int + pos: Position + + +class NnFilterYuvDataset(Dataset): + def __init__( + self, + deco_dir: str, + orig_dir: str, + prop_file: str, + seq_name: str, + seq_qps: Union[Tuple[str], List[str]], + bit_depth: int, + block_size: int, + pad_size: int, + slice_type: str, + min_slice_qp: int, + max_slice_qp: int, + random_patches: bool, + num_patches: int, + min_poc: int, + max_poc: int, + ): + """ + Constructor + :param deco_dir: Directory that contains the decoded video data + :param orig_dir: Directory that contains the original video data + :param prop_file: Properties file for the training data + :param seq_name: Sequence name + :param seq_qps: list of sequence QPs + :param bit_depth: Bit-depth for the data + :param block_size: Block size, without padding + :param pad_size: Padding size (# of samples taken for a luma block at each side) + :param slice_type: Slice type (i.e. I, B, *) + :param min_slice_qp: Minimum frame QP (inclusive) + :param max_slice_qp: Maximum frame QP (inclusive) + :param random_patches: Extract patches at random position + :param num_patches: Number of patches to extract from each frame + :param min_poc: Minimum index of picture + :param max_poc: maximum index of picture + """ + self._bit_depth = bit_depth + self._block_size = block_size + self._pad_size = pad_size + self._random_patches = random_patches + self._num_patches = num_patches + prop_file = check_file(prop_file) + self._seq_name = seq_name + self._base_qp = None + self._seq_info = read_json_file(prop_file) + self._create_patch_lists( + deco_dir, + orig_dir, + seq_name, + seq_qps, + slice_type, + min_slice_qp, + max_slice_qp, + min_poc, + max_poc, + ) + + def _create_patch_lists( + self, + root_deco_dir: str, + root_orig_dir: str, + global_seq_name: str, + seq_qps: Union[Tuple[str], List[str]], + slice_type: str, + min_slice_qp: int, + max_slice_qp: int, + min_poc: int, + max_poc: int, + ) -> None: + """ + Gets the filenames of the images to be processed (original, reconstruction) and the frame QP + :param root_deco_dir: Directory that contains the decoded video data + :param root_orig_dir: Directory that contains the original video data + :param global_seq_name: Sequence name + :param seq_qps: list of sequence QPs + :param slice_type: Slice type (i.e. I, P, B, *) + :param min_slice_qp: Minimum frame QP (inclusive) + :param max_slice_qp: Maximum frame QP (inclusive) + """ + root_deco_dir = check_directory(root_deco_dir) + root_orig_dir = check_directory(root_orig_dir) + + if global_seq_name is None: + deco_seq_dirs = list_dirs(root_deco_dir) + else: + deco_seq_dirs = [root_deco_dir / global_seq_name] + + self._patches = [] + + for deco_seq_dir in deco_seq_dirs: + seq_name = deco_seq_dir.name + seq_width = self._seq_info[seq_name]["width"] + seq_height = self._seq_info[seq_name]["height"] + num_frames = self._seq_info[seq_name]["frames"] + bit_depth = self._seq_info[seq_name]["bit-depth"] + + if len(seq_qps) > 0: + qp_dirs = list_selected_dirs(deco_seq_dir, seq_qps) + else: + qp_dirs = list_dirs(deco_seq_dir) + + for qp_dir in qp_dirs: + frames_info = read_json_file(qp_dir / "coding_info.json") + base_qp = frames_info["base_qp"] + reco_yuv_path = qp_dir / f"{seq_name}_rec_before_dbf.yuv" + boun_yuv_path = qp_dir / f"{seq_name}_bs.yuv" + pred_yuv_path = qp_dir / f"{seq_name}_pred.yuv" + ipb_yuv_path = qp_dir / f"{seq_name}_bpm.yuv" + orig_yuv_path = list((root_orig_dir / seq_name).glob("*.yuv"))[0] + + for curr_poc in range(min_poc, min(num_frames, max_poc)): + curr_slice_type = frames_info["POC"][str(curr_poc)]["slice_type"] + curr_slice_qp = frames_info["POC"][str(curr_poc)]["slice_qp"] + curr_tid = frames_info["POC"][str(curr_poc)]["temporal_layer"] + + if ( + (slice_type != "*" and slice_type != curr_slice_type) + or (curr_slice_qp < min_slice_qp) + or (curr_slice_qp > max_slice_qp) + ): + continue + + # compute position + if self._random_patches: + positions = self._compute_random_positions( + seq_width, seq_height, self._num_patches + ) + else: + positions = self._compute_predetermined_positions( + seq_width, seq_height + ) + + for pos in positions: + patch = PatchInfo( + curr_poc, + seq_width, + seq_height, + reco_yuv_path, + boun_yuv_path, + pred_yuv_path, + ipb_yuv_path, + orig_yuv_path, + curr_slice_qp, + base_qp, + curr_tid, + curr_slice_type, + bit_depth, + pos, + ) + self._patches.append(patch) + + def _compute_random_positions( + self, width: int, height: int, num_patches: int + ) -> List[Position]: + luma_bs = self._block_size + max_x = ( + luma_bs * (width // luma_bs) + if width % luma_bs > 0 + else luma_bs * (width // luma_bs - 1) + ) + max_y = ( + luma_bs * (height // luma_bs) + if height % luma_bs > 0 + else luma_bs * (height // luma_bs - 1) + ) + + x = np.random.randint(self._pad_size, max_x + self._pad_size + 1, num_patches) + y = np.random.randint(self._pad_size, max_y + self._pad_size + 1, num_patches) + + x = (x // 2) * 2 + y = (y // 2) * 2 + positions = [] + for idx in range(num_patches): + positions.append(Position(x[idx], y[idx])) + return positions + + def _compute_predetermined_positions( + self, width: int, height: int + ) -> List[Position]: + luma_bs = self._block_size + max_x = ( + luma_bs * (width // luma_bs) + if width % luma_bs > 0 + else luma_bs * (width // luma_bs - 1) + ) + max_y = ( + luma_bs * (height // luma_bs) + if height % luma_bs > 0 + else luma_bs * (height // luma_bs - 1) + ) + positions = [] + for x in range(self._pad_size, max_x + self._pad_size + 1, luma_bs): + for y in range(self._pad_size, max_y + self._pad_size + 1, luma_bs): + positions.append(Position(x, y)) + return positions + + def _create_input_data( + self, + reco_y: Tensor, + reco_u: Tensor, + reco_v: Tensor, + boun_y: Tensor, + boun_u: Tensor, + boun_v: Tensor, + pred_y: Tensor, + pred_u: Tensor, + pred_v: Tensor, + ipb_y: Tensor, + ipb_u: Tensor, + ipb_v: Tensor, + slice_qp: Tensor, + base_qp: Tensor, + pos_x: int, + pos_y: int, + ) -> Tensor: + """ + Creates input data + :param reco_y: luma reconstruction image + :param reco_u: cb reconstruction image + :param reco_v: cr reconstruction image + :param boun_y: luma boundary stength image + :param boun_u: cb boundary stength image + :param boun_v: cr boundary stength image + :param qp: QP + :param pos_x: left corner position + :param pos_y: top corner position + :return: Patch + """ + if self._pad_size > 0: + reco_y = pad_image(reco_y, self._block_size, self._pad_size) + reco_u = pad_image(reco_u, self._block_size // 2, self._pad_size // 2) + reco_v = pad_image(reco_v, self._block_size // 2, self._pad_size // 2) + + boun_y = pad_image(boun_y, self._block_size, self._pad_size) + boun_u = pad_image(boun_u, self._block_size // 2, self._pad_size // 2) + boun_v = pad_image(boun_v, self._block_size // 2, self._pad_size // 2) + + pred_y = pad_image(pred_y, self._block_size, self._pad_size) + pred_u = pad_image(pred_u, self._block_size // 2, self._pad_size // 2) + pred_v = pad_image(pred_v, self._block_size // 2, self._pad_size // 2) + + ipb_y = pad_image(ipb_y, self._block_size, self._pad_size) + + patch_size = self._block_size + self._pad_size * 2 + + actual_y_pos_x = pos_x - self._pad_size + actual_y_pos_y = pos_y - self._pad_size + + actual_uv_pos_x = (pos_x - self._pad_size) // 2 + actual_uv_pos_y = (pos_y - self._pad_size) // 2 + + input_data = {} + input_data["rec_before_dbf_Y"] = reco_y[ + :, + actual_y_pos_y : actual_y_pos_y + patch_size, + actual_y_pos_x : actual_y_pos_x + patch_size, + ] + input_data["rec_before_dbf_U"] = reco_u[ + :, + actual_uv_pos_y : actual_uv_pos_y + patch_size // 2, + actual_uv_pos_x : actual_uv_pos_x + patch_size // 2, + ] + input_data["rec_before_dbf_V"] = reco_v[ + :, + actual_uv_pos_y : actual_uv_pos_y + patch_size // 2, + actual_uv_pos_x : actual_uv_pos_x + patch_size // 2, + ] + input_data["pred_Y"] = pred_y[ + :, + actual_y_pos_y : actual_y_pos_y + patch_size, + actual_y_pos_x : actual_y_pos_x + patch_size, + ] + input_data["pred_U"] = pred_u[ + :, + actual_uv_pos_y : actual_uv_pos_y + patch_size // 2, + actual_uv_pos_x : actual_uv_pos_x + patch_size // 2, + ] + input_data["pred_V"] = pred_v[ + :, + actual_uv_pos_y : actual_uv_pos_y + patch_size // 2, + actual_uv_pos_x : actual_uv_pos_x + patch_size // 2, + ] + input_data["bs_Y"] = boun_y[ + :, + actual_y_pos_y : actual_y_pos_y + patch_size, + actual_y_pos_x : actual_y_pos_x + patch_size, + ] + input_data["bs_U"] = boun_u[ + :, + actual_uv_pos_y : actual_uv_pos_y + patch_size // 2, + actual_uv_pos_x : actual_uv_pos_x + patch_size // 2, + ] + input_data["bs_V"] = boun_v[ + :, + actual_uv_pos_y : actual_uv_pos_y + patch_size // 2, + actual_uv_pos_x : actual_uv_pos_x + patch_size // 2, + ] + input_data["qp_base"] = base_qp * torch.ones( + input_data["rec_before_dbf_Y"].shape + ) + input_data["qp_slice"] = slice_qp * torch.ones( + input_data["rec_before_dbf_Y"].shape + ) + + input_data["ipb_Y"] = ipb_y[ + :, + actual_y_pos_y : actual_y_pos_y + patch_size, + actual_y_pos_x : actual_y_pos_x + patch_size, + ] + + return input_data + + def _create_label_data( + self, y: Tensor, u: Tensor, v: Tensor, pos_x: int, pos_y: int + ) -> Tensor: + """ + Creates label data + :param y: luma image + :param u: cb image + :param v: cr image + :param pos_x: left corner position + :param pos_y: top corner position + :return: Patch + """ + + if self._pad_size > 0: + _, height, width = y.shape + mod = width % self._block_size + right_padding = self._block_size - mod if mod > 0 else 0 + mod = height % self._block_size + bottom_padding = self._block_size - mod if mod > 0 else 0 + + mask = torch.ones(y.shape) + mask = pad(mask, [0, 0, right_padding, bottom_padding]) + + y = pad(y, [0, 0, right_padding, bottom_padding]) + u = pad(u, [0, 0, right_padding // 2, bottom_padding // 2]) + v = pad(v, [0, 0, right_padding // 2, bottom_padding // 2]) + + actual_pos_x = pos_x - self._pad_size + actual_pos_y = pos_y - self._pad_size + + label_data = {} + label_data["orig_Y"] = y[ + :, + actual_pos_y : actual_pos_y + self._block_size, + actual_pos_x : actual_pos_x + self._block_size, + ] + label_data["mask_Y"] = mask[ + :, + actual_pos_y : actual_pos_y + self._block_size, + actual_pos_x : actual_pos_x + self._block_size, + ] + actual_pos_x //= 2 + actual_pos_y //= 2 + label_data["orig_U"] = u[ + :, + actual_pos_y : actual_pos_y + self._block_size // 2, + actual_pos_x : actual_pos_x + self._block_size // 2, + ] + label_data["orig_V"] = v[ + :, + actual_pos_y : actual_pos_y + self._block_size // 2, + actual_pos_x : actual_pos_x + self._block_size // 2, + ] + mask_UV = torch.nn.functional.pixel_unshuffle(label_data["mask_Y"], 2) + + label_data["mask_U"] = mask_UV[0:1, :, :] + label_data["mask_V"] = mask_UV[0:1, :, :] + + return label_data + + def __len__(self) -> int: + """ + Dataset length + """ + return len(self._patches) + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: + """ + Returns a single item (input and label) + :param idx: Index + :return: input and label data + """ + patch = self._patches[idx] + + orig_y, orig_u, orig_v = read_yuv_frame( + patch.orig_yuv_path, + patch.seq_width, + patch.seq_height, + patch.bit_depth, + 1024.0, + False, + patch.poc, + ) + + pos_x, pos_y = patch.pos + + reco_y, reco_u, reco_v = read_yuv_frame( + patch.reco_yuv_path, + patch.seq_width, + patch.seq_height, + self._bit_depth, + 1024.0, + False, + patch.poc, + ) + + boun_y, boun_u, boun_v = read_yuv_frame( + patch.boun_yuv_path, + patch.seq_width, + patch.seq_height, + self._bit_depth, + 1024.0, + False, + patch.poc, + ) + + pred_y, pred_u, pred_v = read_yuv_frame( + patch.pred_yuv_path, + patch.seq_width, + patch.seq_height, + self._bit_depth, + 1024.0, + False, + patch.poc, + ) + + ipb_y, ipb_u, ipb_v = read_yuv_frame( + patch.ipb_yuv_path, + patch.seq_width, + patch.seq_height, + self._bit_depth, + 2.0, + True, + patch.poc, + ) + + slice_qp = float(patch.slice_qp) / 64.0 + base_qp = float(patch.base_qp) / 64.0 + + input_data = self._create_input_data( + reco_y, + reco_u, + reco_v, + boun_y, + boun_u, + boun_v, + pred_y, + pred_u, + pred_v, + ipb_y, + ipb_u, + ipb_v, + slice_qp, + base_qp, + pos_x, + pos_y, + ) + label_data = self._create_label_data(orig_y, orig_u, orig_v, pos_x, pos_y) + + return input_data, label_data, base_qp, slice_qp diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/file_system.py b/training/training_scripts/NN_Adaptive_Filtering/util/file_system.py new file mode 100644 index 0000000000..3323f2fec1 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/util/file_system.py @@ -0,0 +1,292 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import json +import sys +from pathlib import Path +from typing import Dict, List, Tuple, Union + +import numpy as np + + +def check_directory(input_path: Union[Path, str]) -> Path: + """ + Checks whether the given path exists and corresponds to a directory + :param input_path: directory path + :return: Path obj + """ + dir_path = Path(input_path) if isinstance(input_path, str) else input_path + if dir_path.exists() and dir_path.is_dir(): + return dir_path + + print(f"{input_path} is not a directory") + sys.exit(-1) + + +def check_file(input_path: Union[Path, str]) -> Path: + """ + Checks whether the given path exists and corresponds to a file + :param input_path: file path + :return: Path obj + """ + file_path = Path(input_path) if isinstance(input_path, str) else input_path + if file_path.exists() and file_path.is_file(): + return file_path + + print(f"{input_path} is not a file") + sys.exit(-1) + + +def list_dirs(input_dir: Path) -> List[Path]: + """ + Lists the subdirectories of the given directory + :param input_dir: Input directory + :return: list of directories + """ + return [f for f in input_dir.iterdir() if f.is_dir()] + + +def create_directory(input_path: Union[Path, str]) -> Path: + """ + Creates a directory + :param input_path: Input path + :return Path obj + """ + dir_path = Path(input_path) if isinstance(input_path, str) else input_path + dir_path.mkdir(parents=True, exist_ok=True) + return dir_path + + +def list_selected_dirs( + input_dir: Path, pattern: Union[str, Tuple[str], List[str]] +) -> List[Path]: + """ + Lists the subdirectories that contain a given text in their names + :param input_dir: Input directory + :param pattern: pattern to match + :return: list of directories + """ + if isinstance(pattern, str): + return [f for f in input_dir.iterdir() if f.is_dir() and pattern in f.name] + + return [f for f in input_dir.iterdir() if f.is_dir() and f.name in pattern] + + +def read_json_file(json_path: Union[Path, str]) -> Dict: + """ + Reads JSON file + :param json_path: Absolute path to the JSON file + :return: Dictionary containing JSON file data + """ + file_path = check_file(json_path) + with open(file_path, "r") as stream: + config = json.load(stream) + return config + + +def write_json_file(content: Dict, json_path: Union[Path, str]) -> None: + """ + Writes a dictionary to a JSON file + :param content: Dictionary to be saved + :param json_path: Absolute path to the JSON file + """ + file_path = Path(json_path) if isinstance(json_path, str) else json_path + assert ( + file_path.parent.exists() and file_path.parent.is_dir() + ), f"The parent directory {file_path.parent} does not exist" + with open(file_path, "w") as stream: + json.dump(content, stream, sort_keys=True, indent=4) + + +def create_vtm_config_file( + cfg_file: Path, filename: Path, width: int, height: int, fps: int, num_frames: int +) -> None: + """ + Creates the sequence config file for VTM encoding + :param cfg_file: Output file name + :param filename: YUV file name + :param width: Width of the YUV + :param height: Height of the YUV + :param fps: Frame rate of the YUV + :param num_frames: Number of frames to be encoded + """ + with open(cfg_file, "w") as stream: + stream.write(f"InputFile: {filename}\n") + stream.write(f"SourceWidth: {width}\n") + stream.write(f"SourceHeight: {height}\n") + stream.write("InputBitDepth: 10\n") + stream.write("InputChromaFormat: 420\n") + stream.write(f"FrameRate: {fps}\n") + stream.write("FrameSkip: 0\n") + stream.write(f"FramesToBeEncoded: {num_frames}\n") + stream.write("Level: 5.1\n") + + +def read_yuv_frame( + file_path: Path, + width: int, + height: int, + bit_depth: int, + frame_idx: int, + frame_skip: int = 0, + temporal_subsample_ratio: int = 1, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Reads YUV frame + :param file_path: Absolute file path + :param width: Luma width + :param height: Luma height + :param bit_depth: Bit-depth + :param frame_idx: Frame index + :param frame_skip: Number of frames to skip + :param temporal_subsample_ratio: Temporal subsample ratio (see VTM config) + :return: Y, U, V frames + """ + word = 1 if bit_depth == 8 else 2 + pix_type = np.uint8 if bit_depth == 8 else np.uint16 + + frame_pos = (frame_idx + frame_skip) * temporal_subsample_ratio + frame_size = width * height * word * 3 // 2 + + uv_width = width // 2 + uv_height = height // 2 + + with open(file_path, "rb") as stream: + stream.seek(frame_pos * frame_size) + + y = np.frombuffer(stream.read(width * height * word), pix_type) + u = np.frombuffer(stream.read(uv_width * uv_height * word), pix_type) + v = np.frombuffer(stream.read(uv_width * uv_height * word), pix_type) + + y = np.reshape(y, (height, width)) + u = np.reshape(u, (height // 2, width // 2)) + v = np.reshape(v, (height // 2, width // 2)) + + if bit_depth == 8: + pix_type = np.uint16 + y = y.astype(pix_type) * 4 + u = u.astype(pix_type) * 4 + v = v.astype(pix_type) * 4 + + return y, u, v + + +def read_yuv_patch( + file_path: Path, + width: int, + height: int, + bit_depth: int, + frame_idx: int, + luma_pos: int, + luma_patch_size: int, + uv_pos: int, + uv_patch_size: int, + frame_skip: int = 0, + temporal_subsample_ratio: int = 1, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Reads YUV patch + :param file_path: Absolute file path + :param width: Luma frame width + :param height: Luma frame height + :param bit_depth: Bit-depth + :param frame_idx: Frame index + :param luma_pos: Top-left corner position of the luma patch + :param luma_patch_size: Luma patch size (one side) + :param uv_pos: Top-left corner position of the chroma patches + :param uv_patch_size: Chroma patch size (one side) + :param frame_skip: Number of frames to skip + :param temporal_subsample_ratio: Temporal subsample ratio (see VTM config) + :return: Y, U, V patches + """ + word = 1 if bit_depth == 8 else 2 + pix_type = np.uint8 if bit_depth == 8 else np.uint16 + + y_frame_size = width * height + uv_frame_size = width // 2 * height // 2 + full_frame_size = y_frame_size + 2 * uv_frame_size + + frame_pos = (frame_idx + frame_skip) * temporal_subsample_ratio + + uv_width = width // 2 + + y = np.zeros((luma_patch_size, luma_patch_size), dtype=np.uint16) + u = np.zeros((uv_patch_size, uv_patch_size), dtype=np.uint16) + v = np.zeros((uv_patch_size, uv_patch_size), dtype=np.uint16) + + with open(file_path, "rb") as stream: + patch_start = word * ( + frame_pos * full_frame_size + luma_pos.y * width + luma_pos.x + ) + stream.seek(patch_start, 0) + + for i in range(luma_patch_size): + row = np.frombuffer(stream.read(luma_patch_size * word), pix_type) + y[i, :] = row.astype(np.uint16) * 4 if bit_depth == 8 else row + stream.seek(word * (width - luma_patch_size), 1) + + patch_start = word * ( + frame_pos * full_frame_size + y_frame_size + uv_pos.y * uv_width + uv_pos.x + ) + stream.seek(patch_start, 0) + + for i in range(uv_patch_size): + row = np.frombuffer(stream.read(uv_patch_size * word), pix_type) + u[i, :] = row.astype(np.uint16) * 4 if bit_depth == 8 else row + stream.seek(word * (uv_width - uv_patch_size), 1) + + patch_start = word * ( + frame_pos * full_frame_size + + y_frame_size + + uv_frame_size + + uv_pos.y * uv_width + + uv_pos.x + ) + stream.seek(patch_start, 0) + + for i in range(uv_patch_size): + row = np.frombuffer(stream.read(uv_patch_size * word), pix_type) + v[i, :] = row.astype(np.uint16) * 4 if bit_depth == 8 else row + stream.seek(word * (uv_width - uv_patch_size), 1) + + return y, u, v + + +def get_models(root_dir: Path) -> List[Path]: + sub_dirs = list_dirs(root_dir) + return [ + v / "models/best_model.pt" + for v in sub_dirs + if (v / "models/best_model.pt").is_file() + ] diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/image_ops.py b/training/training_scripts/NN_Adaptive_Filtering/util/image_ops.py new file mode 100644 index 0000000000..595167e209 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/util/image_ops.py @@ -0,0 +1,149 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from pathlib import Path +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor + +from util import DEVICE + + +def normalise_image(in_img: np.ndarray, normalize_factor: int) -> Tensor: + """ + Normalises an image + :param in_img: Input image + :param normalize_factor: Factor for normalization + :return: Normalised image + """ + out_img = torch.from_numpy(in_img.astype(np.float32)).to(DEVICE) + out_img = out_img.div(normalize_factor) + out_img = out_img.unsqueeze(0) + return out_img + + +def read_yuv_frame( + file_path: Path, + width: int, + height: int, + bit_depth: int, + normalize_factor: int, + ipb: bool, + frame_idx: int, + frame_skip: int = 0, + temporal_subsample_ratio: int = 1, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Reads YUV frame + :param file_path: Absolute file path + :param width: Luma width + :param height: Luma height + :param bit_depth: Bit-depth + :param frame_idx: Frame index + :param frame_skip: Number of frames to skip + :param temporal_subsample_ratio: Temporal subsample ratio (see VTM config) + :return: Y, U, V frames + """ + word = 1 if bit_depth == 8 else 2 + pix_type = np.uint8 if bit_depth == 8 else np.uint16 + + frame_pos = (frame_idx + frame_skip) * temporal_subsample_ratio + frame_size = width * height * word * 3 // 2 + + uv_width = width // 2 + uv_height = height // 2 + + with open(file_path, "rb") as stream: + stream.seek(frame_pos * frame_size) + + y = np.frombuffer(stream.read(width * height * word), pix_type) + u = np.frombuffer(stream.read(uv_width * uv_height * word), pix_type) + v = np.frombuffer(stream.read(uv_width * uv_height * word), pix_type) + + y = np.reshape(y, (height, width)) + u = np.reshape(u, (height // 2, width // 2)) + v = np.reshape(v, (height // 2, width // 2)) + + if bit_depth == 8: + pix_type = np.uint16 + y = y.astype(pix_type) * 4 + u = u.astype(pix_type) * 4 + v = v.astype(pix_type) * 4 + + if ipb: + # first right shift: Ipred=0, IBC=1, Uni-pred=2, bi-pred=3 + y = np.right_shift(y, 1) + u = np.right_shift(u, 1) + v = np.right_shift(v, 1) + # Ipred/IBV=0, Unipred=1, Bi-pred=2 + y = (y == 0) + y - 1 + u = (u == 0) + u - 1 + v = (v == 0) + v - 1 + + y = normalise_image(y, normalize_factor) + u = normalise_image(u, normalize_factor) + v = normalise_image(v, normalize_factor) + + return y, u, v + + +def pad_image(in_image: Tensor, block_size: int, border_size: int) -> Tensor: + """ + Applies padding to the input image + :param in_image: Input image + :param block_size: Size of the actual block (final output size) + :param border_size: Number of samples added to each side of the block size + :return: Padded image + """ + _, height, width = in_image.shape + + mod = width % block_size + if mod > 0: + right_border = border_size + block_size - mod + else: + right_border = border_size + + mod = height % block_size + if mod > 0: + bottom_border = border_size + block_size - mod + else: + bottom_border = border_size + + transform = torch.nn.ConstantPad2d( + (border_size, right_border, border_size, bottom_border), 0.0 + ) + out_image = transform(in_image) + return out_image diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/logger.py b/training/training_scripts/NN_Adaptive_Filtering/util/logger.py new file mode 100644 index 0000000000..aab5c5e9dd --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/util/logger.py @@ -0,0 +1,202 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import sys +import time +from datetime import datetime, timedelta +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from util import COLOUR_LABEL, TIME_FORMAT, Colour + + +def compute_elapsed_time(start_time: str) -> Tuple[str, timedelta]: + end_time = time.strftime(TIME_FORMAT) + elapsed_time = datetime.strptime(end_time, TIME_FORMAT) - datetime.strptime( + start_time, TIME_FORMAT + ) + return end_time, elapsed_time + + +class Logger: + def __init__(self, log_dir: str, stop_patience: int): + log_dir = Path(log_dir) + log_dir.mkdir(exist_ok=True, parents=True) + + self._models_dir = log_dir / "models" + self._models_dir.mkdir(exist_ok=True, parents=True) + self._checkpoint_dir = self._models_dir / "checkpoints" + self._checkpoint_dir.mkdir(exist_ok=True, parents=True) + + self._time_file = log_dir / "time.csv" + self._write_idx = 0 + self._train_writer = SummaryWriter(str(log_dir / "train")) + + self._start_time = time.strftime(TIME_FORMAT) + self._last_time = self._start_time + + self._best_loss = sys.float_info.max + self._num_no_improvements = 0 + self._stop_patience = stop_patience + + def log_metrics( + self, + epoch: int, + batch: int, + loss: np.ndarray, + psnr: np.ndarray, + delta_psnr_wrt_vtm: np.ndarray, + delta_psnr_wrt_base: np.ndarray, + positive_delta_psnr_wrt_base: np.ndarray, + portion_positive_delta_psnr: np.ndarray, + ) -> None: + _, elapsed_time = compute_elapsed_time(self._last_time) + + metrics_str = "" + for c in range(Colour.YCbCr + 1): + label = COLOUR_LABEL[c] + + if loss is not None: + self._train_writer.add_scalar(f"Loss_{label}", loss[c], epoch) + metrics_str += "Loss_{}: {:.6f}; ".format(label, loss[c]) + if psnr is not None: + self._train_writer.add_scalar(f"PSNR_{label}", psnr[c], epoch) + metrics_str += "PSNR_{}: {:.6f}; ".format(label, psnr[c]) + if delta_psnr_wrt_vtm is not None: + self._train_writer.add_scalar( + f"dPSNR_{label}_wrt_vtm", delta_psnr_wrt_vtm[c], epoch + ) + metrics_str += "dPSNR_{}_wrt_vtm: {:.6f}; ".format( + label, delta_psnr_wrt_vtm[c] + ) + if delta_psnr_wrt_base is not None: + self._train_writer.add_scalar( + f"dPSNR_{label}_wrt_base", delta_psnr_wrt_base[c], epoch + ) + metrics_str += "dPSNR_{}_wrt_base: {:.6f}; ".format( + label, delta_psnr_wrt_base[c] + ) + if positive_delta_psnr_wrt_base is not None: + self._train_writer.add_scalar( + f"dPSNR_positive_{label}", positive_delta_psnr_wrt_base[c], epoch + ) + metrics_str += "dPSNR_positive_{}: {:.6f}; ".format( + label, positive_delta_psnr_wrt_base[c] + ) + if portion_positive_delta_psnr is not None: + self._train_writer.add_scalar( + f"percentage_positive_dpsnr_{label}", + portion_positive_delta_psnr[c], + epoch, + ) + metrics_str += "percentage_positive_dpsnr_{}: {:.6f}; ".format( + label, portion_positive_delta_psnr[c] + ) + + # print( + # "Epoch: {:03d}; Batch: {:06d}; {}Time: {}".format( + # epoch, batch, metrics_str, elapsed_time + # ) + # ) + self._last_time = time.strftime(TIME_FORMAT) + + def save_model( + self, + epoch: int, + loss: float, + model: torch.nn.Module, + optimiser: torch.optim.Optimizer, + lr_scheduler, + save_interval: int = 1, + ) -> bool: + # x == x to check nan + if loss == loss and loss < self._best_loss: + self._best_loss = loss + self._num_no_improvements = 0 + print(f"Saving the best model at epoch {epoch}...") + torch.save( + model.module.state_dict(), str(self._models_dir / "best_model.pt") + ) + else: + self._num_no_improvements += 1 + + stop_train = self._num_no_improvements > self._stop_patience + + if (epoch + 1) % save_interval != 0: + return stop_train + + print(f"Saving checkpoint at epoch {epoch}...") + torch.save( + model.module.state_dict(), + str(self._checkpoint_dir / "model_{:03d}.pt".format(epoch)), + ) + torch.save( + optimiser.state_dict(), + str(self._checkpoint_dir / "model_{:03d}_optimiser.pt".format(epoch)), + ) + torch.save( + lr_scheduler.state_dict(), + str(self._checkpoint_dir / "model_{:03d}_lr_scheduler.pt".format(epoch)), + ) + + return stop_train + + def reset(self) -> None: + self._best_loss = sys.float_info.max + self._num_no_improvements = 0 + + def save_training_time(self) -> None: + end_time, elapsed_time = compute_elapsed_time(self._start_time) + with open(self._time_file, "w") as stream: + stream.writelines("start,end,duration\n") + stream.writelines(f"{self._start_time},{end_time},{elapsed_time}\n") + + def save_param_names(self, param_names: List[str]) -> None: + with open(self._checkpoint_dir / "param_names.txt", "w") as file: + names_to_write = "\n".join(param_names) + file.write(names_to_write) + + def load_param_names(self) -> List[str]: + param_names = [] + try: + with open(self._checkpoint_dir / "param_names.txt", "r") as file: + param_names = file.readlines() + except Exception: + print("No files with names of trainable parameters") + + return param_names diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/metrics.py b/training/training_scripts/NN_Adaptive_Filtering/util/metrics.py new file mode 100644 index 0000000000..8cecabb3cf --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/util/metrics.py @@ -0,0 +1,164 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import math +from typing import Dict, Tuple + +import torch +from torch import Tensor + +from util import COLOUR_WEIGHT, Colour + + +def compute_sample_loss_psnr( + ground_truth: Tensor, prediction: Tensor, loss_func +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes the MSE and PSNR per sample + :param ground_truth: Ground-truth + :param prediction: Prediction + :param loss_func: Loss function + :return: MSE, PSNR and mask that indicates what PSNR values are valid (not inf nor nan) + """ + + label, loss = loss_func(ground_truth, prediction) + + if label == "MSE": + psnr = 10.0 * (torch.log(1.0 / loss) / math.log(10.0)) + else: + mse = torch.mean((ground_truth - prediction) ** 2, dim=(1, 2, 3)) + psnr = 10.0 * (torch.log(1.0 / mse) / math.log(10.0)) + + psnr, mask = zero_out_inf_nan(psnr) + return loss, psnr, mask + + +def zero_out_negative(in_tensor: Tensor) -> Tuple[Tensor, Tensor]: + """ + Zeroes out values in the input tensor that are negative + :param in_tensor: Input tensor + :return: Tensor with zero on those positions where there was previously a negative value. Also a mask that + indicates what are the valid values in the tensor + """ + out_tensor = torch.where(in_tensor < 0.0, 0.0, in_tensor) + mask = torch.where(in_tensor > 0.0, 1.0, 0.0) + return out_tensor, mask + + +def zero_out_inf_nan(in_tensor: Tensor) -> Tuple[Tensor, Tensor]: + """ + Zeroes out values in the input tensor that are either inf or nan + :param in_tensor: Input tensor + :return: Tensor with zero on those positions where there was previously an inf or nan value. Also a mask that + indicates what are the valid values in the tensor + """ + condition = torch.logical_or(torch.isinf(in_tensor), torch.isnan(in_tensor)) + out_tensor = torch.where(condition, 0.0, in_tensor) + mask = torch.where(condition, 0.0, 1.0) + return out_tensor, mask + + +def compute_loss_psnr( + ground_truth: Dict, + prediction_Y: Tensor, + prediction_U: Tensor, + prediction_V: Tensor, + loss_func, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes both the Loss and PSNR colour- and sample-wise + :param ground_truth: Ground-truth + :param prediction: Prediction + :param loss_func: Loss function + :return: MSE, PSNR and a mask. The mask indicates for which samples the PSNR is valid (not inf nor nan) + """ + + y_orig = ground_truth["orig_Y"] + cb_orig = ground_truth["orig_U"] + cr_orig = ground_truth["orig_V"] + + y_pred = torch.multiply(prediction_Y, ground_truth["mask_Y"]) + cb_pred = torch.multiply(prediction_U, ground_truth["mask_U"]) + cr_pred = torch.multiply(prediction_V, ground_truth["mask_V"]) + + loss = [0.0] * (Colour.YCbCr + 1) + psnr = [0.0] * (Colour.YCbCr + 1) + mask = [0.0] * (Colour.YCbCr + 1) + + loss[Colour.Y], psnr[Colour.Y], mask[Colour.Y] = compute_sample_loss_psnr( + y_pred, y_orig, loss_func + ) + loss[Colour.Cb], psnr[Colour.Cb], mask[Colour.Cb] = compute_sample_loss_psnr( + cb_pred, cb_orig, loss_func + ) + loss[Colour.Cr], psnr[Colour.Cr], mask[Colour.Cr] = compute_sample_loss_psnr( + cr_pred, cr_orig, loss_func + ) + + y_weight = COLOUR_WEIGHT[Colour.Y] + cb_weight = COLOUR_WEIGHT[Colour.Cb] + cr_weight = COLOUR_WEIGHT[Colour.Cr] + + loss[Colour.YCbCr] = ( + y_weight * loss[Colour.Y] + + cb_weight * loss[Colour.Cb] + + cr_weight * loss[Colour.Cr] + ) + psnr[Colour.YCbCr] = ( + y_weight * psnr[Colour.Y] + + cb_weight * psnr[Colour.Cb] + + cr_weight * psnr[Colour.Cr] + ) + + loss = torch.stack(loss) + psnr = torch.stack(psnr) + + mask[Colour.YCbCr] = mask[Colour.Y] * mask[Colour.Cb] * mask[Colour.Cr] + mask = torch.stack(mask) + + return loss, psnr, mask + + +def mse_loss(orig, pred): + return "MSE", torch.mean((orig - pred) ** 2, dim=(1, 2, 3)) + + +def mae_loss(orig, pred): + return "MAE", torch.mean(torch.abs(orig - pred), dim=(1, 2, 3)) + + +LOSS_OP = { + "mse": mse_loss, + "mae": mae_loss, +} diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/regex.py b/training/training_scripts/NN_Adaptive_Filtering/util/regex.py new file mode 100644 index 0000000000..2fdeeddab0 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/util/regex.py @@ -0,0 +1,127 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import re +from typing import Tuple + +# regular expression to identify the segments +ra_segment_rgx = re.compile(r"p(\d+)") + + +def get_regex_groups_for_pattern( + text: str, pattern: str, expected_number_of_groups: int +) -> re.match: + """ + Matches a regular expression from the beginning of a given text + :param text: Input text + :param pattern: Regular expression + :param expected_number_of_groups: number of groups to extract + :return: match + """ + match = re.search(pattern, text) + if match is None or len(match.groups()) != expected_number_of_groups: + raise ValueError(f"Couldn't find pattern {pattern} in '{text}'") + return match + + +def look_for_string_pattern_in_text(text: str, pattern: str) -> str: + """ + Looks for a pattern in a text + :param text: Input text + :param pattern: Pattern + :return: String in the text that matches the pattern + """ + matches = re.findall(pattern, text) + if len(matches) > 0: + return matches[0] + return "0" + + +def get_frame_info_from_decoder_log_line(text: str) -> Tuple[int, int, str, int]: + """ + Given an input text, it extracts video frame information + :param text: Input text + :return: POC, temporal layer, slice type and QP + """ + pattern = r"POC\s+(\d+)\s+LId:\s+\d+\s+TId:\s+(\d+)\s+\(\s+.*,\s+([I|P|B])-\w+,\s+QP\s+(\d+)" + match = get_regex_groups_for_pattern(text, pattern, 4) + return int(match.group(1)), int(match.group(2)), match.group(3), int(match.group(4)) + + +def get_data_from_encoder_log( + text: str, +) -> Tuple[str, int, int, float, float, float, float, float, float]: + """ + It extracts the following data from a line of the encoder log: + slice type, qp, bits, psnr yuv (x3), mssim yuv (x3) + :param text: file line + :return: Slice type, slice QP, bits, PSNR Y, PSNR U, PSNR V, MSSIM Y, MSSIM U, MSSIM V + """ + pattern = r"([I|B])-SLICE\W\s+QP\s+(\d+)\s+\)\s+(\d+)\s+bits\s+\[Y\s+(\d+.\d+)\s+dB\s+U\s+(\d+.\d+)\s+dB\s+V\s+(\d+.\d+)\s+dB\]\s+\[\w+\s+\w+\s+\w+\s+\w+\s+\w+\s+\w+\]\s+\[MS-SSIM\s+Y\s+(\d\.\d+)\s+U\s+(\d+\.\d+)\s+V\s+(\d+\.\d+)" + result = get_regex_groups_for_pattern(text, pattern, 9) + return ( + result.group(1), + int(result.group(2)), + int(result.group(3)), + float(result.group(4)), + float(result.group(5)), + float(result.group(6)), + float(result.group(7)), + float(result.group(8)), + float(result.group(9)), + ) + + +def get_regex_groups_from_org_zip_file(file_stem: str) -> re.match: + file_pattern = r"((\w+_\w+_*\w*)_(\d+)_part(\d+))" + return get_regex_groups_for_pattern(file_stem, file_pattern, 4) + + +def get_bitrate_and_mse_from_encoder_log( + text: str, +) -> Tuple[float, float, float, float, float]: + """ + Extracts the bitrate and MSEs from the summary line in the encoder log + :param text: file line + :return: Bit-rate, MSE Y, MSE U, MSE V + """ + pattern = r"(\d+\.\d+)\s+\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+\s+\w+\s+\w+\s+\w+\s+\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+\s+(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)" + result = get_regex_groups_for_pattern(text, pattern, 5) + return ( + float(result.group(1)), + float(result.group(2)), + float(result.group(3)), + float(result.group(4)), + float(result.group(5)), + ) diff --git a/training/training_scripts/NN_Adaptive_Filtering/wu_decoding.py b/training/training_scripts/NN_Adaptive_Filtering/wu_decoding.py new file mode 100644 index 0000000000..12afb7130b --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/wu_decoding.py @@ -0,0 +1,131 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import argparse +import os +import sys +from pathlib import Path +from tempfile import NamedTemporaryFile + +import nnc_core +import torch +from framework.mpeg_applications.utils import icnn_tools + +from conversion.sadl2torch import create_model, map_params +from conversion.torch2sadl import torch_to_onnx +from util.file_system import read_json_file, write_json_file +from wu_encoding import get_model_data_info + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser("Decoding weight-update") + parser.add_argument("--arch", type=str, required=True, help="Model architecture") + parser.add_argument( + "--base_model", type=str, required=True, help="Path to base model" + ) + parser.add_argument( + "--model_arch", type=str, default="lop2", help="Name of the model architecture" + ) + parser.add_argument( + "--nnr_bitstream", type=str, required=True, help="Path to NNR bitstream" + ) + parser.add_argument( + "--reco_model", type=str, required=True, help="Path to reconstructed model" + ) + return parser.parse_args() + + +def decode_weight_udpate_and_2sadl(): + if "lop2" != args.arch: + print(f"{args.arch} not supported") + + base_model_config = read_json_file("models/models.json")["lop2"] + nnr_bitstream_file = Path(args.nnr_bitstream) + + with NamedTemporaryFile() as dq_torch_file, NamedTemporaryFile() as reco_torch_file, NamedTemporaryFile() as onnx_file, NamedTemporaryFile() as sadl_f_file, NamedTemporaryFile() as quantiser_file, NamedTemporaryFile() as base_quantizers_file, NamedTemporaryFile() as map_orig_params_to_new_file: + dq_model, base_model_quantizers = create_model( + args.base_model, base_model_config + ) + torch.save(dq_model.state_dict(), dq_torch_file.name) + write_json_file(base_model_quantizers, base_quantizers_file.name) + + base_model, base_model_params = get_model_data_info( + dq_torch_file.name, base_model_config, False + ) + + map_orig_params_to_new = map_params(base_model_config) + write_json_file(map_orig_params_to_new, map_orig_params_to_new_file.name) + + # Decode + diff_dec_approx_params = {"parameters": {}, "put_node_depth": {}} + diff_rec_approx_data, bs_size, dec_model_info, res = nnc_core.dec_and_rec( + nnr_bitstream_file, + True, + tml=None, + epoch=1, + client=0, + parameter_index=base_model.model_info["parameter_index"], + parameter_dimensions=base_model.model_info["parameter_dimensions"], + dec_approx_param_base=diff_dec_approx_params, + update_base_param=True, + ) + + # Restore torch model + restored_params = icnn_tools.add( + base_model_params, diff_rec_approx_data["parameters"] + ) + base_model.restore_and_save(restored_params, reco_torch_file.name) + + # Torch to Onnx + torch_to_onnx(reco_torch_file.name, onnx_file.name, base_model_config) + + # onnx to SADL float + command = ( + f"python conversion/onnx2sadl_modified.py --input_onnx {onnx_file.name} --output {sadl_f_file.name} " + f"--out_quant {quantiser_file.name} --base_quantizers {base_quantizers_file.name} " + f"--orig_params_to_new {map_orig_params_to_new_file.name}" + ) + if os.system(command): + sys.exit(-1) + + # convert float SADL to int16 + command_int = f"tail -n 1 {quantiser_file.name} | ./naive_quantization {sadl_f_file.name} {args.reco_model}" + if os.system(command_int): + sys.exit(-1) + + +if __name__ == "__main__": + args = parse_arguments() + decode_weight_udpate_and_2sadl() + sys.exit(0) diff --git a/training/training_scripts/NN_Adaptive_Filtering/wu_encoding.py b/training/training_scripts/NN_Adaptive_Filtering/wu_encoding.py new file mode 100644 index 0000000000..b107ef1695 --- /dev/null +++ b/training/training_scripts/NN_Adaptive_Filtering/wu_encoding.py @@ -0,0 +1,253 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Tuple, Union + +import nnc_core +import numpy as np +from framework.mpeg_applications.utils import icnn_tools + +import config +from conversion.torch2sadl import onnx_to_sadl, torch_to_onnx +from models.model_lop2_nnr import PytorchModel +from util.regex import get_regex_groups_from_org_zip_file + + +def create_enc_info() -> Dict: + param_opt = True + temporal_context = False + + info = { + "cabac_unary_length_minus1": 10, + "param_opt_flag": param_opt, + "partial_data_counter": 0, + } + + if config.PUT_SYNTAX(): + info["node_id_present_flag"] = 1 + info["device_id"] = 0 + info["parent_node_id_present_flag"] = 1 + info["parent_node_id_type"] = nnc_core.hls.ParentNodeIdType.ICNN_NDU_ID + info["parent_device_id"] = 0 + + if config.TEMPORAL_CONTEXT(): + info["temporal_context_modeling_flag"] = 1 if temporal_context else 0 + + return info + + +def get_model_data_info( + model_path: Union[Path, str], + model_config: Dict, + is_overf_model: bool, + block_size: int = 0, + pad_size: int = 0, +) -> Tuple[PytorchModel, Dict]: + model = PytorchModel( + model_config, + block_size, + pad_size, + ) + path = model_path if isinstance(model_path, str) else str(model_path) + model_params = model.load_model(path, is_overf_model) + return model, model_params + + +def compress_one_model_and_2sadl( + overfitted_model_dir: Path, + base_model_file: Path, + model_config: Dict, + base_quantizers: Path, + map_orig_params_to_new, + nnr_models_dir: Path, + onnx_models_dir: Path, + sadl_float_dir: Path, + sadl_int_dir: Path, + quantizers_dir: Path, + train_loader: Any, + block_size: int, + pad_size: int, +): + enc_info = create_enc_info() + + qp_density = 2 + + approx_method = "uniform" + nnr_qp = -40 + opt_qp = True + disable_dq = True + lambda_scale = 0.0 + cb_size_ratio = 5000 + q_mse = 0.00001 + inc_bn_folding = False + + qp_step = 1 + bias = 0.005 + + overfitted_model_dir = Path(overfitted_model_dir) + nnr_models_dir = Path(nnr_models_dir) + onnx_models_dir = Path(onnx_models_dir) + sadl_float_dir = Path(sadl_float_dir) + sadl_int_dir = Path(sadl_int_dir) + quantizers_dir = Path(quantizers_dir) + + overfitted_model_file = overfitted_model_dir / "models" / "best_model.pt" + + matched_groups = get_regex_groups_from_org_zip_file(overfitted_model_dir.name) + output_stem = matched_groups.group(1) + + print("-----------------------------") + print(output_stem) + print("-----------------------------") + + _, overfitted_model_params = get_model_data_info( + overfitted_model_file, model_config, True, block_size, pad_size + ) + base_model, base_model_params = get_model_data_info( + base_model_file, model_config, False, block_size, pad_size + ) + + approx_data = base_model.init_approx_data( + base_model_params, qp_density, scan_order=0 + ) + + approx_info = nnc_core.nnr_model.ApproxInfo( + approx_data, + base_model.model_info, + approx_method, + nnr_qp, + opt_qp, + disable_dq, + lambda_scale, + cb_size_ratio, + q_mse, + ) + + nnr_qp = np.int32(nnr_qp) + diff_qp = nnr_qp + + model_param_diff = icnn_tools.model_diff(overfitted_model_params, base_model_params) + + diff_dec_approx_params = {"parameters": {}, "put_node_depth": {}} + + # Iterative QP + if config.OPT_QP() and opt_qp: + base_model.set_dataset(train_loader) + ref_perf, _, _ = base_model.eval_model(base_model_params) + with NamedTemporaryFile( + dir="/tmp", suffix=".nnr", prefix="bitstream_", delete=False + ) as tmp_nnr_bs: + diff_qp = icnn_tools.opt_qp( + diff_qp, + model_param_diff, + base_model_params, + diff_dec_approx_params, + base_model, + base_model.model_info, + ref_perf, + approx_data, + approx_info, + enc_info, + inc_bn_folding, + tmp_nnr_bs.name, + 1, + save_bitstreams=False, + tml=None, + sbt_args=None, + bias=bias, + qp_step=qp_step, + crit="acc", + ) + + approx_data["parameters"] = model_param_diff + + approx_info.apply_qp(approx_data, base_model.model_info, diff_qp) + + bitstream_path = nnr_models_dir / "nnr" / f"{output_stem}.nnr" + + # Approximate and encode + len_of_bs, enc_diff_rec_approx_data = nnc_core.approx_and_enc( + base_model.model_info, + approx_data, + diff_dec_approx_params, + approx_info, + enc_info, + num_workers=4, + bs_filename=bitstream_path, + tml=None, + n_epochs=1, + epoch=1, + client=0, + sbt_args=None, + ) + assert ( + len(enc_diff_rec_approx_data["parameters"]) > 0 + ), "The weight update was fully sparsified" + + # Decode + diff_rec_approx_data, bs_size, dec_model_info, res = nnc_core.dec_and_rec( + bitstream_path, + True, + tml=None, + epoch=1, + client=0, + parameter_index=base_model.model_info["parameter_index"], + parameter_dimensions=base_model.model_info["parameter_dimensions"], + dec_approx_param_base=diff_dec_approx_params, + update_base_param=True, + ) + + # Restoration + restored_params = icnn_tools.add( + base_model_params, diff_rec_approx_data["parameters"] + ) + + saved_model_path = nnr_models_dir / f"nnr_{output_stem}.pt" + base_model.restore_and_save(restored_params, str(saved_model_path)) + + # Torch to Onnx + out_onnx_path = onnx_models_dir / saved_model_path.stem + torch_to_onnx(saved_model_path, out_onnx_path, model_config) + + # Onnx to SADL and quantized by using base model's quantizers + onnx_to_sadl( + sadl_float_dir, + out_onnx_path, + quantizers_dir, + sadl_int_dir, + base_quantizers, + map_orig_params_to_new, + ) -- GitLab