diff --git a/.gitattributes b/.gitattributes
index acf3ec92dd9bf76897b6449376a2c0530f28df04..b6c4a19bc7f25cf79ef8c72cd7393825824ab736 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -10,3 +10,4 @@
 *.index filter=lfs diff=lfs merge=lfs -text
 *.pb filter=lfs diff=lfs merge=lfs -text
 *.data-* filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
diff --git a/training/training_scripts/NN_Adaptive_Filtering/README.md b/training/training_scripts/NN_Adaptive_Filtering/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..685419a81fc1480cfacd9e747eb016de064ae300
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/README.md
@@ -0,0 +1,312 @@
+# Torch Over-fitting LOP2
+
+## Instructions
+
+Content-adaptive LOP2 via NN overfitting on each intra-period
+
+This document explains how to conduct the training and generate the models for inference.
+The steps are:
+
+0. Requirements
+1. Set up
+2. Data Preparation
+3. Convert base models from SADL to Torch
+4. Overfitting pipeline
+5. Inference
+
+## 0. Requirements
+
+* Torch 1.12.1
+* [MPEG NCTM repository](https://git.mpeg.expert/MPEG/Video/NNCoding/NCTM). This is used to code the weight-updates that result from the overfitting process.
+* Executable file ```naive_quantization```. This is used to convert float SADL model to int SADL model.
+
+**Note:** The executable file ```naive_quantization``` can be obtained by building SADL project. Then copy it to this workspace.
+
+To get access to the NCTM repository, please follow these steps:
+
+1. Request an account in [https://git.mpeg.expert/](https://git.mpeg.expert/)
+2. Inform the AG/WG Convenor to request the account approval
+3. Once the account has been approved, request to join the NCTM project to Werner Bailer (werner.bailer@joanneum.at)
+
+The environment can be created with venv or anaconda:
+
+### venv setup
+`create_env.sh` can be used to create the virtual environment, called `lop2_overf`
+
+Start using with:
+
+```shell
+source ${HOME}/lop2_overf/bin/activate
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/lop2_overf/lib
+export PYTHONPATH=${PYTHONPATH}:${PWD}
+```
+
+### anaconda setup
+
+```shell
+conda env create -f environment.yml
+conda activate lop2_overf
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/anaconda3/envs/lop2_overf/lib
+export PYTHONPATH=${PYTHONPATH}:${PWD}
+```
+
+Clone NCTM repository and apply the patch provided
+
+```shell
+git clone https://git.mpeg.expert/MPEG/Video/NNCoding/NCTM.git
+cd NCTM
+git checkout v1-v2-harmonization
+```
+
+Then, install NCTM as library on top of the active environment
+
+```shell
+python setup.py build
+python setup.py install
+```
+
+The NCTM directory can be removed now.
+
+## 1. Set up
+
+All the parameters for training are taken from the [resources/config.json](resources/config.json) config file.
+Mandatory updates based on your system:
+
+* `orig_path`: abs path to original data (will be created)
+* `deco_path`: abs path to decoded data (will be created)
+* `output_path`: abs path to save the output files from the overfitting (will be created)
+
+The working directory is `<repo_path>/training/training_scripts/EE1-1.4`.
+
+The overfitting data is organised in a specific way. Create the directory structure with the following command:
+
+```shell
+python create_dataset_dirs.py
+```
+
+As a result, the original data file structure will look as shown below. There is one directory for each video
+sequence the names can be seen in [resources/datasets/jvet_labels.json](resources/datasets/jvet_labels.json)
+
+```shell
+orig
+├── A1_CampfireParty
+└── ...
+```
+
+The decoded data file structure will look like this:
+
+```shell
+deco
+├── A1_CampfireParty
+│   ├── 22
+│   ├── 27
+│   ├── 32
+│   ├── 37
+│   └── 42
+└── ...
+```
+
+## 2. Data preparation
+
+The video data is in yuv420p and yuv420p10le formats.
+
+The following scripts saves to a json file the slice QP for each slice in a video sequence.
+
+```shell
+python data_preparation.py
+```
+
+### 2.1. Original data 
+
+Copy the JVET RA NNVC CTC mandatory sequences into the corresponding directory under `orig_path`.
+
+### 2.2. Video encoding and decoding
+
+Use [NNVC-8.0rc3](https://vcgit.hhi.fraunhofer.de/jvet-ahg-nnvc/VVCSoftware_VTM/-/tree/NNVC-8.0rc3/) and enable both 
+the NN intra tool and the NN LOP2.0 loop-filter.
+
+**NOTE**: The decoded data should follow the file structure shown earlier. The output files can be re-organised
+after the decoding process takes place. 
+
+Save the following files:
+  * Decoder log 
+  * Reconstruction before the deblocking filter
+  * Boundary strength
+  * Prediction
+  * Block Prediction Mode
+
+These files will have the same name with different sequences as prefix across different QPs.
+Check the file names in [resources/config.json](resources/config.json) under `nnvc`.
+
+## 3. SADL model to Torch model
+
+The torch model for training is generated from the original LOP2 in Int16 SADL format by using the following command:
+
+```shell
+python conversion/sadl2torch.py
+```
+
+The same command will also create two JSON files: one with the model quantisers and one that maps the parameters of
+the original LOP2 model to the model that contains the multiplier parameters.
+
+The three output files will be located under the `resources` directory as follows:
+
+
+```shell
+resources
+├── nnlf_lop2_model.pt
+├── nnlf_lop2_quantizers.json
+└── nnlf_lop2_params_to_new_model.json
+```
+
+## 4. Overfitting pipeline
+
+Overfitting pipeline includes following steps:
+1. Trainings
+2. Weight-update compression 
+3. Conversion from Torch model to SADL model
+4. Quantization of SADL float model
+
+### 4.1 Overfitting
+For each video sequence, the overfitting is done per intra-period and the overfitting is done by using layer-wise optimisation.
+
+### 4.2 Weight-update compression
+Here the weight-update is computed and compressed, generating an NNR/NNC bitstream. The reconstructed Torch model is saved as well.
+
+### 4.3 Convert Torch models to SADL
+The Torch model will be converted into ONNX first, converted into SADL float32 model and then quantised.
+
+The sample script below launches in a single python process the overfitting pipeline in the same class:
+
+```shell
+CUDA_VISIBLE_DEVICES=2 python launch_pipeline.py -vs A1_CampfireParty A1_FoodMarket A1_Tango--qps 22 27 32 37 42
+CUDA_VISIBLE_DEVICES=3 python launch_pipeline.py -vs A2_CatRobot A2_DaylightRoad A2_ParkRunning --qps 22 27 32 37 42
+CUDA_VISIBLE_DEVICES=4 python launch_pipeline.py -vs B_BasketBallDrive B_BQTerrace B_Cactus B_MarketPlace B_RitualDance --qps 22 27 32 37 42
+CUDA_VISIBLE_DEVICES=0 python launch_pipeline.py -vs C_BasketballDrill C_PartyScene C_BQMall C_RaceHorses_big --qps 22 27 32 37 42
+CUDA_VISIBLE_DEVICES=1 python launch_pipeline.py -vs D_BasketBallPass D_BQSquare D_BlowingBubbles D_RaceHorses_s --qps 22 27 32 37 42
+CUDA_VISIBLE_DEVICES=5 python launch_pipeline.py -vs F_ArenaOfValor F_BBDrillText F_SlideEditing F_SlideShow --qps 22 27 32 37 42
+```
+
+The overfitting results will be in `<output_path>/overfittings`.
+The Weight-update compression results will be in `<output_path>/nnr_models`
+The float SADL models will be in `<output_path>/sadl_float_models`
+The int SADL models will be in `<output_path>/sadl_int_models`
+
+The directory structure will look as shown below (class_seqQP_partPartIdx):
+
+```shell
+<output_path>
+├── overfittings
+│   ├── D_BlowingBubbles_42_part0
+│   ├── D_BlowingBubbles_42_part1
+│   ├── D_BlowingBubbles_42_part2
+│   └── ...
+├── nnr_models
+│   ├── nnr
+│   │   ├── A1_Tango_22_part0.nnr
+│   │   └── ...
+│   ├── nnr_A1_Tango_22_part0.pt
+│   └── ...
+├── sadl_float_models
+│   ├── nnr_A1_Tango_22_part0.sadl
+│   └── ...
+├── quantizers
+│   ├── Q_nnr_A1_Tango_22_part0.log
+│   └── ...
+└── sadl_int_models
+    ├── nnr_A1_Tango_22_part0.sadl
+    └── ...
+```
+
+## 5. Inference
+
+The content-adaptive LOP2 is run only on the RA config, and **each intra period is encoded separately**. Then the separate
+bitstreams are merged and the resulting bitstream is decoded.
+
+The NNR bitstream is signaled within a Neural Network Filter Update (NNFU) APS. One APS is signalled within the first
+B-Frame of an intra-period.
+
+The following example shows the NNFU encoder parameters for sequence D_BlowingBubbles, QP 22 and intra period 0.
+
+```shell
+NnfuEnabled                  : 1
+NumNnfus                     : 1
+NnfuPayloadFileName0         : nnr/D_BlowingBubbles_22_part0.nnr
+NnfuModelFileName0           : overfitted_models/nnr_D_BlowingBubbles_22_part0.sadl
+```
+
+### 5.1. Encoder
+Use the default `encoder_randomacess_nnvc.cfg` as well as NNFU APS parameters.
+
+The first intra period of D_BlowingBubbles QP 22 can be encoded like:
+
+```shell
+./EncoderAppStatic -c encoder_randomacess_nnvc.cfg --IntraPeriod=64 --FramesToBeEncoded=65 --NnfuEnabled=1 --NumNnfus=1 --NnfuPayloadFileName0=nnr/D_BlowingBubbles_22_part0.nnr --NnfuModelFileName0=overfitted_models/nnr_D_BlowingBubbles_22_part0.sadl
+```
+
+The final bitstream is obtained by merging the best RA segment bitstreams. For that, the script
+`segment_on_off.py` is used. This script takes the first pass (data generation with original lop2 model) 
+and second pass encoding (overfitted model) results and creates 
+scripts to (1) merge the final segment bitstreams and (2) call the decoder.
+The script is tailored for the proponents environment, but it can be tuned to work with others.
+
+Sample script:
+
+```shell
+python3  segment_on_off.py \
+-vs D_BasketBallPass D_BQSquare D_BlowingBubbles D_RaceHorses_s \
+--qps 22 27 32 37 42 \
+--pass1 /pass1_results \
+--pass2 /pass2_results \
+--output_dir /segment_on_off \
+--intra_models_dir /src/models/intra \
+--lop2_model /src/models/nnlf_lop2_model_int16.sadl \
+--parcat_bin /src/bin/parcatStatic \
+--decoder_bin /src/bin/DecoderAppStatic
+```
+
+### 5.1. Decoder
+
+The reconstruction of the overfitted SADL model is done within the NNVC decoder, which calls `wu_decoding.py` 
+script when an NNFU APS is decoded. The parameter `NnfuOutputFileStem` is required, and it is used as stem for
+the NNR bitstream file names and the reconstructed overfitted SADL models.
+
+The following setup is required to run the decoder correctly:
+
+1. Create environment variable that points to this directory (where the `wu_decoding.py` script is located) 
+and added it too to the python path:
+
+```shell
+export WU_CODE=$PWD
+export PYTHONPATH=${PYTHONPATH}:${PWD}
+```
+
+2. Activate pyton environment with
+
+```shell
+source ${HOME}/lop2_overf/bin/activate
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/lop2_overf/lib
+
+```
+
+or
+
+```shell
+conda activate lop2_overf
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${HOME}/anaconda3/envs/lop2_overf/lib
+```
+
+3. Disable python warnings:
+
+```shell
+export TF_CPP_MIN_LOG_LEVEL=2
+export PYTHONWARNINGS="ignore"
+```
+
+4. Build sadl sample and copy the `naive_quantization` binary to the `$WU_CODE` directory
+
+Sample decoder script:
+
+```shell
+./DecoderAppStatic -b D_BlowingBubbles_22.bin --NnfuOutputFileStem=D_BlowingBubbles_22
+```
diff --git a/training/training_scripts/NN_Adaptive_Filtering/config.py b/training/training_scripts/NN_Adaptive_Filtering/config.py
new file mode 100755
index 0000000000000000000000000000000000000000..d093bb555366a9658f1e9a5941630a14c320c0d7
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/config.py
@@ -0,0 +1,101 @@
+"""
+The copyright in this software is being made available under this Software
+Copyright License. This software may be subject to other third party and
+contributor rights, including patent rights, and no such rights are
+granted under this license.
+Copyright (c) 1995 - 2021 Fraunhofer-Gesellschaft zur Förderung der
+angewandten Forschung e.V. (Fraunhofer)
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted for purpose of testing the functionalities of
+this software provided that the following conditions are met:
+*     Redistributions of source code must retain the above copyright notice,
+this list of conditions and the following disclaimer.
+*     Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in the
+documentation and/or other materials provided with the distribution.
+*     Neither the names of the copyright holders nor the names of its
+contributors may be used to endorse or promote products derived from this
+software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
+CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
+INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
+NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
+ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+NO EXPRESS OR IMPLIED LICENSES TO ANY PATENT CLAIMS, INCLUDING
+WITHOUT LIMITATION THE PATENTS OF THE COPYRIGHT HOLDERS AND
+CONTRIBUTORS, ARE GRANTED BY THIS SOFTWARE LICENSE. THE
+COPYRIGHT HOLDERS AND CONTRIBUTORS PROVIDE NO WARRANTY OF PATENT
+NON-INFRINGEMENT WITH RESPECT TO THIS SOFTWARE.
+"""
+
+# Format changed
+
+import sys
+
+assert sys.version_info >= (3, 6)
+
+
+# def PUT_SYNTAX(): return False
+def PUT_SYNTAX():
+    return True
+
+
+# def ROW_SKIP(): return False
+def ROW_SKIP():
+    return True
+
+
+# def TEMPORAL_CONTEXT(): return False
+def TEMPORAL_CONTEXT():
+    return True
+
+
+# def OPT_QP(): return False
+def OPT_QP():
+    return True
+
+
+# def SPARSE(): return False
+def SPARSE():
+    return True
+
+
+# Use center PUT for temporal contexts in UC14A
+def UC14A_CENTER_PUT():
+    return False
+
+
+# def UC14A_CENTER_PUT(): return True
+
+
+# stochastic binary / ternary quantization
+# def SBT(): return False
+def SBT():
+    return True
+
+
+print("Config:")
+print("  PUT_SYNTAX:           ", str(PUT_SYNTAX()))
+print("  ROW_SKIP:             ", str(ROW_SKIP()))
+print("  TEMPORAL_CONTEXT:     ", str(TEMPORAL_CONTEXT()))
+print("  OPT_QP:               ", str(OPT_QP()))
+print("  SPARSE:               ", str(SPARSE()))
+print("  UC14A_CENTER_PUT:     ", str(UC14A_CENTER_PUT()))
+print("  SBT:                  ", str(SBT()))
+
+# check tool requirements
+if TEMPORAL_CONTEXT():
+    assert PUT_SYNTAX()
+
+if UC14A_CENTER_PUT():
+    assert PUT_SYNTAX()
+    assert TEMPORAL_CONTEXT()
diff --git a/training/training_scripts/NN_Adaptive_Filtering/conversion/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/conversion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/training_scripts/NN_Adaptive_Filtering/conversion/onnx2sadl_modified.py b/training/training_scripts/NN_Adaptive_Filtering/conversion/onnx2sadl_modified.py
new file mode 100644
index 0000000000000000000000000000000000000000..f023336cf27074882cf7df6c1ed042c11c4593e9
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/conversion/onnx2sadl_modified.py
@@ -0,0 +1,2096 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from __future__ import print_function
+
+import argparse
+import copy
+import struct
+import sys
+from collections import OrderedDict
+from enum import IntEnum
+from pathlib import Path
+
+import numpy as np
+import onnx
+
+from util.file_system import read_json_file
+
+# file format:
+# MAGIC: SADL0004 [char[8]]
+# type_model [int32_t] 0:int32, 1:float, 2:int16
+# nb_layers [int32_t]
+# nb_inputs [int32_t]
+# inputs_id [int32_t[nb_inputs]]
+# nb_outputs [int32_t]
+# outputs_id [int32_t[nb_outputs]]
+# (for all layers:)
+#  layer_id [int32_t]
+#  op_id    [int32_t]
+#  name_size [int32_t]
+#  name [char[name_size]]
+#  nb_inputs [int32_t]
+#  intput_ids [int32_t[nb_inputs]]
+#
+# (additional information)
+#  Const_layer:
+#   length_dim [int32_t]
+#   dim [int32_t[length_dim]]
+#   type [int32_t] 0:int32, 1:float32 2:int16
+#   [if integer: quantizer [int32])
+#   data [type[prod(dim)]]
+#
+#  Conv2DTranspose
+#    nb_dim_strides [int32_t]
+#    strides [int32_t[nb_dim_strides]]
+#    quantizer [int32_t]
+#
+#  Conv2D
+#    nb_dim_strides [int32_t]
+#    strides [int32_t[nb_dim_strides]]
+#    quantizer [int32_t]
+#
+#  MatMul
+#    quantizer [int32_t]
+#
+#  Mul
+#    quantizer [int32_t]
+#
+#  PlaceHolder
+#   length_dim [int32_t]
+#   dim [int32_t[length_dim]]
+#   quantizer [int32_t]
+#
+#  MaxPool
+#    nb_dim_strides [int32_t]
+#    strides [int32_t[nb_dim_strides]]
+#    nb_dim_kernel [int32_t]
+#    kernel_dim [int32_t[nb_dim_kernel]]
+
+
+class OPTYPE(IntEnum):
+    Const = (1,)
+    Placeholder = (2,)
+    Identity = (3,)
+    BiasAdd = (4,)
+    MaxPool = (5,)
+    MatMul = (6,)
+    Reshape = (7,)
+    Relu = (8,)
+    Conv2D = (9,)
+    Add = (10,)
+    ConcatV2 = (11,)
+    Mul = (12,)
+    Maximum = (13,)
+    LeakyReLU = (14,)
+    Transpose = (15,)
+    Flatten = (16,)
+    Shape = (17,)
+    Expand = (18,)
+    Conv2DTranspose = (19,)
+    Slice = (
+        20,
+    )  # Currently slicing across depth is supported with default step size of 1
+    PReLU = (21,)
+    # In "tf2cpp", the same layer performs the matrix multiplication
+    # and the matrix multiplication by batches.
+    BatchMatMul = (6,)
+    ScatterND = (22,)
+    GridSample = (23,)
+    Resize = (24,)
+    Compare = (25,)
+    Where = (26,)
+    Minimum = (27,)
+
+    # "BatchMatMulV2" did not exist in Tensorflow 1.9. It exists in
+    # Tensorflow 1.15.
+    BatchMatMulV2 = 6
+    Count = 28
+
+    def __repr__(self):
+        return self.name
+
+    def __str__(self):
+        return self.name
+
+
+class DTYPE_SADL(IntEnum):
+    FLOAT = (1,)  # float
+    INT8 = (3,)  # int8_t
+    INT16 = (2,)  # int16_t
+    INT32 = 0  # int32_t
+
+    def __repr__(self):
+        return self.name
+
+    def __str__(self):
+        return self.name
+
+
+class DTYPE_ONNX(IntEnum):
+    # https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L483-L485
+    FLOAT = (1,)  # float
+    INT8 = (3,)  # int8_t
+    INT16 = (4,)  # int16_t
+    INT32 = (6,)  # int32_t
+    INT64 = 7  # int64_t
+
+    def __repr__(self):
+        return self.name
+
+    def __str__(self):
+        return self.name
+
+
+class Node_Annotation:
+    to_remove = False
+    add_transpose_before = False
+    add_transpose_after = False
+    to_transpose = False
+    layout_onnx = None
+
+    def __repr__(self):
+        return "to_remove={}, to_transpose={}, layout_onnx={}, add_transpose_before={} add_transpose_after={}".format(
+            self.to_remove,
+            self.to_transpose,
+            self.layout_onnx,
+            self.add_transpose_before,
+            self.add_transpose_after,
+        )
+
+
+# get attribute name in node
+def getAttribute(node, attr):
+    for a in node.attribute:
+        if a.name == attr:
+            return a
+    return None
+
+
+def transpose_tensor(raw_data, dims):
+    """
+    When convert TF2 to ONNX, ONNX weight's  are not represent in the same way as TF2 weight's
+    """
+    # print(dims)
+    tmp = []
+    tmp.append(dims[2])
+    tmp.append(dims[3])
+    tmp.append(dims[1])
+    tmp.append(dims[0])
+
+    x = np.frombuffer(raw_data, dtype=np.float32)
+    x = x.reshape(tmp[3], tmp[2], tmp[0] * tmp[1]).transpose().flatten()
+    return x.tobytes(), tmp
+
+
+def transpose_matrix(raw_data, dims):
+    x = np.frombuffer(raw_data, dtype=np.float32)
+    tmp = []
+    tmp.append(dims[1])
+    tmp.append(dims[0])
+    x = x.reshape(dims[0], dims[1])
+    x = np.transpose(x)  # moveaxis(x, -2, -1)
+    return x.flatten().tobytes(), tmp
+
+
+def toList(ii):
+    d = []
+    for i in ii:
+        d.append(i)
+    return d
+
+
+def is_constant(name, onnx_initializer):
+    for n in onnx_initializer:
+        if n.name == name:
+            return True
+    return False
+
+
+def is_output(name, onnx_output):
+    for out in onnx_output:
+        if out.name == name:
+            return True
+    return False
+
+
+def parse_graph_input_node(input_node, map_onnx_to_myGraph, to_transpose):
+    map_onnx_to_myGraph[input_node.name] = input_node.name
+    struct = {}
+    struct["inputs"] = []
+    struct["additional"] = {}
+    if (
+        to_transpose
+    ):  # data_layout == 'nchw' and len(input_node.type.tensor_type.shape.dim)==4:
+        struct["additional"]["dims"] = [
+            input_node.type.tensor_type.shape.dim[0].dim_value,
+            input_node.type.tensor_type.shape.dim[2].dim_value,
+            input_node.type.tensor_type.shape.dim[3].dim_value,
+            input_node.type.tensor_type.shape.dim[1].dim_value,
+        ]
+    else:
+        struct["additional"]["dims"] = [
+            d.dim_value for d in input_node.type.tensor_type.shape.dim
+        ]
+    struct["op_type"] = OPTYPE.Placeholder
+    return struct
+
+
+def extract_additional_data_from_node(data, to_transpose):
+    tmp = {}
+    if data.dims == []:
+        tmp["dims"] = [1]
+    else:
+        tmp["dims"] = [dim for dim in data.dims]
+
+    tmp["raw_data"] = data.raw_data
+
+    if data.data_type == DTYPE_ONNX.FLOAT:
+        tmp["dtype"] = DTYPE_SADL.FLOAT
+    elif data.data_type == DTYPE_ONNX.INT8:
+        tmp["dtype"] = DTYPE_SADL.INT8
+    elif data.data_type == DTYPE_ONNX.INT16:
+        tmp["dtype"] = DTYPE_SADL.INT16
+    elif data.data_type == DTYPE_ONNX.INT32:
+        tmp["dtype"] = DTYPE_SADL.INT32
+    elif data.data_type == DTYPE_ONNX.INT64:
+
+        def convert_int64_to_int32(binary_data):
+            x = np.frombuffer(binary_data, dtype=np.int64)
+            x = x.astype(np.int32)
+            return x.tobytes()
+
+        tmp["dtype"] = DTYPE_SADL.INT32
+        tmp["raw_data"] = convert_int64_to_int32(tmp["raw_data"])
+    else:
+        raise ValueError("extract_additional_data: Unknown dtype")
+
+    if to_transpose:
+        if len(tmp["dims"]) == 4:
+            tmp["raw_data"], tmp["dims"] = transpose_tensor(
+                tmp["raw_data"], tmp["dims"]
+            )
+        elif len(tmp["dims"]) == 2:  # and data_layout == "nchw":
+            tmp["raw_data"], tmp["dims"] = transpose_matrix(
+                tmp["raw_data"], tmp["dims"]
+            )
+
+    return tmp["dims"], tmp["raw_data"], tmp["dtype"]
+
+
+def extract_additional_data(name, to_transpose, onnx_graph, verbose):
+    if verbose:
+        print("[INFO] {} transpose={}".format(name, to_transpose))
+
+    for init in onnx_graph.initializer:
+        if name == init.name:
+            return extract_additional_data_from_node(init, to_transpose)
+    for node in onnx_graph.node:  # not found in initializaer, search in Constant
+        if name == node.output[0]:
+            return extract_additional_data_from_node(node.attribute[0].t, to_transpose)
+    quit("[ERROR] unable to extract data in {}".format(name))
+
+
+def extract_dims(name, onnx_graph):
+    for init in onnx_graph.initializer:
+        if name == init.name:
+            return init.dims
+    for node in onnx_graph.node:  # not found in initializaer, search in Constant
+        if name == node.output[0]:
+            a = getAttribute(node, "value")
+            if a is not None:
+                return a.t.dims
+            else:
+                return None
+    for node in onnx_graph.input:  # not found in initializaer, search in Constant
+        if name == node.name:
+            return node.type.tensor_type.shape.dim
+    quit("[ERROR] unable to extract dims in {}".format(name))
+
+
+# get the nodes with name as input
+def getNodesWithInput(name, model):
+    L = []
+    for node in model.graph.node:
+        for inp in node.input:
+            if inp == name:
+                L.append(node)
+    return L
+
+
+# get the nodes with name as output
+def getNodesWithOutput(name, model):
+    for node in model.graph.node:
+        for out in node.output:
+            if out == name:
+                return node
+    for node in model.graph.initializer:
+        if node.name == name:
+            return node
+    for node in model.graph.input:
+        if node.name == name:
+            return node
+    quit("[ERROR] not found: {}".format(name))
+
+
+# get the nodes with name as output
+def getNodesWithOutputNotConst(name, model):
+    for node in model.graph.node:
+        for out in node.output:
+            if out == name:
+                return node
+    for node in model.graph.input:
+        if node.name == name:
+            return node
+    return None
+
+
+# get dims from data
+def getDims(node):
+    if node.data_type != DTYPE_ONNX.INT64:
+        quit("[ERROR] bad node type fpr getDims {}".format(node))
+
+    x = np.frombuffer(node.raw_data, dtype=np.int64)
+    dims = x.tolist()
+    return dims
+
+
+def getInitializer(name, model_onnx):
+    for node in model_onnx.graph.initializer:
+        if node.name == name:
+            return node
+    return None
+
+
+def add_transpose(node, myGraph, map_onnx_to_myGraph):
+    # Transpose inserted
+    # Const
+    reshape_coef_name = node.input[0] + "_COEF_TRANSPOSE_NOT_IN_GRAPH"
+    myGraph[reshape_coef_name] = {}
+    myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const
+    myGraph[reshape_coef_name]["inputs"] = []
+    additional = {}
+    additional["dims"] = [4]
+    additional["raw_data"] = np.array(
+        [0, 3, 1, 2], dtype=np.int32
+    ).tobytes()  # nhwc -> nchw
+    additional["dtype"] = DTYPE_SADL.INT32
+    additional["data"] = node
+    myGraph[reshape_coef_name]["additional"] = additional
+    map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name
+
+    nname = node.input[0] + "_TRANSPOSE_NOT_IN_GRAPH"
+    myGraph[nname] = {}
+    myGraph[nname]["op_type"] = OPTYPE.Transpose
+    myGraph[nname]["inputs"] = [map_onnx_to_myGraph[node.input[0]], reshape_coef_name]
+    map_onnx_to_myGraph[nname] = nname
+    return nname
+
+
+def add_transpose_after(node, myGraph, map_onnx_to_myGraph):
+    # Transpose inserted
+    # Const
+    reshape_coef_name = node.output[0] + "_COEF_TRANSPOSE_AFTER_NOT_IN_GRAPH"
+    myGraph[reshape_coef_name] = {}
+    myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const
+    myGraph[reshape_coef_name]["inputs"] = []
+    additional = {}
+    additional["dims"] = [4]
+    additional["raw_data"] = np.array(
+        [0, 2, 3, 1], dtype=np.int32
+    ).tobytes()  # nchw -> nhwc
+    additional["dtype"] = DTYPE_SADL.INT32
+    additional["data"] = node
+    myGraph[reshape_coef_name]["additional"] = additional
+    map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name
+
+    nname = node.output[0] + "_TRANSPOSE_AFTER_NOT_IN_GRAPH"
+    myGraph[nname] = {}
+    myGraph[nname]["op_type"] = OPTYPE.Transpose
+    myGraph[nname]["inputs"] = [map_onnx_to_myGraph[node.output[0]], reshape_coef_name]
+    map_onnx_to_myGraph[nname] = nname
+    map_onnx_to_myGraph[node.output[0]] = nname
+    return nname
+
+
+def parse_graph_node(
+    node, model_onnx, myGraph, node_annotation, map_onnx_to_myGraph, verbose
+):
+    if verbose > 1:
+        print("parse node", node.name)
+
+    if node_annotation[
+        node.name
+    ].add_transpose_before:  # layout_onnx == 'nchw' : # need to go back to original layout before reshape
+        n0name = add_transpose(node, myGraph, map_onnx_to_myGraph)
+    else:
+        if len(node.input) >= 1:
+            n0name = node.input[0]
+        else:
+            n0name = None
+
+    if (
+        node.op_type == "Conv"
+        or node.op_type == "Gemm"
+        or node.op_type == "ConvTranspose"
+    ):
+        nb_inputs = len(node.input)
+        if (nb_inputs != 3) and (nb_inputs != 2):
+            raise Exception("parse_graph_node: Error on node type")
+        additional = {}
+        # Const: weight
+        additional["data"] = node
+        n2 = getNodesWithOutput(node.input[1], model_onnx)
+        (
+            additional["dims"],
+            additional["raw_data"],
+            additional["dtype"],
+        ) = extract_additional_data(
+            node.input[1],
+            node_annotation[n2.name].to_transpose,
+            model_onnx.graph,
+            verbose,
+        )
+        map_onnx_to_myGraph[node.input[1]] = node.input[1]
+
+        myGraph[node.input[1]] = {}
+        myGraph[node.input[1]]["inputs"] = []
+        myGraph[node.input[1]]["additional"] = additional
+        myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+
+        # Conv2d
+        inputs, additional = [], {}
+        inputs = [map_onnx_to_myGraph[n0name]] + [map_onnx_to_myGraph[node.input[1]]]
+
+        additional["data"] = node
+        if node.op_type == "Conv" or node.op_type == "ConvTranspose":
+            a = getAttribute(node, "strides")
+            additional["strides"] = a.ints
+            if node.op_type == "Conv":
+                a = getAttribute(node, "group")
+                additional["group"] = a.i
+            a = getAttribute(node, "pads")
+            # if pads is unavailable then no padding
+            if a:
+                additional["pads"] = a.ints
+            else:
+                additional["pads"] = [0, 0, 0, 0]
+        if node.op_type == "ConvTranspose":
+            a = getAttribute(node, "output_padding")
+            if a:
+                additional["output_padding"] = a.ints
+            else:
+                additional["output_padding"] = [0, 0]
+
+        if nb_inputs == 2:
+            map_onnx_to_myGraph[node.output[0]] = node.output[0]
+        elif nb_inputs == 3:
+            map_onnx_to_myGraph[node.output[0]] = node.output[0] + "_NOT_IN_GRAPH"
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["inputs"] = inputs
+        myGraph[node.output[0]]["additional"] = additional
+        if node.op_type == "Conv":
+            myGraph[node.output[0]]["op_type"] = OPTYPE.Conv2D
+        elif node.op_type == "ConvTranspose":
+            myGraph[node.output[0]]["op_type"] = OPTYPE.Conv2DTranspose
+        elif node.op_type == "Gemm":
+            myGraph[node.output[0]]["op_type"] = OPTYPE.MatMul
+
+        if nb_inputs == 3:
+            additional = {}
+            # Const: bias
+            additional["data"] = node
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(node.input[2], False, model_onnx.graph, verbose)
+            map_onnx_to_myGraph[node.input[2]] = node.input[2]
+            myGraph[node.input[2]] = {}
+            myGraph[node.input[2]]["inputs"] = []
+            myGraph[node.input[2]]["additional"] = additional
+            myGraph[node.input[2]]["op_type"] = OPTYPE.Const
+            # BiasAdd
+            inputs, additional = [], {}
+            inputs = [node.output[0]] + [map_onnx_to_myGraph[node.input[2]]]
+            additional["data"] = node
+            map_onnx_to_myGraph[node.output[0] + "_NOT_IN_GRAPH"] = None
+            myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {}
+            myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = inputs
+            myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional
+            myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.BiasAdd
+
+    elif node.op_type == "Relu":
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Relu
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    # Constant node value has to be skipped for slicing
+    # This node is not used by any other onnx model in utests/models directory
+    # Const value of slice is taken care inside "Slice" condition
+    elif node.op_type == "Constant":  # ~ like an initializer
+        pass
+
+    elif node.op_type == "Add":
+        swap_inputs = False
+        if is_constant(n0name, model_onnx.graph.initializer):
+            additional = {}
+            additional["data"] = node
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(n0name, False, model_onnx.graph, verbose)
+            map_onnx_to_myGraph[n0name] = n0name
+            myGraph[n0name] = {}
+            myGraph[n0name]["inputs"] = []
+            myGraph[n0name]["additional"] = additional
+            myGraph[n0name]["op_type"] = OPTYPE.Const
+            swap_inputs = True
+        if is_constant(node.input[1], model_onnx.graph.initializer):
+            additional = {}
+            additional["data"] = node
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose)
+            map_onnx_to_myGraph[node.input[1]] = node.input[1]
+            myGraph[node.input[1]] = {}
+            myGraph[node.input[1]]["inputs"] = []
+            myGraph[node.input[1]]["additional"] = additional
+            myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Add
+        if not swap_inputs:
+            D1 = extract_dims(n0name, model_onnx.graph)
+            D2 = extract_dims(node.input[1], model_onnx.graph)
+            if D1 is not None and D2 is not None and len(D1) < len(D2):
+                swap_inputs = True
+
+        if swap_inputs:
+            myGraph[node.output[0]]["inputs"] = [
+                map_onnx_to_myGraph[node.input[1]],
+                map_onnx_to_myGraph[n0name],
+            ]
+        else:
+            myGraph[node.output[0]]["inputs"] = [
+                map_onnx_to_myGraph[n0name],
+                map_onnx_to_myGraph[node.input[1]],
+            ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "MaxPool":
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.MaxPool
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
+        myGraph[node.output[0]]["additional"] = {}
+        a = getAttribute(node, "strides")
+        myGraph[node.output[0]]["additional"]["strides"] = [1, a.ints[0], a.ints[1], 1]
+        a = getAttribute(node, "pads")
+        if a is None:
+            pp = [0, 0, 0, 0]
+        else:
+            pp = a.ints
+        myGraph[node.output[0]]["additional"]["pads"] = pp
+        a = getAttribute(node, "kernel_shape")
+        myGraph[node.output[0]]["additional"]["kernel_shape"] = [
+            1,
+            a.ints[0],
+            a.ints[1],
+            1,
+        ]
+        myGraph[node.output[0]]["additional"]["data"] = node
+        # todo: check pads?
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Mul":
+        # check the inputs
+        if is_constant(n0name, model_onnx.graph.initializer) and is_constant(
+            node.input[1], model_onnx.graph.initializer
+        ):
+            quit("[ERROR] unsupported double constants Mul", node)
+        swap_inputs = False
+        if is_constant(n0name, model_onnx.graph.initializer):
+            additional = {}
+            additional["data"] = node
+            n2 = getNodesWithOutput(n0name, model_onnx)
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(
+                n0name, node_annotation[n2.name].to_transpose, model_onnx.graph, verbose
+            )
+            map_onnx_to_myGraph[n0name] = n0name
+            myGraph[n0name] = {}
+            myGraph[n0name]["inputs"] = []
+            myGraph[n0name]["additional"] = additional
+            myGraph[n0name]["op_type"] = OPTYPE.Const
+            swap_inputs = True
+        if is_constant(node.input[1], model_onnx.graph.initializer):
+            additional = {}
+            additional["data"] = node
+            n2 = getNodesWithOutput(node.input[1], model_onnx)
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(
+                node.input[1],
+                node_annotation[n2.name].to_transpose,
+                model_onnx.graph,
+                verbose,
+            )
+            map_onnx_to_myGraph[node.input[1]] = node.input[1]
+            myGraph[node.input[1]] = {}
+            myGraph[node.input[1]]["inputs"] = []
+            myGraph[node.input[1]]["additional"] = additional
+            myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Mul
+        if swap_inputs:
+            myGraph[node.output[0]]["inputs"] = [
+                map_onnx_to_myGraph[node.input[1]],
+                map_onnx_to_myGraph[n0name],
+            ]
+        else:
+            myGraph[node.output[0]]["inputs"] = [
+                map_onnx_to_myGraph[n0name],
+                map_onnx_to_myGraph[node.input[1]],
+            ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Identity" or node.op_type == "Cast":
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Identity
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "LeakyRelu":
+        # leaky coef
+        additional = {}
+        additional["data"] = node
+        additional["dims"] = [1]
+        additional["raw_data"] = np.array(
+            float(node.attribute[0].f), dtype=np.float32
+        ).tobytes()
+        additional["dtype"] = DTYPE_SADL.FLOAT
+        map_onnx_to_myGraph[node.output[0] + "_COEF_NOT_IN_GRAPH"] = None
+        myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {}
+        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = []
+        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional
+        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.Const
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.LeakyReLU
+        myGraph[node.output[0]]["inputs"] = [
+            map_onnx_to_myGraph[n0name],
+            node.output[0] + "_NOT_IN_GRAPH",
+        ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "PRelu":
+        additional = {}
+        additional["data"] = node
+        n2 = getNodesWithOutput(node.input[1], model_onnx)
+        (
+            additional["dims"],
+            additional["raw_data"],
+            additional["dtype"],
+        ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose)
+        myGraph[node.input[1]] = {}
+        myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+        myGraph[node.input[1]]["inputs"] = []
+        myGraph[node.input[1]]["additional"] = additional
+        map_onnx_to_myGraph[node.input[1]] = node.input[1]
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.PReLU
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + [
+            map_onnx_to_myGraph[node.input[1]]
+        ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Flatten":
+        inputs, additional = [], {}
+        inputs = [map_onnx_to_myGraph[n0name]]
+        additional["data"] = node
+        a = getAttribute(node, "axis")
+        additional["axis"] = a.i
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["inputs"] = inputs
+        myGraph[node.output[0]]["additional"] = additional
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Flatten
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Shape":
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Shape
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Expand":
+        inputs, additional = [], {}
+        inputs = [map_onnx_to_myGraph[n0name], map_onnx_to_myGraph[node.input[1]]]
+        additional["data"] = node
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["inputs"] = inputs
+        myGraph[node.output[0]]["additional"] = additional
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Expand
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Reshape" or node.op_type == "MatMul":
+        # Const
+        myGraph[node.input[1]] = {}
+        myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+        myGraph[node.input[1]]["inputs"] = []
+        additional = {}
+        (
+            additional["dims"],
+            additional["raw_data"],
+            additional["dtype"],
+        ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose)
+        additional["data"] = node
+        myGraph[node.input[1]]["additional"] = additional
+        map_onnx_to_myGraph[node.input[1]] = node.input[1]
+        n2 = getNodesWithOutput(node.input[0], model_onnx)
+        # Reshape
+        inputs, additional = [], {}
+        inputs = [map_onnx_to_myGraph[n0name], node.input[1]]
+        additional["data"] = node
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["inputs"] = inputs
+        myGraph[node.output[0]]["additional"] = additional
+
+        if node.op_type == "Reshape":
+            myGraph[node.output[0]]["op_type"] = OPTYPE.Reshape
+        elif node.op_type == "MatMul":
+            myGraph[node.output[0]]["op_type"] = OPTYPE.MatMul
+
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Concat":
+        # Const
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Const
+        myGraph[node.output[0]]["inputs"] = []
+        additional = {}
+        additional["dims"] = [1]
+        additional["raw_data"] = np.array(node.attribute[0].i, dtype=np.int32).tobytes()
+        additional["dtype"] = DTYPE_SADL.INT32
+        additional["data"] = node
+        myGraph[node.output[0]]["additional"] = additional
+        map_onnx_to_myGraph[node.output[0] + "_NOT_IN_GRAPH"] = None
+
+        # Concatenate
+        inputs, additional = [], {}
+        for inp in node.input:
+            inputs.append(map_onnx_to_myGraph[inp])
+        inputs.append(node.output[0])
+        additional["data"] = node
+        myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {}
+        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = inputs
+        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional
+        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.ConcatV2
+        map_onnx_to_myGraph[node.output[0]] = node.output[0] + "_NOT_IN_GRAPH"
+
+    elif node.op_type == "Max":
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Maximum
+        myGraph[node.output[0]]["inputs"] = [
+            map_onnx_to_myGraph[n0name],
+            map_onnx_to_myGraph[node.input[1]],
+        ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Min":
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Minimum
+        myGraph[node.output[0]]["inputs"] = [
+            map_onnx_to_myGraph[n0name],
+            map_onnx_to_myGraph[node.input[1]],
+        ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Unsqueeze":
+        # No need to parse Unsqueeze as SADL can handle it.
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Transpose":
+        # Const
+        reshape_coef_name = node.output[0] + "_COEF_TRANSPOSE"
+        myGraph[reshape_coef_name] = {}
+        myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const
+        myGraph[reshape_coef_name]["inputs"] = []
+        additional = {}
+        d = toList(getAttribute(node, "perm").ints)
+        additional["dims"] = [len(d)]
+        additional["raw_data"] = np.array(d, dtype=np.int32).tobytes()
+        additional["dtype"] = DTYPE_SADL.INT32
+        additional["data"] = node
+        myGraph[reshape_coef_name]["additional"] = additional
+        map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Transpose
+        myGraph[node.output[0]]["inputs"] = [
+            map_onnx_to_myGraph[n0name],
+            reshape_coef_name,
+        ]
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Slice":
+        # Slice
+        if len(node.input) == 5:  # PyTorch
+            initializer = getInitializer(node.input[3], model_onnx)
+            # Case: In pytorch, Slice is not in model_onnx.graph.initializer but in model_onnx.graph.node
+            if initializer is None:
+                attribute = getAttribute(
+                    getNodesWithOutput(node.input[3], model_onnx), "value"
+                )
+                initializer = attribute.t
+            axes = getDims(initializer)
+
+            initializer = getInitializer(node.input[4], model_onnx)
+            # Case: In pytorch, Slice is not in model_onnx.graph.initializer but in model_onnx.graph.node
+            if initializer is None:
+                attribute = getAttribute(
+                    getNodesWithOutput(node.input[4], model_onnx), "value"
+                )
+                initializer = attribute.t
+            steps = getDims(initializer)
+
+            if len(axes) != 1:
+                quit(
+                    "[ERROR] currently sadl slicing support lenght of axes equal to one"
+                )
+            if axes[0] == 0:
+                quit("[ERROR] currently slicing not supported for first dimension")
+            if not (len(steps) == 1 and steps[0] == 1):
+                quit("[ERROR] currently step has to be default one")
+
+        # Currently slicing support only across width is added
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Slice
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
+        # assume depth is the last one, assume axes are always 0, 1, 2, etc.
+
+        initializer = getInitializer(node.input[1], model_onnx)
+        if initializer is None:
+            attribute = getAttribute(
+                getNodesWithOutput(node.input[1], model_onnx), "value"
+            )
+            initializer = attribute.t
+        start = getDims(initializer)
+
+        initializer = getInitializer(node.input[2], model_onnx)
+        if initializer is None:
+            attribute = getAttribute(
+                getNodesWithOutput(node.input[2], model_onnx), "value"
+            )
+            initializer = attribute.t
+        end = getDims(initializer)
+        additional = {}
+        dim_keys = ["b", "h", "w", "c"]
+        for i in range(1, len(dim_keys)):
+            additional[f"start_{dim_keys[i]}"] = 0
+            additional[f"end_{dim_keys[i]}"] = 2147483647
+
+        # model_onnx got from tensorflow has length of start and end equal to 4
+        # model_onnx got from pytorch has length of start and end equal to 1. The dimension of slicing
+        # i.e., if slicing is done across C or H or W is controlled by axes
+        if len(start) > 1:  # TensorFlow to onnx models
+            if start[0] != 0:
+                quit("[ERROR] currently slicing not supported for first dimension")
+            if end[0] != 2147483647:
+                quit("[ERROR] currently slicing not supported for first dimension")
+            for i in range(1, len(start)):
+                initializer = getInitializer(node.input[1], model_onnx)
+                if initializer is None:
+                    attribute = getAttribute(
+                        getNodesWithOutput(node.input[1], model_onnx), "value"
+                    )
+                    initializer = attribute.t
+                start_d = getDims(initializer)[i]
+
+                initializer = getInitializer(node.input[2], model_onnx)
+                if initializer is None:
+                    attribute = getAttribute(
+                        getNodesWithOutput(node.input[2], model_onnx), "value"
+                    )
+                    initializer = attribute.t
+                end_d = getDims(initializer)[i]
+                if (
+                    end_d > 2147483647
+                ):  # The default infinity number in PyTorch INT64 ONNX is 9223372036854775807.
+                    end_d = 2147483647
+                additional[f"start_{dim_keys[i]}"] = start_d
+                additional[f"end_{dim_keys[i]}"] = end_d
+        else:  # PyTorch to onnx models
+            dim_keys_torch = ["b", "c", "h", "w"]
+            for i in range(len(end) - 1):
+                if start[i] != 0:
+                    quit("[ERROR] currently slicing not supported for first dimension")
+                if end[i] < 2147483647:
+                    quit("[ERROR] currently slicing only supported for last channel")
+
+            initializer = getInitializer(node.input[1], model_onnx)
+            if initializer is None:
+                attribute = getAttribute(
+                    getNodesWithOutput(node.input[1], model_onnx), "value"
+                )
+                initializer = attribute.t
+            start_d = getDims(initializer)[-1]
+
+            initializer = getInitializer(node.input[2], model_onnx)
+            if initializer is None:
+                attribute = getAttribute(
+                    getNodesWithOutput(node.input[2], model_onnx), "value"
+                )
+                initializer = attribute.t
+            end_d = getDims(initializer)[-1]
+            if (
+                end_d > 2147483647
+            ):  # The default infinity number in PyTorch INT64 ONNX is 9223372036854775807.
+                end_d = 2147483647
+            additional[f"start_{dim_keys_torch[axes[0]]}"] = start_d
+            additional[f"end_{dim_keys_torch[axes[0]]}"] = end_d
+        myGraph[node.output[0]]["additional"] = additional
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "ScatterND":
+        # The default input order for the ScatterND is data, indices, and updates.
+        if not is_constant(node.input[1], model_onnx.graph.initializer):
+            quit("[ERROR] The second input of the ScatterND must be indices.")
+        # indices
+        additional = {}
+        additional["data"] = node
+        (
+            additional["dims"],
+            additional["raw_data"],
+            additional["dtype"],
+        ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose)
+        if len(additional["dims"]) == 5:
+            # When the tensor format is specified as NCHW4 (or NHWC4) and the value of N is 1, the format is transformed
+            # to CHW4 (or HWC4). Here, the "4" indicates the position index within a 4-dimensional tensor.
+            additional["dims"] = additional["dims"][1:]
+            # transpose CHW4 to HWC4
+            if node_annotation[node.input[1]].to_transpose:
+                tmp = [
+                    additional["dims"][1],
+                    additional["dims"][2],
+                    additional["dims"][0],
+                    additional["dims"][3],
+                ]
+                x = (
+                    np.frombuffer(additional["raw_data"], dtype=np.int32)
+                    .reshape(additional["dims"])
+                    .transpose(1, 2, 0, 3)
+                )
+                indices = x.copy()
+                for i in np.ndindex(indices.shape[:-1]):
+                    indices[i] = [
+                        indices[i][0],
+                        indices[i][2],
+                        indices[i][3],
+                        indices[i][1],
+                    ]
+                additional["dims"] = tmp
+                additional["raw_data"] = indices.flatten().tobytes()
+        else:
+            quit("[ERROR] Currently, ScatterND only supports indices of length 5.")
+        map_onnx_to_myGraph[node.input[1]] = node.input[1]
+        myGraph[node.input[1]] = {}
+        myGraph[node.input[1]]["inputs"] = []
+        myGraph[node.input[1]]["additional"] = additional
+        myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.ScatterND
+        myGraph[node.output[0]]["inputs"] = [
+            map_onnx_to_myGraph[n0name],  # data
+            map_onnx_to_myGraph[node.input[2]],  # updates
+            map_onnx_to_myGraph[node.input[1]],
+        ]  # indices
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "GridSample":
+        # Currently, the official TensorFlow does not have an implementation for GridSample.
+        align_corners = getAttribute(node, "align_corners").i
+        mode = getAttribute(node, "mode").s.decode("utf-8")
+        padding_mode = getAttribute(node, "padding_mode").s.decode("utf-8")
+
+        mode_list = ["nearest", "bilinear"]
+        if mode not in mode_list:
+            quit("[ERROR] Currently, the mode of GridSample must in", mode_list, node)
+        else:
+            mode = mode_list.index(mode)
+        padding_mode_list = ["border"]
+        if padding_mode not in padding_mode_list:
+            quit(
+                "[ERROR] Currently, the padding_mode of GridSample must in",
+                padding_mode_list,
+                node,
+            )
+        else:
+            padding_mode = padding_mode_list.index(padding_mode)
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.GridSample
+        myGraph[node.output[0]]["inputs"] = [
+            map_onnx_to_myGraph[n0name],
+            map_onnx_to_myGraph[node.input[1]],
+        ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        myGraph[node.output[0]]["additional"]["align_corners"] = align_corners
+        myGraph[node.output[0]]["additional"]["mode"] = mode
+        myGraph[node.output[0]]["additional"]["padding_mode"] = padding_mode
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Resize":
+        # Between 2 and 4 inputs. The default input order for the Resize is X, roi, scales, and sizes.
+        input_label = 0
+        input_list = [map_onnx_to_myGraph[n0name]]
+        for input_index, input_name in enumerate(node.input):
+            if is_constant(input_name, model_onnx.graph.initializer):
+                additional = {}
+                additional["data"] = node
+                (
+                    additional["dims"],
+                    additional["raw_data"],
+                    additional["dtype"],
+                ) = extract_additional_data(
+                    input_name, False, model_onnx.graph, verbose
+                )
+                if (
+                    additional["raw_data"] == b""
+                ):  # When tensor data is empty, just ignore it.
+                    continue
+
+                map_onnx_to_myGraph[input_name] = input_name
+                myGraph[input_name] = {}
+                myGraph[input_name]["inputs"] = []
+                myGraph[input_name]["additional"] = additional
+                myGraph[input_name]["op_type"] = OPTYPE.Const
+                input_list.append(map_onnx_to_myGraph[input_name])
+                input_label = input_label + (1 << (3 - input_index))
+        if input_label != 1 and input_label != 2:
+            quit(
+                "[ERROR] Currently, the inputs of Resize have to be X and sizes, or X and scales."
+            )
+
+        # attribute (str -> int)
+        coordinate_transformation_mode = getAttribute(
+            node, "coordinate_transformation_mode"
+        ).s.decode("utf-8")
+        cubic_coeff_a = getAttribute(node, "cubic_coeff_a")
+        exclude_outside = getAttribute(node, "exclude_outside")
+        mode = getAttribute(node, "mode").s.decode("utf-8")
+        nearest_mode = getAttribute(node, "nearest_mode").s.decode("utf-8")
+
+        coordinate_transformation_mode_list = ["half_pixel", "asymmetric"]
+        if coordinate_transformation_mode not in coordinate_transformation_mode_list:
+            quit(
+                "[ERROR] Currently, the coordinate_transformation_mode of Resize must in",
+                coordinate_transformation_mode_list,
+                node,
+            )
+        else:
+            coordinate_transformation_mode = coordinate_transformation_mode_list.index(
+                coordinate_transformation_mode
+            )
+        if cubic_coeff_a is None:
+            cubic_coeff_a = -0.75
+        else:
+            cubic_coeff_a = cubic_coeff_a.f
+        if not cubic_coeff_a == -0.75:
+            quit(
+                "[ERROR] Currently, the cubic_coeff_a of Resize must be default -0.75.",
+                node,
+            )
+        if exclude_outside is None:
+            exclude_outside = 0
+        else:
+            exclude_outside = exclude_outside.i
+        if not exclude_outside == 0:
+            quit(
+                "[ERROR] Currently, the exclude_outside of Resize must be default 0.",
+                node,
+            )
+        mode_list = ["linear", "nearest"]
+        if mode not in mode_list:
+            quit("[ERROR] Currently, the mode of Resize must in", mode_list, node)
+        else:
+            mode = mode_list.index(mode)
+        nearest_mode_list = ["floor", "round_prefer_ceil"]
+        if nearest_mode not in nearest_mode_list:
+            quit(
+                "[ERROR] Currently, the nearest_mode of Resize must in",
+                nearest_mode_list,
+                node,
+            )
+        else:
+            nearest_mode = nearest_mode_list.index(nearest_mode)
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Resize
+        myGraph[node.output[0]]["inputs"] = input_list
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        myGraph[node.output[0]]["additional"]["input_label"] = input_label
+        myGraph[node.output[0]]["additional"][
+            "coordinate_transformation_mode"
+        ] = coordinate_transformation_mode
+        myGraph[node.output[0]]["additional"]["mode"] = mode
+        myGraph[node.output[0]]["additional"]["nearest_mode"] = nearest_mode
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Less":
+        additional = {}
+        additional["data"] = node
+        if is_constant(node.input[1], model_onnx.graph.initializer):
+            n2 = getNodesWithOutput(node.input[1], model_onnx)  # constant
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(
+                node.input[1],
+                False,
+                model_onnx.graph,
+                verbose,
+            )
+            myGraph[node.input[1]] = {}
+            myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+            myGraph[node.input[1]]["inputs"] = []
+            myGraph[node.input[1]]["additional"] = additional
+            map_onnx_to_myGraph[node.input[1]] = node.input[1]
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Compare
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + [
+            map_onnx_to_myGraph[node.input[1]]
+        ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        myGraph[node.output[0]]["additional"]["mode"] = 0
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Greater":
+        additional = {}
+        additional["data"] = node
+        if is_constant(node.input[1], model_onnx.graph.initializer):
+            n2 = getNodesWithOutput(node.input[1], model_onnx)  # constant
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(
+                node.input[1],
+                False,
+                model_onnx.graph,
+                verbose,
+            )
+            myGraph[node.input[1]] = {}
+            myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+            myGraph[node.input[1]]["inputs"] = []
+            myGraph[node.input[1]]["additional"] = additional
+            map_onnx_to_myGraph[node.input[1]] = node.input[1]
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Compare
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + [
+            map_onnx_to_myGraph[node.input[1]]
+        ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        myGraph[node.output[0]]["additional"]["mode"] = 1
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Where":
+        if is_constant(node.input[1], model_onnx.graph.initializer):
+            additional = {}
+            additional["data"] = node
+            n2 = getNodesWithOutput(node.input[1], model_onnx)
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(node.input[1], False, model_onnx.graph, verbose)
+            myGraph[node.input[1]] = {}
+            myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+            myGraph[node.input[1]]["inputs"] = []
+            myGraph[node.input[1]]["additional"] = additional
+            map_onnx_to_myGraph[node.input[1]] = node.input[1]
+        if is_constant(node.input[2], model_onnx.graph.initializer):
+            additional = {}
+            additional["data"] = node
+            n2 = getNodesWithOutput(node.input[2], model_onnx)
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(node.input[2], False, model_onnx.graph, verbose)
+            myGraph[node.input[2]] = {}
+            myGraph[node.input[2]]["op_type"] = OPTYPE.Const
+            myGraph[node.input[2]]["inputs"] = []
+            myGraph[node.input[2]]["additional"] = additional
+            map_onnx_to_myGraph[node.input[2]] = node.input[2]
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Where
+        myGraph[node.output[0]]["inputs"] = (
+            [map_onnx_to_myGraph[n0name]]
+            + [map_onnx_to_myGraph[node.input[1]]]
+            + [map_onnx_to_myGraph[node.input[2]]]
+        )
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    elif node.op_type == "Equal":
+        additional = {}
+        additional["data"] = node
+        if is_constant(node.input[1], model_onnx.graph.initializer):
+            n2 = getNodesWithOutput(node.input[1], model_onnx)  # constant
+            (
+                additional["dims"],
+                additional["raw_data"],
+                additional["dtype"],
+            ) = extract_additional_data(
+                node.input[1],
+                False,
+                model_onnx.graph,
+                verbose,
+            )
+            myGraph[node.input[1]] = {}
+            myGraph[node.input[1]]["op_type"] = OPTYPE.Const
+            myGraph[node.input[1]]["inputs"] = []
+            myGraph[node.input[1]]["additional"] = additional
+            map_onnx_to_myGraph[node.input[1]] = node.input[1]
+
+        myGraph[node.output[0]] = {}
+        myGraph[node.output[0]]["op_type"] = OPTYPE.Compare
+        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]] + [
+            map_onnx_to_myGraph[node.input[1]]
+        ]
+        myGraph[node.output[0]]["additional"] = {}
+        myGraph[node.output[0]]["additional"]["data"] = node
+        myGraph[node.output[0]]["additional"]["mode"] = 2
+        map_onnx_to_myGraph[node.output[0]] = node.output[0]
+
+    else:
+        raise Exception("[ERROR] node not supported:\n{})".format(node))
+
+    if node_annotation[node.name].add_transpose_after:
+        n0name = add_transpose_after(node, myGraph, map_onnx_to_myGraph)
+
+
+def parse_onnx(model_onnx, node_annotation, verbose=False):
+    myGraph, map_onnx_to_myGraph = OrderedDict(), {}
+
+    # Inputs
+    for inp in model_onnx.graph.input:
+        myGraph[inp.name] = parse_graph_input_node(
+            inp, map_onnx_to_myGraph, node_annotation[inp.name].to_transpose
+        )
+
+    # Nodes removal
+    for node in model_onnx.graph.node:
+        if node.name in node_annotation and node_annotation[node.name].to_remove:
+            curr_key = node.input[0]
+            while (
+                map_onnx_to_myGraph[curr_key] is not None
+                and map_onnx_to_myGraph[curr_key] != curr_key
+            ):
+                next_key = map_onnx_to_myGraph[curr_key]
+                curr_key = next_key
+                if curr_key not in map_onnx_to_myGraph:
+                    curr_key = node.input[0]
+                    break
+
+            map_onnx_to_myGraph[node.output[0]] = curr_key
+        else:
+            parse_graph_node(
+                node, model_onnx, myGraph, node_annotation, map_onnx_to_myGraph, verbose
+            )
+
+    myInputs = []
+    for inp in model_onnx.graph.input:
+        myInputs.append(inp.name)
+
+    myOutputs = []
+    for out in model_onnx.graph.output:
+        for key, value in map_onnx_to_myGraph.items():
+            if key == out.name:
+                myOutputs.append(value)
+
+    return myGraph, myInputs, myOutputs
+
+
+QUANTIZERS = dict()
+
+
+def dump_onnx(graph, my_inputs, my_outputs, output_filename, verbose=False):
+    # graph[my_name]={ op_type
+    #                  inputs: []
+    #                  dtype:
+    #                  onnx : model.graph.node[x]
+    #                  }
+
+    # my_input=[my_name, my_name..]
+    # outputs=[my_name, ...]
+    # print(graph)
+    input_q = 11
+    relu_q = 11
+    mul_q = 13
+
+    map_name_to_idx = dict()
+    for idx, (key, _value) in enumerate(graph.items()):
+        map_name_to_idx[key] = idx
+
+    # dbg print(map_name_to_idx)
+    with open(output_filename, "wb") as f:
+        f.write(str.encode("SADL0004"))
+        # output of the network type 0: int32 | 1: float | 2: int16 | default: float(1)
+        f.write(struct.pack("i", int(DTYPE_SADL.FLOAT)))
+
+        if verbose:
+            print(f"# Nb layers: {len(graph.keys())}")
+        f.write(struct.pack("i", int(len(graph.keys()))))
+
+        inputs = []
+        for name in my_inputs:
+            inputs.append(map_name_to_idx[name])
+        if verbose:
+            print(f"# Nb inputs: {len(inputs)}")
+        f.write(struct.pack("i", int(len(inputs))))
+        for i in inputs:
+            if verbose:
+                print(f"#  input {i}")
+            f.write(struct.pack("i", int(i)))
+
+        outputs = []
+        for name in my_outputs:
+            outputs.append(map_name_to_idx[name])
+        if verbose:
+            print(f"# Nb outputs: {len(outputs)}")
+        f.write(struct.pack("i", int(len(outputs))))
+        for i in outputs:
+            if verbose:
+                print(f"#  output {i}")
+            f.write(struct.pack("i", int(i)))
+
+        for name, node in graph.items():
+            quantizer_idx = map_name_to_idx[name]
+            QUANTIZERS[quantizer_idx] = 0
+
+            if verbose:
+                print(f"# Layer id {map_name_to_idx[name]}")
+            f.write(struct.pack("i", int(map_name_to_idx[name])))
+
+            if verbose:
+                print("#\t op " + str(node["op_type"]))
+            f.write(struct.pack("i", int(node["op_type"].value)))
+
+            # Name size
+            if verbose:
+                print(f"#\t name_size {len(name)}")
+            f.write(struct.pack("i", int(len(name))))
+
+            # Name
+            if verbose:
+                print(f"#\t name {name}")
+            f.write(str.encode(str(name)))
+
+            # Nb inputs
+            if verbose:
+                print(f"#\t nb_inputs {len(node['inputs'])}")
+            f.write(struct.pack("i", int(len(node["inputs"]))))
+
+            for name_i in node["inputs"]:
+                idx = map_name_to_idx[name_i]
+                if verbose:
+                    print(f"#\t\t {idx} ({name_i})")
+                f.write(struct.pack("i", int(idx)))
+
+            # Additional info depending on OPTYPE
+            if node["op_type"] == OPTYPE.Const:
+                if verbose:
+                    print(f"#\t nb_dim {len(node['additional']['dims'])}")
+                f.write(struct.pack("i", int(len(node["additional"]["dims"]))))
+
+                for dim in node["additional"]["dims"]:
+                    if verbose:
+                        print(f"#\t\t {dim}")
+                    f.write(struct.pack("i", int(dim)))
+
+                if verbose:
+                    print(f"#\t dtype {node['additional']['dtype']}")
+                f.write(struct.pack("i", int(node["additional"]["dtype"])))
+
+                if node["additional"]["dtype"] != DTYPE_SADL.FLOAT:  # not float
+                    if verbose:
+                        print("#\t quantizer 0")
+                    f.write(struct.pack("i", int(0)))
+
+                f.write(node["additional"]["raw_data"])
+
+                # update names
+                if (
+                    args.orig_params_to_new is not None
+                    and name in map_orig_params_to_new_params
+                ):
+                    orig_name = map_orig_params_to_new_params[name]
+                else:
+                    orig_name = name
+
+                if ".multiplier" in orig_name:
+                    QUANTIZERS[quantizer_idx] = mul_q
+                elif "PRelu" in name:
+                    QUANTIZERS[quantizer_idx] = relu_q
+                else:
+                    QUANTIZERS[quantizer_idx] = base_quantizers.get(orig_name, 0)
+
+                if verbose:
+                    print(orig_name, QUANTIZERS[quantizer_idx])
+
+            # ???    if "alpha" in layer['additional']:
+            #        f.write(struct.pack('f', float(layer['additional']['alpha'])))
+
+            elif node["op_type"] == OPTYPE.Slice:
+                dim_keys = ["h", "w", "c"]
+                for dim in dim_keys:
+                    if verbose:
+                        print(
+                            f"#\t start_depth index for {dim} slicing",
+                            node["additional"][f"start_{dim}"],
+                        )
+                        print(
+                            f"#\t end_depth index for {dim} slicing",
+                            node["additional"][f"end_{dim}"],
+                        )
+                    f.write(struct.pack("i", int(node["additional"][f"start_{dim}"])))
+                    f.write(struct.pack("i", int(node["additional"][f"end_{dim}"])))
+
+            elif node["op_type"] == OPTYPE.Conv2D:
+                if verbose:
+                    print("#\t  nb_dim_strides", len(node["additional"]["strides"]))
+                f.write(struct.pack("i", int(len(node["additional"]["strides"]))))
+
+                for stride in node["additional"]["strides"]:
+                    if verbose:
+                        print(f"#\t\t {stride}")
+                    f.write(struct.pack("i", int(stride)))
+
+                if verbose:
+                    print("#\t  nb_dim_pads", len(node["additional"]["pads"]))
+                f.write(struct.pack("i", int(len(node["additional"]["pads"]))))
+
+                for p in node["additional"]["pads"]:
+                    if verbose:
+                        print(f"#\t\t {p}")
+                    f.write(struct.pack("i", int(p)))
+
+                if verbose:
+                    print("#\t  nb_group", node["additional"]["group"])
+                f.write(struct.pack("i", int(node["additional"]["group"])))
+
+            elif node["op_type"] == OPTYPE.Conv2DTranspose:
+                if verbose:
+                    print("#\t  nb_dim_strides", len(node["additional"]["strides"]))
+                f.write(struct.pack("i", int(len(node["additional"]["strides"]))))
+
+                for stride in node["additional"]["strides"]:
+                    if verbose:
+                        print(f"#\t\t {stride}")
+                    f.write(struct.pack("i", int(stride)))
+
+                if verbose:
+                    print("#\t  nb_dim_pads", len(node["additional"]["pads"]))
+                f.write(struct.pack("i", int(len(node["additional"]["pads"]))))
+
+                for p in node["additional"]["pads"]:
+                    if verbose:
+                        print(f"#\t\t {p}")
+                    f.write(struct.pack("i", int(p)))
+
+                if verbose:
+                    print(
+                        "#\t  nb_dim_output_padding",
+                        len(node["additional"]["output_padding"]),
+                    )
+                f.write(
+                    struct.pack("i", int(len(node["additional"]["output_padding"])))
+                )
+
+                for p in node["additional"]["output_padding"]:
+                    if verbose:
+                        print(f"#\t\t {p}")
+                    f.write(struct.pack("i", int(p)))
+
+            elif node["op_type"] == OPTYPE.Placeholder:
+                if verbose:
+                    print(f"#\t nb input dimension {len(node['additional']['dims'])}")
+                f.write(struct.pack("i", int(len(node["additional"]["dims"]))))
+
+                for dim in node["additional"]["dims"]:
+                    if verbose:
+                        print(f"#\t\t {dim}")
+                    f.write(struct.pack("i", int(dim)))
+
+                # output the quantizer of the input default: 0
+                if verbose:
+                    print("#\t quantizer_of_input 0")
+                f.write(struct.pack("i", int(0)))
+
+                if "Equal" not in name:
+                    QUANTIZERS[quantizer_idx] = input_q
+                else:
+                    QUANTIZERS[quantizer_idx] = 0
+
+            elif node["op_type"] == OPTYPE.MaxPool:
+                if verbose:
+                    print("#\t  nb_dim_strides", len(node["additional"]["strides"]))
+                f.write(struct.pack("i", int(len(node["additional"]["strides"]))))
+
+                for stride in node["additional"]["strides"]:
+                    if verbose:
+                        print(f"#\t\t {stride}")
+                    f.write(struct.pack("i", int(stride)))
+
+                if verbose:
+                    print("#\t  nb_dim_kernel", len(node["additional"]["kernel_shape"]))
+                f.write(struct.pack("i", int(len(node["additional"]["kernel_shape"]))))
+
+                for ks in node["additional"]["kernel_shape"]:
+                    if verbose:
+                        print(f"#\t\t {ks}")
+                    f.write(struct.pack("i", int(ks)))
+
+                if verbose:
+                    print("#\t  nb_dim_pads", len(node["additional"]["pads"]))
+                f.write(struct.pack("i", int(len(node["additional"]["pads"]))))
+
+                for p in node["additional"]["pads"]:
+                    if verbose:
+                        print(f"#\t\t {p}")
+                    f.write(struct.pack("i", int(p)))
+
+            elif node["op_type"] == OPTYPE.Flatten:
+                if verbose:
+                    print("#\t axis", node["additional"]["axis"])
+                f.write(struct.pack("i", int(node["additional"]["axis"])))
+
+            elif node["op_type"] == OPTYPE.GridSample:
+                if verbose:
+                    print("#\t align_corners", node["additional"]["align_corners"])
+                f.write(struct.pack("i", int(node["additional"]["align_corners"])))
+
+                if verbose:
+                    print("#\t mode", node["additional"]["mode"])
+                f.write(struct.pack("i", int(node["additional"]["mode"])))
+
+                if verbose:
+                    print("#\t padding_mode", node["additional"]["padding_mode"])
+                f.write(struct.pack("i", int(node["additional"]["padding_mode"])))
+
+            elif node["op_type"] == OPTYPE.Resize:
+                if verbose:
+                    print("#\t input_label", node["additional"]["input_label"])
+                f.write(struct.pack("i", int(node["additional"]["input_label"])))
+
+                if verbose:
+                    print(
+                        "#\t coordinate_transformation_mode",
+                        node["additional"]["coordinate_transformation_mode"],
+                    )
+                f.write(
+                    struct.pack(
+                        "i", int(node["additional"]["coordinate_transformation_mode"])
+                    )
+                )
+
+                if verbose:
+                    print("#\t mode", node["additional"]["mode"])
+                f.write(struct.pack("i", int(node["additional"]["mode"])))
+
+                if verbose:
+                    print("#\t nearest_mode", node["additional"]["nearest_mode"])
+                f.write(struct.pack("i", int(node["additional"]["nearest_mode"])))
+
+            elif node["op_type"] == OPTYPE.Compare:
+                if verbose:
+                    print("#\t mode", node["additional"]["mode"])
+                f.write(struct.pack("i", int(node["additional"]["mode"])))
+
+            if (
+                node["op_type"] == OPTYPE.Conv2D
+                or node["op_type"] == OPTYPE.Conv2DTranspose
+                or node["op_type"] == OPTYPE.MatMul
+                or node["op_type"] == OPTYPE.Mul
+            ):
+                # output the internal quantizer default: 0
+                f.write(struct.pack("i", int(0)))
+
+            if verbose:
+                print("")
+
+
+# adatp (remove/add) the current node to the data_layout and
+# recurse in the output
+def annotate_node(
+    node, model_onnx, node_annotation, global_data_layout, verbose
+):  # recusrive
+    if node.name in node_annotation:
+        return
+    if verbose > 1:
+        print("[INFO] annotate {}".format(node.name))
+
+    if verbose:
+        print("[INFO] annotate_node {} op={}".format(node.name, node.op_type))
+    data_layout = None
+
+    # inherit from input
+    for inp in node.input:
+        n2 = getNodesWithOutputNotConst(inp, model_onnx)
+        if n2 is not None:
+            if n2.name in node_annotation:
+                if data_layout is None:
+                    data_layout = node_annotation[n2.name].layout_onnx
+                elif (
+                    node_annotation[n2.name].layout_onnx is not None
+                    and node_annotation[n2.name].layout_onnx != data_layout
+                ):
+                    quit(
+                        "[ERROR] inputs with diferent layout for\n{}Layouts: {}".format(
+                            node, node_annotation
+                        )
+                    )
+            else:  # not ready yet
+                return
+
+    if verbose > 1 and data_layout is None:
+        print(
+            "[WARNING] no data layout constraints for {}\n {}".format(node.name, node)
+        )
+
+    if node.name not in node_annotation:
+        node_annotation[node.name] = Node_Annotation()
+    node_annotation[node.name].layout_onnx = data_layout  # default
+
+    if node.op_type == "Transpose":
+        a = getAttribute(node, "perm")
+        if data_layout == "nhwc":
+            if (
+                a.ints[0] == 0 and a.ints[1] == 3 and a.ints[2] == 1 and a.ints[3] == 2
+            ):  # nhwc ->nchw
+                node_annotation[node.name].to_remove = True  # will be removed
+                node_annotation[node.name].layout_onnx = "nchw"  # new layout at output
+            else:
+                if verbose > 1:
+                    print("[WARNING] transpose not for NCHW handling in\n", node)
+        elif data_layout == "nchw":
+            if (
+                a.ints[0] == 0 and a.ints[1] == 2 and a.ints[2] == 3 and a.ints[3] == 1
+            ):  # nchw ->nhwc
+                node_annotation[node.name].to_remove = True  # will be removed
+                node_annotation[node.name].layout_onnx = "nhwc"  # new layout at output
+            else:
+                if verbose > 1:
+                    print("[WARNING] transpose not for NCHW handling in\n", node)
+
+            if node_annotation[node.name].to_remove:
+                # The GridSample is usually used with Transpose. Cause the optical-flow will be
+                # transposed from (N,2,H,W) to (N,H,W,2) and this operation should not be removed.
+                # Meanwhile there will not have other operations after transposed feature with
+                # shape (N,H,W,2) in PyTorch, so the code in this IF statement will not influnce
+                # other situations.
+                nexts = getNodesWithInput(node.output[0], model_onnx)
+                for n in nexts:
+                    if n.op_type == "GridSample":
+                        node_annotation[node.name].to_remove = False
+                        node_annotation[node.name].layout_onnx = "nchw"
+                        break
+
+    elif node.op_type == "Reshape":
+        initializer = getInitializer(node.input[1], model_onnx)
+        # Case: In pytorch, Reshape is not in model_onnx.graph.initializer but in model_onnx.graph.node
+        if initializer is None:
+            attribute = getAttribute(
+                getNodesWithOutput(node.input[1], model_onnx), "value"
+            )
+            initializer = attribute.t
+        dims = getDims(initializer)
+
+        # detect if this reshape is actually added by onnx to emulate a transpose
+        # we need to test more if reshpae is for transpose...
+        if len(dims) == 4 and (dims[0] == 1 or dims[0] == -1):
+            if data_layout == "nhwc":
+                if dims[1] == 1:  # or dims2 * dims3 == 1 # nhwc ->nchw
+                    node_annotation[node.name].to_remove = True  # will be removed
+                    node_annotation[
+                        node.name
+                    ].layout_onnx = "nchw"  # new layout at output
+                else:
+                    if verbose > 1:
+                        print("[WARNING] reshape unknown for", node, " dims", dims)
+                    node_annotation[node.name].layout_onnx = None
+            elif data_layout == "ncwh":
+                if dims[3] == 1:  # # or dims2 * dims3 == 1 nchw ->nhwc
+                    node_annotation[node.name].to_remove = True  # will be removed
+                    node_annotation[
+                        node.name
+                    ].layout_onnx = "nhwc"  # new layout at output
+                else:
+                    if verbose > 1:
+                        print("[WARNING] reshape unknown for", node, " dims", dims)
+                    node_annotation[node.name].layout_onnx = None
+            elif data_layout is None:
+                node_annotation[
+                    node.name
+                ].layout_onnx = global_data_layout  # back to org
+                if global_data_layout == "nchw":
+                    node_annotation[
+                        node.name
+                    ].add_transpose_after = True  # a bit too agressive
+        else:
+            node_annotation[node.name].layout_onnx = None
+
+        n2 = getNodesWithOutputNotConst(node.input[0], model_onnx)
+        if (
+            node_annotation[n2.name].layout_onnx == "nchw"
+        ):  # need to go back to original layout before reshape
+            node_annotation[node.name].add_transpose_before = True
+
+    elif node.op_type == "Flatten":
+        if (
+            node_annotation[node.name].layout_onnx == "nchw"
+        ):  # need to go back to original layout before reshape
+            node_annotation[node.name].add_transpose_before = True
+
+    elif node.op_type == "Concat":
+        if data_layout == "nchw":  # nhwc -> nhwc
+            a = getAttribute(node, "axis")
+            if a.i == 1:
+                a.i = 3
+            elif a.i == 2:
+                a.i = 1
+            elif a.i == 3:
+                a.i = 2
+            elif a.i == -3:
+                a.i = -1
+
+    elif node.op_type == "Unsqueeze":
+        node_annotation[node.name].to_remove = True
+
+    elif node.op_type == "Conv":
+        n2 = getInitializer(node.input[1], model_onnx)
+        node_annotation[n2.name].to_transpose = True
+        node_annotation[n2.name].layout_onnx = "nhwc"
+
+    elif node.op_type == "ConvTranspose":
+        n2 = getInitializer(node.input[1], model_onnx)
+        node_annotation[n2.name].to_transpose = True
+        node_annotation[n2.name].layout_onnx = "nhwc"
+
+    elif node.op_type == "Gemm":
+        n2 = getInitializer(node.input[1], model_onnx)
+        if global_data_layout == "nchw":
+            node_annotation[n2.name].to_transpose = True
+        #    node_annotation[n2.name].layout_onnx = 'nhwc'
+
+    elif node.op_type == "ScatterND":
+        n2 = getInitializer(node.input[1], model_onnx)
+        if global_data_layout == "nchw":
+            node_annotation[n2.name].to_transpose = True
+
+    nexts = getNodesWithInput(node.output[0], model_onnx)
+    for n in nexts:
+        annotate_node(
+            n, model_onnx, node_annotation, global_data_layout, verbose
+        )  # rec
+
+
+def annotate_graph(model_onnx, node_annotation, data_layout, verbose):
+
+    # track the data layout in the graph and remove/add layers if necessary
+    for inp in model_onnx.graph.input:
+        node_annotation[inp.name] = Node_Annotation()
+        if len(inp.type.tensor_type.shape.dim) == 4:
+            node_annotation[inp.name].layout_onnx = data_layout
+            if data_layout == "nchw":
+                node_annotation[inp.name].to_transpose = True
+        else:
+            node_annotation[inp.name].layout_onnx = None
+
+    for inp in model_onnx.graph.initializer:
+        node_annotation[inp.name] = Node_Annotation()
+        node_annotation[inp.name].layout_onnx = None
+
+    for inp in model_onnx.graph.node:
+        if inp.op_type == "Constant":
+            node_annotation[inp.name] = Node_Annotation()
+            node_annotation[inp.name].layout_onnx = None
+
+    for inp in model_onnx.graph.input:
+        nexts = getNodesWithInput(inp.name, model_onnx)
+        for n in nexts:
+            annotate_node(
+                n, model_onnx, node_annotation, data_layout, verbose
+            )  # recusrive
+
+    if verbose > 1:
+        for node in model_onnx.graph.node:
+            if node.op_type == "Transpose" and (
+                node.name not in node_annotation
+                or not node_annotation[node.name].to_remove
+            ):
+                print(
+                    "[ERROR] preprocess_onnxGraph: all transpose node should be removed but this is not the case here: {}\n{}".format(
+                        node.name, node
+                    )
+                )
+
+
+def detectDataType(model):  # more adaptation to do here if tf is using nchw
+    if model.producer_name == "tf2onnx":
+        return "nhwc"
+    elif model.producer_name == "pytorch":
+        return "nchw"
+    else:
+        quit("[ERROR] unable to detect data layout")
+
+
+def dumpModel(model_onnx, output_filename, data_layout, verbose, user_annotation):
+    """Writes the neural network model in the \"sadl\" format to binary file.
+
+    Parameters
+    ----------
+    model : onnx model
+    output_filename : either str or None
+        Path to the binary file to which the neural network model
+        is written.
+    data_type: None, 'ncwh' or 'nwhc'
+    verbose : bool
+        Is additional information printed?
+    """
+    model_onnx_copy = copy.deepcopy(model_onnx)
+    if data_layout is None:
+        data_layout = detectDataType(model_onnx_copy)
+
+    if verbose:
+        print("[INFO] assume data type", data_layout)
+
+    if verbose > 1:
+        # remove data
+        gg = copy.deepcopy(model_onnx.graph)
+        for node in gg.initializer:
+            node.raw_data = np.array(0.0).tobytes()
+        print("[INFO] original graph:\n", gg)
+        del gg
+
+    if data_layout != "nhwc" and data_layout != "nchw":
+        quit("[ERROR] unsupported layout", data_layout)
+
+    node_annotation = {}
+    annotate_graph(model_onnx_copy, node_annotation, data_layout, verbose)
+
+    for k, v in user_annotation.items():
+        if k in node_annotation:
+            if v.add_transpose_before is not None:
+                node_annotation[k].add_transpose_before = v.add_transpose_before
+            if v.add_transpose_after is not None:
+                node_annotation[k].add_transpose_after = v.add_transpose_after
+            if v.to_remove is not None:
+                node_annotation[k].to_remove = v.to_remove
+            if v.to_transpose is not None:
+                node_annotation[k].to_transpose = v.to_transpose
+        else:
+            print("[ERROR] unknown node user custom", k)
+            quit()
+
+    if verbose > 1:
+        print(
+            "INFO] annotations:\n{"
+            + "\n".join("{!r}: {!r},".format(k, v) for k, v in node_annotation.items())
+            + "}"
+        )  # print("[INFO] node annotations:", node_annotation)
+    my_graph, my_inputs, my_outputs = parse_onnx(
+        model_onnx_copy, node_annotation, verbose=verbose
+    )
+    dump_onnx(my_graph, my_inputs, my_outputs, output_filename, verbose=verbose)
+    if data_layout == "nchw":
+        print(
+            "[INFO] in SADL, your inputs and outputs has been changed from NCHW to NHWC"
+        )
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        prog="onnx2sadl conversion", usage="NB: force run on CPU"
+    )
+    parser.add_argument(
+        "--input_onnx",
+        action="store",
+        nargs="?",
+        type=str,
+        help="name of the onnx file",
+    )
+    parser.add_argument(
+        "--output",
+        action="store",
+        nargs="?",
+        type=str,
+        help="name of model binary file",
+    )
+    parser.add_argument(
+        "--out_quant",
+        action="store",
+        nargs="?",
+        type=str,
+        default="",
+        help="name of quantizer file",
+    )
+    parser.add_argument(
+        "--base_quantizers",
+        action="store",
+        nargs="?",
+        type=str,
+        help="path to base quantizers file",
+    )
+    parser.add_argument(
+        "--orig_params_to_new",
+        action="store",
+        nargs="?",
+        type=str,
+        help="path to the file that contains the mapping of orig param names to new",
+    )
+    parser.add_argument("--nchw", action="store_true")
+    parser.add_argument("--nhwc", action="store_true")
+    parser.add_argument("--verbose", action="count")
+    parser.add_argument(
+        "--do_not_add_transpose_before",
+        action="store",
+        nargs="+",
+        default=[],
+        help="specify a node where add transpose before will be disable",
+    )
+    parser.add_argument(
+        "--do_not_add_transpose_after",
+        action="store",
+        nargs="+",
+        default=[],
+        help="specify a node where add transpose after will be disable",
+    )
+
+    args = parser.parse_args()
+    if args.input_onnx is None:
+        quit("[ERROR] You should specify an onnx file")
+    if args.output is None:
+        quit("[ERROR] You should specify an output file")
+
+    print("[INFO] ONNX converter")
+    if args.verbose is None:
+        args.verbose = 0
+
+    model_onnx = onnx.load(args.input_onnx)
+
+    user_annotation = {}
+    for node in args.do_not_add_transpose_before:
+        if node not in user_annotation:
+            user_annotation[node] = Node_Annotation()
+            user_annotation[node].to_remove = None
+            user_annotation[node].add_transpose_before = None
+            user_annotation[node].add_transpose_after = None
+            user_annotation[node].to_transpose = None
+        user_annotation[node].add_transpose_before = False
+
+    for node in args.do_not_add_transpose_after:
+        if node not in user_annotation:
+            user_annotation[node] = Node_Annotation()
+            user_annotation[node].to_remove = None
+            user_annotation[node].add_transpose_before = None
+            user_annotation[node].add_transpose_after = None
+            user_annotation[node].to_transpose = None
+        user_annotation[node].add_transpose_after = False
+
+    data_layout = None
+    if args.nchw:
+        data_layout = "nchw"
+    elif args.nhwc:
+        data_layout = "nhwc"
+
+    base_quantizers = read_json_file(Path(args.base_quantizers))
+
+    if args.orig_params_to_new is not None:
+        map_orig_params_to_new_params = read_json_file(Path(args.orig_params_to_new))
+
+    dumpModel(model_onnx, args.output, data_layout, args.verbose, user_annotation)
+
+    original_stdout = sys.stdout
+    if args.out_quant != "":
+        with open(args.out_quant, "w") as f:
+            sys.stdout = f  # Change the standard output to the file we created.
+            for idx in QUANTIZERS:
+                print(f"{idx} {QUANTIZERS[idx]}", end=" ")
+            sys.stdout = (
+                original_stdout  # Reset the standard output to its original value
+            )
+
+    # The modifications: output quantizers when dumping onnx model (in the function dump_onnx())
diff --git a/training/training_scripts/NN_Adaptive_Filtering/conversion/sadl2torch.py b/training/training_scripts/NN_Adaptive_Filtering/conversion/sadl2torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..32f58f7004b8a84bb2f515f22c3d6d486b66f78f
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/conversion/sadl2torch.py
@@ -0,0 +1,539 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import hashlib
+import os
+import sys
+from pathlib import Path
+from tempfile import NamedTemporaryFile
+from typing import Dict
+
+import numpy as np
+import torch
+
+from conversion.onnx2sadl_modified import DTYPE_SADL, OPTYPE
+from models import load_torch_model, model_lop2, model_lop2_with_multiplier
+from util.file_system import read_json_file, write_json_file
+
+# used to load tensor in SADL, the max dimension of tensor is 6
+MAX_DIMENSION = 6
+verbose = 0
+
+
+def md5file(filename):
+    md5 = hashlib.md5()
+    # handle content in binary form
+    f = open(filename, "rb")
+    while True:
+        chunk = f.read(4096)
+        if not chunk:
+            break
+        md5.update(chunk)
+    return md5.hexdigest()
+
+
+def load_prefix(sadl_file):
+    L = int.from_bytes(sadl_file.read(4), byteorder="little")
+    maxLength = 2048
+    assert L > 0 and L + 1 < maxLength  # max name size
+    layer_name = sadl_file.read(L).decode("utf-8")
+    if verbose:
+        print(f"  - name: {layer_name}")
+
+    L = int.from_bytes(sadl_file.read(4), byteorder="little")
+    assert 0 <= L < 8
+    inputs_id = [
+        int.from_bytes(sadl_file.read(4), byteorder="little") for _ in range(L)
+    ]
+    if verbose:
+        print("  - inputs:", inputs_id)
+    return layer_name
+
+
+def load_internal_const(sadl_file):
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:
+        quit(f"[ERROR] invalid nb of dimensions: {x}")
+
+    d = [0] * x
+    for k in range(len(d)):
+        d[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - tensor: {d}")
+
+    num_elements = np.prod(d)
+    ele_type = int.from_bytes(sadl_file.read(4), byteorder="little")
+
+    if ele_type == DTYPE_SADL.FLOAT:
+        ele_dtype = np.float32
+        quantizer = 0
+    elif ele_type == DTYPE_SADL.INT16:
+        ele_dtype = np.int16
+        quantizer = int.from_bytes(sadl_file.read(4), byteorder="little")
+    elif ele_type == DTYPE_SADL.INT32:
+        ele_dtype = np.int32
+        quantizer = int.from_bytes(sadl_file.read(4), byteorder="little")
+    else:
+        if verbose:
+            print("[ERROR] unknown internal type")
+
+    element_size = np.dtype(ele_dtype).itemsize
+    b_data = sadl_file.read(element_size * num_elements)
+    data = np.frombuffer(b_data, ele_dtype)
+    # dequantize data
+    data = data / float(2**quantizer)
+
+    if verbose:
+        print("  - data:", end=" ")
+    for k in range(4 if len(data) >= 4 else len(data)):
+        if verbose:
+            print(data[k], end=" ")
+    if verbose:
+        print("...")
+    return d, data.copy(), quantizer
+
+
+def load_internal_conv2d(sadl_file):
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:
+        quit(f"[ERROR] invalid nb of dimensions: {x}")
+
+    strides = [0] * x
+    for k in range(len(strides)):
+        strides[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+
+    if len(strides) == 2:
+        strides = [1, strides[0], strides[1], 1]
+
+    if len(strides) != 4:
+        quit(f"[ERROR] invalid strides: {len(strides)}")
+
+    if strides[0] != 1:
+        quit(f"[ERROR] invalid strides[0]: {strides[0]}")
+
+    if strides[3] != 1:
+        quit(f"[ERROR] invalid strides[3]: {strides[3]}")
+
+    if strides[1] != 1 and strides[1] != 2:
+        quit(f"[ERROR] not 1 or 2: to check {strides}")
+
+    if strides[2] != 1 and strides[2] != 2:
+        quit(f"[ERROR] not 1 or 2: to check {strides}")
+
+    if verbose:
+        print(f"  - strides: {strides}")
+
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:
+        quit(f"[ERROR] invalid nb of dimensions: {x}")
+
+    pads = [0] * x
+    for k in range(len(pads)):
+        pads[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - pads: {pads}")
+
+    groups = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - groups: {groups}")
+
+    quantizer = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - q: {quantizer}")
+
+
+def load_internal_conv2dTranspose(sadl_file):
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:  # Dimensions.MaxDim
+        quit(f"[ERROR] invalid nb of dimensions: {x}")
+
+    strides = [0] * x
+    for k in range(len(strides)):
+        strides[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+
+    if len(strides) == 2:
+        strides = [1, strides[0], strides[1], 1]
+
+    if len(strides) != 4:
+        quit(f"[ERROR] invalid strides: {len(strides)}")
+
+    if strides[0] != 1:
+        quit(f"[ERROR] invalid strides[0]: {strides[0]}")
+
+    if strides[3] != 1:
+        quit(f"[ERROR] invalid strides[3]: {strides[3]}")
+
+    if strides[1] != 2 or strides[2] != 2:
+        quit(f"[ERROR] stride not 2: to check {strides}")
+
+    if verbose:
+        print(f"  - strides: {strides}")
+
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:  # Dimensions.MaxDim
+        quit(f"[ERROR] invalid nb of dimensions: {x}")
+
+    pads = [0] * x
+    for k in range(len(pads)):
+        pads[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+        if pads[k] != 1 and pads[k] != 2:
+            quit(f"[ERROR] pads values not supported: {pads[k]}")
+
+    if verbose:
+        print(f"  - pads: {pads}")
+
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:
+        quit(f"[ERROR] invalid nb of dimensions: {x}")
+
+    out_pads = [0] * x
+    for k in range(len(out_pads)):
+        out_pads[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+        if out_pads[k] != 1:
+            quit(f"[ERROR] output pads !=1 {out_pads[k]}")
+
+    if verbose:
+        print(f"  - out_pads: {out_pads}")
+
+    quantizer = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - q: {quantizer}")
+
+
+def load_internal_flatten(sadl_file):
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:  # Dimensions.MaxDim
+        quit(f"[ERROR] invalid axis: {x}")
+
+    m_axis = x
+    if verbose:
+        print(f"  - start axis: {m_axis}")
+
+
+def load_internal_matmul(sadl_file):
+    quantizer = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - q: {quantizer}")
+
+
+def load_internal_maxpool(sadl_file):
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:
+        quit(f"[ERROR] invalid nb of dimensions strides: {x}")
+
+    strides = [0] * x
+    for k in range(len(strides)):
+        strides[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - strides: {strides}")
+
+    if len(strides) != 4:
+        quit(f"[ERROR] invalid strides: {len(strides)}")
+
+    if strides[0] != 1:
+        quit(f"[ERROR] invalid strides[0]: {strides[0]}")
+
+    if strides[3] != 1:
+        quit(f"[ERROR] invalid strides[3]: {strides[3]}")
+
+    if strides[1] != strides[2]:
+        quit(f"[ERROR] invalid stride H Vs: {strides}")
+
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:
+        quit(f"[ERROR] invalid nb of dimensions kernel: {x}")
+
+    kernel = [0] * x
+    for k in range(len(kernel)):
+        kernel[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - kernel: {kernel}")
+
+    if len(kernel) != 4:
+        quit(f"[ERROR] invalid kernel: {len(kernel)}")
+
+    if kernel[0] != 1:
+        quit(f"[ERROR] invalid kernel[0]: {kernel[0]}")
+
+    if kernel[3] != 1:
+        quit(f"[ERROR] invalid kernel[3]: {kernel[3]}")
+
+    if kernel[1] != kernel[2]:
+        quit(f"[ERROR] invalid kernel H V: {kernel}")
+
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:
+        quit(f"[ERROR] invalid nb of dimensions: {x}")
+
+    pads = [0] * x
+    for k in range(len(pads)):
+        pads[k] = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - pads: {pads}")
+
+
+def load_internal_mul(sadl_file):
+    quantizer = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - q: {quantizer}")
+
+
+def load_internal_placeholder(sadl_file):
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if x <= 0 or x > MAX_DIMENSION:  # Dimensions.MaxDim
+        quit(f"[ERROR] invalid nb of dimensions: {x}")
+
+    dims = [int.from_bytes(sadl_file.read(4), byteorder="little") for _ in range(x)]
+    if len(dims) == 1:
+        dims = [1] + dims
+
+    quantizer = int.from_bytes(sadl_file.read(4), byteorder="little")
+
+    if verbose:
+        print(f"  - dim: {dims}")
+    if verbose:
+        print(f"  - q: {quantizer}")
+
+
+def load_internal_slice(sadl_file):
+    start_h = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - start_h: {start_h}")
+
+    end_h = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - end_h: {end_h}")
+
+    start_w = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - start_w: {start_w}")
+
+    end_w = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - end_w: {end_w}")
+
+    start_c = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - start_c: {start_c}")
+
+    end_c = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"  - end_c: {end_c}")
+
+
+MAP_OPERATION_TO_LOAD_FUNC = {
+    OPTYPE.Const: load_internal_const,
+    OPTYPE.Conv2D: load_internal_conv2d,
+    OPTYPE.Conv2DTranspose: load_internal_conv2dTranspose,
+    OPTYPE.Flatten: load_internal_flatten,
+    OPTYPE.MatMul: load_internal_matmul,
+    OPTYPE.MaxPool: load_internal_maxpool,
+    OPTYPE.Mul: load_internal_mul,
+    OPTYPE.Placeholder: load_internal_placeholder,
+    OPTYPE.Slice: load_internal_slice,
+}
+
+
+def load_internal(sadl_file, operation):
+    try:
+        op_func = MAP_OPERATION_TO_LOAD_FUNC[operation]
+        out = op_func(sadl_file)
+    except Exception:
+        # No need to load internal for current layer
+        return
+
+    return out
+
+
+def read_sadl_file(sadl_model_path):
+    base_model_quantizers = {}
+    sadl_file = open(sadl_model_path, "rb")
+    magic = sadl_file.read(8).decode()
+    if verbose:
+        print(f"[INFO] read magic {magic}")
+
+    x = int.from_bytes(sadl_file.read(4), byteorder="little")
+    model_type = DTYPE_SADL(x).name
+    if verbose:
+        print(f"[INFO] Model type: {model_type}")
+
+    nb_layers = int.from_bytes(sadl_file.read(4), byteorder="little")
+    if verbose:
+        print(f"[INFO] Num layers: {nb_layers}")
+
+    # get input and output id
+    nb = int.from_bytes(sadl_file.read(4), byteorder="little")
+    ids_input = list(
+        int.from_bytes(sadl_file.read(4), byteorder="little") for _ in range(nb)
+    )
+    nb = int.from_bytes(sadl_file.read(4), byteorder="little")
+    ids_output = list(
+        int.from_bytes(sadl_file.read(4), byteorder="little") for _ in range(nb)
+    )
+    if verbose:
+        print("[INFO] input id:", ids_input)
+    if verbose:
+        print("[INFO] output id:", ids_output)
+
+    # load layers
+    const_layers = {}
+    for k in range(nb_layers):
+        layer_id = int.from_bytes(sadl_file.read(4), byteorder="little")
+        layer_op = int.from_bytes(sadl_file.read(4), byteorder="little")
+        if not (0 < layer_op < OPTYPE.Count):
+            quit(f"[ERROR] Pb reading model: layer op {layer_op}")
+
+        if verbose:
+            print(f"[INFO] id: {layer_id} op {OPTYPE(layer_op).name}")
+        layer_name = load_prefix(sadl_file)
+        out = load_internal(sadl_file, OPTYPE(layer_op))
+        if OPTYPE(layer_op) == OPTYPE.Const:
+            if (
+                ".weight" in layer_name
+                or ".bias" in layer_name
+                or "PRelu" in layer_name
+            ):
+                const_layers[layer_name] = {}
+                const_layers[layer_name]["dimension"] = out[0]
+                const_layers[layer_name]["data"] = out[1]
+                base_model_quantizers[layer_name] = out[2]
+
+    if verbose:
+        print("[INFO] == end model loading ==\n")
+
+    return const_layers, base_model_quantizers
+
+
+def map_sadl_param_name_to_torch(layer_data: Dict):
+    # only names of PRelu need to be modified
+    param_names = list(layer_data.keys())
+    for i, k in enumerate(param_names):
+        if "PRelu" in k:
+            name_of_last_param = param_names[i - 1].split(".")
+            name_of_current_param = (
+                ".".join(name_of_last_param[:-2])
+                + "."
+                + str(int(name_of_last_param[-2]) + 1)
+                + ".weight"
+            )
+            layer_data[name_of_current_param] = layer_data[k]
+            layer_data.pop(k, None)
+
+
+def create_model(orig_sadl_model, orig_model_config):
+    layer_data, base_model_quantizers = read_sadl_file(orig_sadl_model)
+    map_sadl_param_name_to_torch(layer_data)
+
+    torch_model = load_torch_model(orig_model_config, model_lop2)
+    for name, _ in torch_model.SADL_model.state_dict().items():
+        dimension = layer_data[name]["dimension"]
+        weights = torch.Tensor(layer_data[name]["data"])
+        # in SADL, dim format of conv weights is WHC_inC_out, convert to C_outC_intWH here
+        if len(dimension) == 4:
+            weights = weights.reshape(dimension)
+            weights = weights.permute(3, 2, 0, 1)
+        torch_model.SADL_model.state_dict()[name][:] = weights
+
+    return torch_model, base_model_quantizers
+
+
+def map_params(orig_model_config) -> Dict[str, str]:
+    orig_model = load_torch_model(orig_model_config, model_lop2)
+    orig_params = orig_model.state_dict()
+
+    new_model = load_torch_model(orig_model_config, model_lop2_with_multiplier)
+    new_params = new_model.state_dict()
+
+    param_map = {}
+
+    tmp_params = list()
+
+    for new_param in new_params:
+        if "multiplier" not in new_param:
+            tmp_params.append(new_param)
+
+    for (orig_param, new_param) in zip(orig_params.keys(), tmp_params):
+        tmp_orig_param = orig_param.replace("SADL_model.", "")
+        tmp_new_param = new_param.replace("SADL_model.", "")
+        param_map[tmp_new_param] = tmp_orig_param
+
+    return param_map
+
+
+if __name__ == "__main__":
+    config = read_json_file("resources/config.json")
+
+    if "lop2" != config["architecture"]:
+        print(f"{config['architecture']} not supported")
+        sys.exit(-1)
+
+    model_config = read_json_file("models/models.json")[config["architecture"]]
+
+    orig_sadl_model_i = config["sadl2torch"]["base_model_sadl_i"]
+    output_path = config["sadl2torch"]["base_model_torch"]
+
+    verbose = 0
+
+    # dequantized model
+    torch_model_dq, base_quantizers = create_model(orig_sadl_model_i, model_config)
+    torch.save(torch_model_dq.state_dict(), output_path)
+
+    # save quantizers of base model
+    base_quantizers_file = config["sadl2torch"]["base_model_quantizers"]
+    write_json_file(base_quantizers, base_quantizers_file)
+
+    org_to_new_params = map_params(model_config)
+    write_json_file(
+        org_to_new_params, Path(config["sadl2torch"]["base_model_params_to_new"])
+    )
+
+    # NOTE: current scripts can be used to convert SADL0004 version to Torch
+    if verbose:
+        with NamedTemporaryFile() as tmp_onnx, NamedTemporaryFile() as tmp_sadl_f, NamedTemporaryFile() as tmp_sadl_i, NamedTemporaryFile() as tmp_quantizers:
+            torch_model_dq.SADL_model.to_onnx(tmp_onnx.name)
+
+            os.system(
+                f"python conversion/onnx2sadl_modified.py --input {tmp_onnx.name} --output {tmp_sadl_f.name} --out_quant {tmp_quantizers.name} --base_model_quantizers {base_quantizers_file}"
+            )
+            os.system(
+                f"tail -n 1 {tmp_quantizers.name} | ./naive_quantization {tmp_sadl_f.name} {tmp_sadl_i.name}"
+            )
+
+            # check equivalence of base int model and requantized dequantized base model
+            src_md5 = md5file(orig_sadl_model_i)
+            dst_md5 = md5file(tmp_sadl_i.name)
+
+            assert (
+                src_md5 == dst_md5
+            ), "Something went wrong in new custom SADL model reader"
diff --git a/training/training_scripts/NN_Adaptive_Filtering/conversion/torch2sadl.py b/training/training_scripts/NN_Adaptive_Filtering/conversion/torch2sadl.py
new file mode 100644
index 0000000000000000000000000000000000000000..a46251c8f9f937fa9918a9be3bae5d7687355532
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/conversion/torch2sadl.py
@@ -0,0 +1,130 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import os
+import sys
+from pathlib import Path
+from typing import Union
+
+import torch
+
+from models import load_torch_model, model_lop2_with_multiplier
+from util.file_system import create_directory, read_json_file
+
+
+def torch_to_onnx(
+    torch_model_file: Union[Path, str],
+    out_onnx_path: Union[Path, str],
+    model_config: Union[Path, str],
+    model_module: torch.nn.Module = model_lop2_with_multiplier,
+):
+    # Load model
+    model = load_torch_model(model_config, model_module)
+    model.load_state_dict(torch.load(torch_model_file))
+
+    # Torch to Onnx
+    model.SADL_model.to_onnx(out_onnx_path)
+
+    # remove identity layer (cause errors in SADL conversion)
+    exit_code = os.system(
+        f"python -m onnxoptimizer {out_onnx_path} {out_onnx_path} -p eliminate_identity"
+    )
+    if exit_code:
+        sys.exit(-1)
+
+
+def onnx_to_sadl(
+    sadl_models_dir,
+    onnx_model_path,
+    quantizers_dir,
+    sadl_int_models_dir,
+    base_quantizers,
+    orig_params_to_new=None,
+):
+    sadl_model_file = sadl_models_dir / f"{onnx_model_path.stem}.sadl"
+    q_out = quantizers_dir / f"Q_{onnx_model_path.stem}.log"
+    command = (
+        f"python conversion/onnx2sadl_modified.py --input_onnx {onnx_model_path} --output {sadl_model_file} "
+        f"--out_quant {q_out} --base_quantizers {base_quantizers}"
+    )
+    if orig_params_to_new is not None:
+        command += f" --orig_params_to_new {orig_params_to_new}"
+    if os.system(command):
+        sys.exit(-1)
+
+    # convert float SADL to int16
+    output_file = sadl_int_models_dir / f"{onnx_model_path.stem}.sadl"
+    command_int = (
+        f"tail -n 1 {q_out} | ./naive_quantization {sadl_model_file} {output_file}"
+    )
+    if os.system(command_int):
+        sys.exit(-1)
+
+
+def torch_model_to_sadl():
+    input_models_files = (root_dir / "nnr_models").glob("*.pt")
+
+    sadl_float_dir = root_dir / "sadl_float_models"
+    sadl_int_dir = root_dir / "sadl_int_models"
+    quantizers_dir = root_dir / "quantizers"
+    onnx_models_dir = root_dir / "onnx_models"
+    create_directory(onnx_models_dir)
+    create_directory(sadl_float_dir)
+    create_directory(sadl_int_dir)
+    create_directory(quantizers_dir)
+
+    # Torch to ONNX
+    for torch_model in input_models_files:
+        model_name = torch_model.stem
+
+        print("-----------------------------")
+        print(model_name)
+        print("-----------------------------")
+
+        # Torch to Onnx
+        out_onnx_path = onnx_models_dir / f"{model_name}.onnx"
+        torch_to_onnx(torch_model, out_onnx_path, model_config_file)
+
+        # Onnx to SADL
+        onnx_to_sadl(sadl_float_dir, out_onnx_path, quantizers_dir, sadl_int_dir)
+
+
+if __name__ == "__main__":
+    test_cfg = read_json_file(Path("resources/config.json"))
+    root_dir = Path(test_cfg["training"]["output_path"])
+    model_config_file = Path("models/models.json")
+
+    torch_model_to_sadl()
+
+    sys.exit(0)
diff --git a/training/training_scripts/NN_Adaptive_Filtering/create_dataset_dirs.py b/training/training_scripts/NN_Adaptive_Filtering/create_dataset_dirs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c190cc642127f6095e93f2c1c4a37a218934854a
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/create_dataset_dirs.py
@@ -0,0 +1,57 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from pathlib import Path
+
+from util.file_system import create_directory, read_json_file
+
+
+def create_directories() -> None:
+    dataset_orig_path = create_directory(config["dataset"]["orig_path"])
+    dataset_deco_path = create_directory(config["dataset"]["deco_path"])
+
+    seq_labels = read_json_file(Path("resources/datasets/jvet_labels.json"))
+    qps = ["22", "27", "32", "37", "42"]
+
+    for seq, label in seq_labels.items():
+        create_directory(dataset_orig_path / label)
+
+        deco_seq_path = create_directory(dataset_deco_path / label)
+        for qp in qps:
+            create_directory(deco_seq_path / qp)
+
+
+if __name__ == "__main__":
+    config = read_json_file(Path("resources/config.json"))
+    create_directories()
diff --git a/training/training_scripts/NN_Adaptive_Filtering/create_env.sh b/training/training_scripts/NN_Adaptive_Filtering/create_env.sh
new file mode 100755
index 0000000000000000000000000000000000000000..218d0e8581b9a4c2dd252579ed4cf2f2efb5de49
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/create_env.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+python3 -m venv ${HOME}/lop2_overf
+source ${HOME}/lop2_overf/bin/activate
+
+# Torch dependencies
+pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
+
+pip install click==8.1.7 openpyxl
+
+# NCTM dependencies
+pip install protobuf==3.20.2 memory_profiler==0.55.0 dcase-util==0.2.10 pandas opencv-python scikit-image scikit-learn tensorflow-gpu==2.8.0
+
+# SADL dependencies
+pip install tqdm onnxoptimizer
+
+deactivate
diff --git a/training/training_scripts/NN_Adaptive_Filtering/data_preparation.py b/training/training_scripts/NN_Adaptive_Filtering/data_preparation.py
new file mode 100644
index 0000000000000000000000000000000000000000..caa9571caf2aa70dbf7a99c74b4e9852da4d1150
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/data_preparation.py
@@ -0,0 +1,121 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import os
+from pathlib import Path
+
+from util.file_system import (
+    check_directory,
+    check_file,
+    list_dirs,
+    read_json_file,
+    write_json_file,
+)
+from util.regex import get_frame_info_from_decoder_log_line
+
+
+def extract_coding_info(root_dir: Path) -> None:
+    """
+    Extracts coding info from the configuration file and the encoding log
+    :param root_dir: Directory that contains the coding simulation results
+    """
+    seq_dirs = list_dirs(root_dir)
+    for seq_dir in seq_dirs:
+        qp_dirs = list_dirs(seq_dir)
+        for qp_dir in qp_dirs:
+            coding_info = {
+                "base_qp": int(qp_dir.name),
+                "POC": {},
+            }
+
+            log_file = check_file(qp_dir / "log_dec.txt")
+            with open(log_file, "r") as stream:
+                for line in stream:
+                    if line.startswith("POC"):
+                        (
+                            poc,
+                            temporal_layer,
+                            slice_type,
+                            slice_qp,
+                        ) = get_frame_info_from_decoder_log_line(line)
+
+                        coding_info["POC"][poc] = {
+                            "temporal_layer": temporal_layer,
+                            "slice_type": slice_type,
+                            "slice_qp": slice_qp,
+                        }
+
+            write_json_file(coding_info, qp_dir / "coding_info.json")
+
+
+def decode_nnvc_bitstream(
+    dataset_dir: Path, decoder_bin: Path, lop_filter_path: Path, intra_filter_dir: Path
+) -> None:
+    """
+    Decodes video bitstreams using NNVC
+    :param dataset_dir: Directory that contains the bitstream (e.g. root_dir -> seq_dir -> qp -> bitstream.266)
+    :param decoder_bin: Video decoder binary
+    :param lop_filter_path: Default LOP2 filter
+    :param intra_filter_dir: Directory that contains intra filters
+    """
+    seq_dirs = list_dirs(dataset_dir)
+
+    for seq_dir in seq_dirs:
+        qp_dirs = list_dirs(seq_dir)
+        for qp_dir in qp_dirs:
+            print(seq_dir.name, qp_dir.name)
+            bitstream = qp_dir / f"{seq_dir.name}_{qp_dir.name}.266"
+            check_file(bitstream)
+            log_file = qp_dir / "log_dec.txt"
+            command = f"{decoder_bin} -b {bitstream} --DumpBasename={qp_dir / seq_dir.name} --NnlfModelName={lop_filter_path} --PrefixAbsolutePathsToGraphsOutput={intra_filter_dir} > {log_file}"
+            os.system(command)
+
+
+def run_data_preparation() -> None:
+    dataset_deco_path = check_directory(config["dataset"]["deco_path"])
+    decode_nnvc_bitstream(
+        config["dataset"]["deco_path"],
+        config["decoder_bin"],
+        config["sadl2torch"]["base_model_sadl_i"],
+        config["intra_filters"],
+    )
+
+    # Extract POC, frame QP and temporal layer
+    extract_coding_info(dataset_deco_path)
+
+
+if __name__ == "__main__":
+    config = read_json_file(Path("resources/config.json"))
+    seq_cfg = read_json_file(Path("resources/datasets/jvet.json"))
+    run_data_preparation()
diff --git a/training/training_scripts/NN_Adaptive_Filtering/environment.yml b/training/training_scripts/NN_Adaptive_Filtering/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..569d237f7fd837a9e1d19a689a6b22814ec45b6e
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/environment.yml
@@ -0,0 +1,163 @@
+name: lop2_overf
+channels:
+  - pytorch
+  - defaults
+dependencies:
+  - _libgcc_mutex=0.1=main
+  - _openmp_mutex=5.1=1_gnu
+  - blas=1.0=mkl
+  - brotli-python=1.0.9=py39h6a678d5_7
+  - bzip2=1.0.8=h7b6447c_0
+  - ca-certificates=2023.12.12=h06a4308_0
+  - certifi=2023.11.17=py39h06a4308_0
+  - cffi=1.16.0=py39h5eee18b_0
+  - charset-normalizer=2.0.4=pyhd3eb1b0_0
+  - cryptography=41.0.7=py39hdda0065_0
+  - cudatoolkit=11.3.1=h2bc3f7f_2
+  - ffmpeg=4.3=hf484d3e_0
+  - freetype=2.12.1=h4a9f257_0
+  - giflib=5.2.1=h5eee18b_3
+  - gmp=6.2.1=h295c915_3
+  - gnutls=3.6.15=he1e5248_0
+  - idna=3.4=py39h06a4308_0
+  - intel-openmp=2023.1.0=hdb19cb5_46306
+  - jpeg=9e=h5eee18b_1
+  - lame=3.100=h7b6447c_0
+  - lcms2=2.12=h3be6417_0
+  - ld_impl_linux-64=2.38=h1181459_1
+  - lerc=3.0=h295c915_0
+  - libdeflate=1.17=h5eee18b_1
+  - libffi=3.4.4=h6a678d5_0
+  - libgcc-ng=11.2.0=h1234567_1
+  - libgomp=11.2.0=h1234567_1
+  - libiconv=1.16=h7f8727e_2
+  - libidn2=2.3.4=h5eee18b_0
+  - libpng=1.6.39=h5eee18b_0
+  - libstdcxx-ng=11.2.0=h1234567_1
+  - libtasn1=4.19.0=h5eee18b_0
+  - libtiff=4.5.1=h6a678d5_0
+  - libunistring=0.9.10=h27cfd23_0
+  - libwebp=1.3.2=h11a3e52_0
+  - libwebp-base=1.3.2=h5eee18b_0
+  - lz4-c=1.9.4=h6a678d5_0
+  - mkl=2023.1.0=h213fc3f_46344
+  - mkl-service=2.4.0=py39h5eee18b_1
+  - mkl_fft=1.3.8=py39h5eee18b_0
+  - mkl_random=1.2.4=py39hdb19cb5_0
+  - ncurses=6.4=h6a678d5_0
+  - nettle=3.7.3=hbbd107a_1
+  - numpy=1.26.3=py39h5f9d8c6_0
+  - numpy-base=1.26.3=py39hb5e798b_0
+  - openh264=2.1.1=h4ff587b_0
+  - openjpeg=2.4.0=h3ad879b_0
+  - openssl=3.0.12=h7f8727e_0
+  - pillow=10.0.1=py39ha6cbd5a_0
+  - pip=23.3.1=py39h06a4308_0
+  - pycparser=2.21=pyhd3eb1b0_0
+  - pyopenssl=23.2.0=py39h06a4308_0
+  - pysocks=1.7.1=py39h06a4308_0
+  - python=3.9.18=h955ad1f_0
+  - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0
+  - pytorch-mutex=1.0=cuda
+  - readline=8.2=h5eee18b_0
+  - requests=2.31.0=py39h06a4308_0
+  - setuptools=68.2.2=py39h06a4308_0
+  - sqlite=3.41.2=h5eee18b_0
+  - tbb=2021.8.0=hdb19cb5_0
+  - tk=8.6.12=h1ccaba5_0
+  - torchaudio=0.12.1=py39_cu113
+  - torchvision=0.13.1=py39_cu113
+  - typing_extensions=4.9.0=py39h06a4308_0
+  - urllib3=1.26.18=py39h06a4308_0
+  - wheel=0.41.2=py39h06a4308_0
+  - xz=5.4.5=h5eee18b_0
+  - zlib=1.2.13=h5eee18b_0
+  - zstd=1.5.5=hc292b87_0
+  - pip:
+    - absl-py==2.0.0
+    - astunparse==1.6.3
+    - audioread==3.0.1
+    - black==23.3.0
+    - cachetools==5.3.2
+    - click==8.1.7
+    - contourpy==1.2.0
+    - cycler==0.12.1
+    - dcase-util==0.2.10
+    - decorator==5.1.1
+    - flake8==6.1.0
+    - flatbuffers==23.5.26
+    - fonttools==4.47.2
+    - future==0.18.3
+    - gast==0.5.4
+    - google-auth==2.26.2
+    - google-auth-oauthlib==0.4.6
+    - google-pasta==0.2.0
+    - grpcio==1.60.0
+    - h5py==3.10.0
+    - imageio==2.33.1
+    - importlib-metadata==7.0.1
+    - importlib-resources==6.1.1
+    - joblib==1.3.2
+    - keras==2.8.0
+    - keras-preprocessing==1.1.2
+    - kiwisolver==1.4.5
+    - lazy-loader==0.3
+    - libclang==16.0.6
+    - librosa==0.10.1
+    - llvmlite==0.41.1
+    - markdown==3.5.2
+    - markupsafe==2.1.3
+    - matplotlib==3.8.2
+    - mccabe==0.7.0
+    - memory-profiler==0.55.0
+    - msgpack==1.0.7
+    - mypy-extensions==1.0.0
+    - networkx==3.2.1
+    - numba==0.58.1
+    - oauthlib==3.2.2
+    - onnx==1.15.0
+    - onnxoptimizer==0.3.13
+    - opencv-python==4.9.0.80
+    - opt-einsum==3.3.0
+    - packaging==23.2
+    - pandas==2.1.4
+    - pathspec==0.12.1
+    - platformdirs==4.1.0
+    - pooch==1.8.0
+    - protobuf==3.20.2
+    - psutil==5.9.7
+    - pyasn1==0.5.1
+    - pyasn1-modules==0.3.0
+    - pycodestyle==2.11.1
+    - pydot-ng==2.0.0
+    - pyflakes==3.1.0
+    - pyparsing==3.1.1
+    - python-dateutil==2.8.2
+    - python-magic==0.4.27
+    - pytz==2023.3.post1
+    - pyyaml==6.0.1
+    - requests-oauthlib==1.3.1
+    - rsa==4.9
+    - scikit-image==0.22.0
+    - scikit-learn==1.3.2
+    - scipy==1.11.4
+    - six==1.16.0
+    - soundfile==0.12.1
+    - soxr==0.3.7
+    - tensorboard==2.8.0
+    - tensorboard-data-server==0.6.1
+    - tensorboard-plugin-wit==1.8.1
+    - tensorflow-gpu==2.8.0
+    - tensorflow-io-gcs-filesystem==0.35.0
+    - termcolor==2.4.0
+    - tf-estimator-nightly==2.8.0.dev2021122109
+    - threadpoolctl==3.2.0
+    - tifffile==2023.12.9
+    - tomli==2.0.1
+    - tqdm==4.66.1
+    - tzdata==2023.4
+    - validators==0.22.0
+    - werkzeug==3.0.1
+    - wrapt==1.16.0
+    - zipp==3.17.0
+prefix: ~/anaconda3/envs/lop2_overf
diff --git a/training/training_scripts/NN_Adaptive_Filtering/launch_pipeline.py b/training/training_scripts/NN_Adaptive_Filtering/launch_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..4818f7115241b5ef98632806a8ff0ce37dd72017
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/launch_pipeline.py
@@ -0,0 +1,164 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import argparse
+import os
+import sys
+from pathlib import Path
+
+from util.file_system import create_directory, read_json_file
+
+
+def parse_arguments():
+    parser = argparse.ArgumentParser(
+        description="Launching whole pipeline of overfittings, including training, nnr compression and conversion to SADL"
+    )
+    parser.add_argument(
+        "-vs",
+        "--video_sequences",
+        nargs="+",
+        type=str,
+        required=True,
+        help="Sequence name/tag",
+    )
+    parser.add_argument("--qps", nargs="+", type=str, required=True, help="Sequence QP")
+    return parser.parse_args()
+
+
+def launch_pipeline():
+    data_cfg = read_json_file(data_cfg_file)
+    for seq_name in args.video_sequences:
+        num_frames = data_cfg[seq_name]["frames"]
+        intra_period = 64 if data_cfg[seq_name]["fps"] > 32 else 32
+        num_part = int((num_frames - 1) // intra_period) + 1
+
+        for qp in args.qps:
+            num_param = min(
+                num_layer_config[seq_name][qp], config["training"]["num_layers"]
+            )
+
+            for part in range(num_part):
+                min_poc = part * intra_period
+                max_poc = min(min_poc + intra_period + 1, num_frames)
+
+                print(
+                    f"\x1b[6;30;42m Overfitting for video: {seq_name}, qp: {qp}, part: {part} \x1b[0m"
+                )
+
+                curr_overfitting_dir = overfittings_dir / f"{seq_name}_{qp}_part{part}"
+
+                command = (
+                    f"python {script_to_execute} --max_epochs={epochs} "
+                    f"--deco_dir_train={dataset_deco_path} --orig_dir_train={dataset_orig_path} --prop_file_train={data_cfg_file} "
+                    f"--seq_qp_train={qp} --seq_name={seq_name} --slice_type=* --base_nn_params={base_model_path} "
+                    f"--output_dir={curr_overfitting_dir} --arch={config['architecture']} "
+                    f"--stop_patience {stop_patience} --lr_patience {lr_patience} --lr={lr} "
+                    f"--min_poc {min_poc} --max_poc {max_poc} --part_idx {part} --select_layer --num_param {num_param} "
+                    f"--nnr_models_dir {nnr_models_dir} --onnx_models_dir {onnx_models_dir} --sadl_float_dir {sadl_float_dir} --sadl_int_dir {sadl_int_dir} --quantizers_dir {quantizers_dir} "
+                    f"--base_quantizers {base_quantizers} --map_orig_params_to_new {map_orig_params_to_new} "
+                )
+
+                if start_epoch:
+                    nn_params = (
+                        curr_overfitting_dir
+                        / f"models/checkpoints/model_{start_epoch:03d}.pt"
+                    )
+                    optimiser_params = (
+                        curr_overfitting_dir
+                        / f"models/checkpoints/model_{start_epoch:03d}_optimiser.pt"
+                    )
+                    lr_scheduler_params = (
+                        curr_overfitting_dir
+                        / f"models/checkpoints/model_{start_epoch:03d}_lr_scheduler.pt"
+                    )
+                    command = (
+                        command
+                        + f"--start_epoch {start_epoch} --nn_params {nn_params} --optimiser_params {optimiser_params} --lr_scheduler_params {lr_scheduler_params}"
+                    )
+
+                os.system(command)
+
+
+if __name__ == "__main__":
+    args = parse_arguments()
+
+    config = read_json_file("resources/config.json")
+
+    if "lop2" != config["architecture"]:
+        print(f"{config['architecture']} architecture not supported")
+        sys.exit(-1)
+
+    model_config = "models/models.json"
+
+    dataset_orig_path = config["dataset"]["orig_path"]
+    dataset_deco_path = config["dataset"]["deco_path"]
+    data_cfg_file = Path("resources/datasets/jvet.json")
+    output_dir = Path(config["training"]["output_path"])
+
+    # path to overfitted models
+    overfittings_dir = output_dir / "overfittings"
+
+    # path to nnr models
+    nnr_models_dir = output_dir / "nnr_models"
+
+    # path to SADL models
+    sadl_float_dir = output_dir / "sadl_float_models"
+    sadl_int_dir = output_dir / "sadl_int_models"
+    quantizers_dir = output_dir / "quantizers"
+    onnx_models_dir = output_dir / "onnx_models"
+
+    create_directory(overfittings_dir)
+    create_directory(nnr_models_dir)
+    create_directory(nnr_models_dir / "nnr")
+    create_directory(sadl_float_dir)
+    create_directory(sadl_int_dir)
+    create_directory(quantizers_dir)
+    create_directory(onnx_models_dir)
+
+    base_model_path = config["sadl2torch"]["base_model_torch"]
+    base_quantizers = config["sadl2torch"]["base_model_quantizers"]
+    map_orig_params_to_new = config["sadl2torch"]["base_model_params_to_new"]
+
+    num_layer_config = read_json_file("resources/num_layers.json")
+
+    script_to_execute = "./overfitting_pipeline.py"
+    epochs = 1
+    lr = 1e-3
+    stop_patience = 50
+    lr_patience = 100
+    weight_decay = 0
+    total_terms = 100
+    start_epoch = 0
+
+    launch_pipeline()
diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..91510001de79b9d6a786493d53c25f129cff915c
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/models/__init__.py
@@ -0,0 +1,61 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Dict, Union
+
+from util.file_system import read_json_file
+
+
+def instantiate_from_dict(
+    namespace, config: Dict[str, Any], *args: Any, **kwargs: Dict[str, Any]
+) -> Any:
+    """Instantiate instance from config dict. Value of 'class' in dict should be a class available in namespace.
+    Args:
+        namespace: load class definition from this namespace
+        config: dict containing kwargs for instantiation
+        *args: args for instantiation
+        **kwargs: additional kwargs for instantiation
+    """
+    return getattr(namespace, config.pop("class"))(*args, **config, **kwargs)
+
+
+def load_torch_model(model_config: Union[Path, Dict], namespace):
+    config = (
+        model_config if isinstance(model_config, Dict) else read_json_file(model_config)
+    )
+    copy_config = deepcopy(config)
+    model = instantiate_from_dict(namespace, copy_config["model"])
+    return model
diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2.py b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2.py
new file mode 100644
index 0000000000000000000000000000000000000000..47e781fba82674b101bc0991be341be92d2423f1
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2.py
@@ -0,0 +1,360 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Conv(nn.Sequential):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: Union[int, Tuple[int, int]],
+        stride: Union[int, Tuple[int, int]] = 1,
+        padding: Optional[Union[int, Tuple[int, int]]] = None,
+        is_separable: bool = False,
+        hidden_separable_channels: Optional[int] = None,
+        post_activation: Optional[Type] = nn.PReLU,
+        **kwargs,
+    ):
+        """
+        Args:
+            in_channels: the number of input channels
+            out_channels: the number of output channels
+            kernel_size: the convolution's kernel size
+            stride: the convolution's stride(s)
+            padding: the convolution's padding
+            is_separable: whether to implement convolution separably
+            hidden_separable_channels: If is_separable, the number of hidden channels between convolutions. If None, use out_channels
+            post_activation: activation function to use after convolution. If None, no activation after convolution
+            **kwargs: additional kwargs to pass to nn.Conv2d
+        """
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = (
+            (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+        )
+        self.stride = (stride, stride) if isinstance(stride, int) else stride
+        if padding is not None:
+            self.padding = (padding, padding) if isinstance(padding, int) else padding
+        else:
+            self.padding = tuple([k // 2 for k in self.kernel_size])
+        self.is_separable = is_separable
+        self.post_activation = post_activation
+
+        if self.is_separable:
+            self.hidden_separable_channels = hidden_separable_channels or out_channels
+            modules = [
+                nn.Conv2d(
+                    self.in_channels,
+                    self.hidden_separable_channels,
+                    (self.kernel_size[0], 1),
+                    (self.stride[0], 1),
+                    (self.padding[0], 0),
+                    groups=self.hidden_separable_channels,
+                    **kwargs,
+                ),
+                nn.Conv2d(
+                    self.hidden_separable_channels,
+                    self.out_channels,
+                    (1, self.kernel_size[1]),
+                    (1, self.stride[1]),
+                    (0, self.padding[1]),
+                    groups=self.hidden_separable_channels,
+                    **kwargs,
+                ),
+                nn.Conv2d(
+                    self.out_channels,
+                    self.out_channels,
+                    (1, 1),
+                    (1, 1),
+                    (0, 0),
+                    **kwargs,
+                ),
+            ]
+        else:
+            modules = [
+                nn.Conv2d(
+                    self.in_channels,
+                    self.out_channels,
+                    self.kernel_size,
+                    self.stride,
+                    self.padding,
+                    **kwargs,
+                )
+            ]
+
+        if self.post_activation is not None:
+            modules.append(self.post_activation())
+
+        super(Conv, self).__init__(*modules)
+
+
+class MultiBranchModule(nn.Module):
+    """A module representing multple, parallel branches. If the input is a list, each element in the list is fed into the corresponding branch,
+    otherwise the input is fed into every branch. The outputs of each branch are then merged.
+    """
+
+    def __init__(self, *branch_modules, merge_dimension: int = -3):
+        """
+        Args:
+            branch_modules: modules to run in parallel
+            merge_dimension: the dimension to merge outputs from each branch
+        """
+        super().__init__()
+        self.branches = nn.ModuleList(branch_modules)
+        self.merge_dimension = merge_dimension
+
+    def forward(self, args: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
+        inputs = args if isinstance(args, list) else len(self.branches) * [args]
+        branch_outputs = [branch(input) for branch, input in zip(self.branches, inputs)]
+        return torch.cat(branch_outputs, dim=self.merge_dimension)
+
+
+class ResidualBlock(nn.Sequential):
+    def __init__(self, C: int = 64, C1: int = 160, C21: int = 32):
+        super(ResidualBlock, self).__init__(
+            Conv(C, C1, kernel_size=1),
+            Conv(C1, C, kernel_size=1, post_activation=None),
+            Conv(
+                C,
+                C,
+                kernel_size=3,
+                post_activation=None,
+                is_separable=True,
+                hidden_separable_channels=C21,
+            ),
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return x + super(ResidualBlock, self).forward(x)
+
+
+class SplitLumaChromaBlocks(nn.Sequential):
+    def __init__(
+        self,
+        N_Y: int = 12,
+        N_UV: int = 6,
+        C: int = 16,
+        C1_Y: int = 64,
+        C1_UV: int = 48,
+        C21: int = 16,
+        output_channels_y: int = 4,
+        output_channels_uv: int = 2,
+    ):
+        super().__init__()
+
+        self.split_y_path = nn.Sequential(
+            *[ResidualBlock(C, C1_Y, C21) for _ in range(N_Y)],
+            Conv(C, C, kernel_size=3, is_separable=True, hidden_separable_channels=C21),
+            Conv(C, output_channels_y, kernel_size=3, post_activation=None),
+        )
+
+        self.split_uv_path = nn.Sequential(
+            *[ResidualBlock(C, C1_UV, C21) for _ in range(N_UV)],
+            Conv(C, C, kernel_size=3, is_separable=True, hidden_separable_channels=C21),
+            Conv(C, output_channels_uv, kernel_size=3, post_activation=None),
+        )
+
+        self.Cy = C
+        self.Cuv = C
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        split_y_input = x[:, : self.Cy, :, :]
+        split_uv_input = x[:, self.Cy : self.Cy + self.Cuv, :, :]
+
+        y_output = self.split_y_path.forward(split_y_input)
+        uv_output = self.split_uv_path.forward(split_uv_input)
+        return torch.cat((y_output, uv_output), dim=1)
+
+
+class SADLNet(nn.Sequential):
+    """The network used during SADL inference"""
+
+    def __init__(
+        self,
+        input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1],
+        input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3],
+        D1: int = 12,
+        D2: int = 8,
+        D3: int = 4,
+        D4: int = 2,
+        D5: int = 2,
+        D6: int = 24,
+        N_Y: int = 12,
+        N_UV: int = 6,
+        C: int = 16,
+        C1_Y: int = 64,
+        C1_UV: int = 48,
+        C21: int = 16,
+        output_channels_y: int = 4,
+        output_channels_uv: int = 2,
+    ):
+        """
+        Args:
+            input_channels: the number of channels expected for each input
+            input_kernels: the kernel size for each input convolution
+            output_channels: the number of output channels
+        """
+        self.input_channels = input_channels
+        self.input_kernels = input_kernels
+        self.input_features = [D1, D2, D3, D4, D4, D5]
+
+        super(SADLNet, self).__init__(
+            MultiBranchModule(
+                *[
+                    Conv(c, d, kernel_size=k, post_activation=None)
+                    for c, d, k in zip(
+                        self.input_channels, self.input_features, self.input_kernels
+                    )
+                ]
+            ),
+            Conv(sum(self.input_features), D6, kernel_size=1),
+            Conv(D6, C + C, kernel_size=3, stride=2),
+            SplitLumaChromaBlocks(
+                N_Y, N_UV, C, C1_Y, C1_UV, C21, output_channels_y, output_channels_uv
+            ),
+        )
+
+    def get_example_inputs(
+        self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1
+    ):
+        patch_size = (
+            (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
+        )
+        return [
+            torch.rand(
+                batch_size, conv.in_channels, *patch_size, device=conv[0].weight.device
+            )
+            for conv in self[0].branches
+        ]
+
+    def to_onnx(
+        self, filename: str, patch_size: int = 144, batch_size: int = 1, **kwargs
+    ) -> None:
+        mode = self.training
+        self.eval()
+        torch.onnx.export(
+            self, self.get_example_inputs(patch_size, batch_size), filename, **kwargs
+        )
+        self.train(mode)
+
+
+class Net(nn.Module):
+    """Wrapper for SADL model that implements input pre- and post-processing for training."""
+
+    def __init__(
+        self,
+        input_channels: Iterable[Iterable[str]] = [
+            ["rec_before_dbf_Y", "rec_before_dbf_U", "rec_before_dbf_V"],
+            ["pred_Y", "pred_U", "pred_V"],
+            ["bs_Y", "bs_U", "bs_V"],
+            ["qp_base"],
+            ["qp_slice"],
+            ["ipb_Y"],
+        ],
+        input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3],
+        D1: int = 12,
+        D2: int = 8,
+        D3: int = 4,
+        D4: int = 2,
+        D5: int = 2,
+        D6: int = 24,
+        N_Y: int = 12,
+        N_UV: int = 6,
+        C: int = 16,
+        C1_Y: int = 64,
+        C1_UV: int = 48,
+        C21: int = 16,
+    ):
+        super(Net, self).__init__()
+        assert len(input_channels) == len(
+            input_kernels
+        ), "[ERROR] input size and kernels size not equal"
+        self.input_channels = input_channels
+        sizes = [len(a) for a in input_channels]
+        self.SADL_model = SADLNet(
+            sizes,
+            input_kernels,
+            D1,
+            D2,
+            D3,
+            D4,
+            D5,
+            D6,
+            N_Y,
+            N_UV,
+            C,
+            C1_Y,
+            C1_UV,
+            C21,
+            4,
+            2,
+        )
+        self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest")
+
+    def preprocess_args(
+        self, batch: Dict[str, torch.Tensor]
+    ) -> Dict[str, torch.Tensor]:
+        for name, data in batch.items():
+            if "U" in name or "V" in name:
+                batch[name] = self.chroma_upsampler(batch[name])
+
+        return [
+            torch.cat([batch[name] for name in input_], dim=1)
+            for input_ in self.input_channels
+        ]
+
+    def postprocess_outputs(
+        self, batch: Dict[str, torch.Tensor], out: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        Y_res, UV_res = out.split([4, 2], dim=1)
+        return (
+            batch["rec_before_dbf_Y"] + F.pixel_shuffle(Y_res, 2),
+            torch.cat((batch["rec_before_dbf_U"], batch["rec_before_dbf_V"]), dim=1)[
+                ..., ::2, ::2
+            ]
+            + UV_res,
+        )
+
+    def forward(
+        self, batch: Dict[str, torch.Tensor]
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        args = self.preprocess_args(batch)
+        out = self.SADL_model(args)
+        return self.postprocess_outputs(batch, out)
diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_nnr.py b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_nnr.py
new file mode 100644
index 0000000000000000000000000000000000000000..aafefb2f50b7d7d00d25446fa12d0f03c8c9e3b5
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_nnr.py
@@ -0,0 +1,279 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import copy
+import logging
+from collections import OrderedDict
+from pathlib import Path
+from typing import Dict, Tuple, Union
+
+import nnc_core
+import torch
+
+import config
+from models import load_torch_model, model_lop2_with_multiplier
+from util import Colour
+from util.metrics import compute_loss_psnr, mse_loss
+
+LOGGER = logging.getLogger()
+
+
+def shorten_multiplier_name(in_var: str) -> str:
+    out_var = in_var.replace("SADL_model.", "")
+    out_var = out_var.replace(".branches.", "")
+    out_var = out_var.replace("split.", "")
+    out_var = out_var.replace("split_", "")
+    out_var = out_var.replace("_path", "")
+    out_var = out_var.replace("conv", "")
+    out_var = out_var.replace("nonsep_mul", "")
+    out_var = out_var.replace("sep_mul", "")
+    out_var = out_var.replace("..", ".")
+    out_var = out_var.replace(".multiplier", "")
+    return out_var
+
+
+class PytorchModel(nnc_core.nnr_model.NNRModel):
+    def __init__(
+        self, model_config: Union[Path, Dict], block_size: int, border_size: int
+    ) -> None:
+        self.model = load_torch_model(model_config, model_lop2_with_multiplier)
+        self.__model_info = None
+        self.__model_parameters = None
+        self.dataset = None
+
+        self.y_block_size = block_size
+        self.uv_block_size = block_size // 2
+        self.y_border_size = border_size
+        self.uv_border_size = border_size // 2
+
+        self.device = None
+        self.dataloader = None
+
+    def load_model(
+        self,
+        model_path,
+        is_ovef_model: bool,
+    ):
+        # Defining device and device's map location
+        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        model_file = torch.load(model_path, map_location=self.device)
+
+        if is_ovef_model:
+            self.model.load_state_dict(model_file, strict=True)
+        else:
+            tmp_params = OrderedDict()
+            dst_params = OrderedDict()
+
+            for name, value in self.model.state_dict().items():
+                if "multiplier" not in name:
+                    tmp_params[name] = value
+
+            for src_param, dst_param in zip(model_file.keys(), tmp_params.keys()):
+                dst_params[dst_param] = model_file[src_param]
+
+            self.model.load_state_dict(dst_params, strict=False)
+        self.disable_all_grads()
+        org_model, self.__model_info = self.pytorch_to_nctm(self.model.state_dict())
+        return org_model["parameters"]
+
+    def set_dataset(self, dataloader):
+        self.dataloader = dataloader
+
+    @property
+    def model_info(self):
+        return self.__model_info
+
+    def disable_all_grads(self):
+        for param in self.model.parameters():
+            param.requires_grad = False
+
+    def eval_model(
+        self, parameters, bn_folding=False, verbose=False
+    ) -> Tuple[float, float, float]:
+        # torch.set_num_threads(1)
+        if len(parameters) == 0:
+            return -100.0, -100.0, 100.0
+
+        Model = copy.deepcopy(self.model)
+        base_model_arch = Model.state_dict()
+        model_dict = OrderedDict()
+        for module_name in base_model_arch:
+            var_name = module_name
+            if ".multiplier" in var_name:
+                var_name = shorten_multiplier_name(var_name)
+
+            if module_name in parameters or var_name in parameters:
+                model_dict[module_name] = (
+                    parameters[var_name]
+                    if torch.is_tensor(parameters[var_name])
+                    else torch.tensor(parameters[var_name])
+                )
+            elif (
+                config.PUT_SYNTAX()
+            ):  # loading unchanged weights from the feature extractor
+                model_dict[module_name] = base_model_arch[module_name]
+
+        Model.load_state_dict(model_dict)
+        Model = Model.to(self.device)
+
+        avg_loss = 0.0
+        avg_psnr = 0.0
+        avg_psnr_gain = 0.0
+        num_elements = 0
+
+        use_multiplier = torch.ones(1).to(self.device)
+        cond_target_value = torch.ones(1).to(self.device)
+
+        for input_data, label_data in self.dataloader:
+            for k, v in input_data.items():
+                input_data[k] = v.to(self.device)
+            for k, v in label_data.items():
+                label_data[k] = v.to(self.device)
+
+            reco_Y = input_data["rec_before_dbf_Y"][
+                :,
+                :,
+                self.y_border_size : self.y_border_size + self.y_block_size,
+                self.y_border_size : self.y_border_size + self.y_block_size,
+            ]
+            reco_U = input_data["rec_before_dbf_U"][
+                :,
+                :,
+                self.uv_border_size : self.uv_border_size + self.uv_block_size,
+                self.uv_border_size : self.uv_border_size + self.uv_block_size,
+            ]
+            reco_V = input_data["rec_before_dbf_V"][
+                :,
+                :,
+                self.uv_border_size : self.uv_border_size + self.uv_block_size,
+                self.uv_border_size : self.uv_border_size + self.uv_block_size,
+            ]
+
+            prediction = Model(input_data, use_multiplier, cond_target_value)
+            prediction_Y = prediction[0][
+                :,
+                :,
+                self.y_border_size : self.y_border_size + self.y_block_size,
+                self.y_border_size : self.y_border_size + self.y_block_size,
+            ]
+            prediction_U, prediction_V = prediction[1][
+                :,
+                :,
+                self.uv_border_size : self.uv_border_size + self.uv_block_size,
+                self.uv_border_size : self.uv_border_size + self.uv_block_size,
+            ].split([1, 1], dim=1)
+
+            _, vtm_psnr, vtm_mask = compute_loss_psnr(
+                label_data, reco_Y, reco_U, reco_V, mse_loss
+            )
+            pred_loss, pred_psnr, pred_mask = compute_loss_psnr(
+                label_data, prediction_Y, prediction_U, prediction_V, mse_loss
+            )
+
+            mask = vtm_mask * pred_mask
+            vtm_psnr *= mask
+            pred_psnr *= mask
+            delta_psnr_wrt_vtm = pred_psnr - vtm_psnr
+
+            num_elements += torch.sum(mask[Colour.YCbCr])
+            avg_loss += torch.sum(pred_loss[Colour.YCbCr])
+            avg_psnr += torch.sum(pred_psnr[Colour.YCbCr])
+            avg_psnr_gain += torch.sum(delta_psnr_wrt_vtm[Colour.YCbCr])
+
+        avg_loss /= len(self.dataloader) * self.dataloader.batch_size
+        avg_psnr /= num_elements
+        avg_psnr_gain /= num_elements
+
+        del Model
+
+        return avg_psnr_gain, avg_psnr, avg_loss
+
+    def pytorch_to_nctm(self, model_file):
+        model_dict = model_file
+        model_data = {"parameters": {}, "reduction_method": "uniform"}
+        model_info = {
+            "parameter_type": {},
+            "parameter_dimensions": {},
+            "parameter_index": {},
+            "block_identifier": {},
+            "topology_storage_format": nnc_core.nnr_model.TopologyStorageFormat.NNR_TPL_TEF,
+            "topology_compression_format": nnc_core.nnr_model.TopologyCompressionFormat.NNR_PT_RAW,
+        }
+
+        for i, module_name in enumerate(model_dict):
+            var_data = model_dict[module_name].data
+            var_name = module_name
+            # shorten parameter names for multipliers
+            if ".multiplier" in module_name:
+                var_name = shorten_multiplier_name(var_name)
+
+            model_data["parameters"][var_name] = var_data.cpu().detach().numpy()
+            model_info["parameter_dimensions"][var_name] = model_data["parameters"][
+                var_name
+            ].shape
+            model_info["parameter_index"][var_name] = i
+
+            if ".weight" in module_name:
+                model_info["parameter_type"][var_name] = "conv.weight"
+            else:
+                model_info["parameter_type"][var_name] = "conv.bias"
+
+        return model_data, model_info
+
+    def save_state_dict(self, path, model_data):
+        pass
+
+    def train_model(
+        self,
+        parameters,
+    ):
+        pass
+
+    def restore_and_save(self, parameters: Dict, output_path: str) -> None:
+        model_dict = OrderedDict()
+        Model = copy.deepcopy(self.model)
+
+        for module_name in Model.state_dict():
+            var_name = module_name
+            # shorten parameter names for multipliers
+            if ".multiplier" in module_name:
+                var_name = shorten_multiplier_name(var_name)
+
+            var_data = torch.tensor(parameters[var_name])
+
+            model_dict[module_name] = var_data
+
+        Model.load_state_dict(model_dict)
+        torch.save(Model.state_dict(), output_path)
+        del Model
diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_with_multiplier.py b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_with_multiplier.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c6c5414d32c07599c21b5c6a572c1b75875c54b
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/models/model_lop2_with_multiplier.py
@@ -0,0 +1,468 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+
+class Multiplier(nn.Module):
+    """
+    Multiplier layer. It is initialised with ones
+    """
+
+    def __init__(self, units=1):
+        super(Multiplier, self).__init__()
+        self.units = units
+        self._multiplier = torch.ones(size=(self.units, 1, 1), dtype=torch.float32)
+        self.multiplier = torch.nn.Parameter(self._multiplier, requires_grad=True)
+
+    def forward(self, x):
+        return x * self.multiplier
+
+
+class Conv(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: Union[int, Tuple[int, int]],
+        stride: Union[int, Tuple[int, int]] = 1,
+        padding: Optional[Union[int, Tuple[int, int]]] = None,
+        is_separable: bool = False,
+        hidden_separable_channels: Optional[int] = None,
+        post_activation: Optional[Type] = nn.PReLU,
+        mul: bool = False,
+        **kwargs,
+    ):
+        """
+        Args:
+            in_channels: the number of input channels
+            out_channels: the number of output channels
+            kernel_size: the convolution's kernel size
+            stride: the convolution's stride(s)
+            padding: the convolution's padding
+            is_separable: whether to implement convolution separably
+            hidden_separable_channels: If is_separable, the number of hidden channels between convolutions. If None, use out_channels
+            post_activation: activation function to use after convolution. If None, no activation after convolution
+            **kwargs: additional kwargs to pass to nn.Conv2d
+        """
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = (
+            (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+        )
+        self.stride = (stride, stride) if isinstance(stride, int) else stride
+        if padding is not None:
+            self.padding = (padding, padding) if isinstance(padding, int) else padding
+        else:
+            self.padding = tuple([k // 2 for k in self.kernel_size])
+        self.is_separable = is_separable
+        self.mul = mul
+
+        if self.is_separable:
+            self.hidden_separable_channels = hidden_separable_channels or out_channels
+            self.sep_conv1 = nn.Conv2d(
+                self.in_channels,
+                self.hidden_separable_channels,
+                (self.kernel_size[0], 1),
+                (self.stride[0], 1),
+                (self.padding[0], 0),
+                groups=self.hidden_separable_channels,
+                **kwargs,
+            )
+            if self.mul:
+                self.sep_mul1 = Multiplier(self.hidden_separable_channels)
+            self.sep_conv2 = nn.Conv2d(
+                self.hidden_separable_channels,
+                self.out_channels,
+                (1, self.kernel_size[1]),
+                (1, self.stride[1]),
+                (0, self.padding[1]),
+                groups=self.hidden_separable_channels,
+                **kwargs,
+            )
+            if self.mul:
+                self.sep_mul2 = Multiplier(self.out_channels)
+            self.sep_conv3 = nn.Conv2d(
+                self.out_channels,
+                self.out_channels,
+                (1, 1),
+                (1, 1),
+                (0, 0),
+                **kwargs,
+            )
+            if self.mul:
+                self.sep_mul3 = Multiplier(self.out_channels)
+        else:
+            self.nonsep_conv = nn.Conv2d(
+                self.in_channels,
+                self.out_channels,
+                self.kernel_size,
+                self.stride,
+                self.padding,
+                **kwargs,
+            )
+            if self.mul:
+                self.nonsep_mul = Multiplier(self.out_channels)
+
+        self.post_activation = None if post_activation is None else post_activation()
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        activate_multiplier: torch.Tensor = None,
+    ) -> torch.Tensor:
+        if self.is_separable:
+            y = self.sep_conv1(x)
+            if self.mul and activate_multiplier is not None:
+                m = self.sep_mul1(y)
+                y = torch.where(activate_multiplier, m, y)
+
+            y = self.sep_conv2(y)
+            if self.mul and activate_multiplier is not None:
+                m = self.sep_mul2(y)
+                y = torch.where(activate_multiplier, m, y)
+
+            y = self.sep_conv3(y)
+            if self.mul and activate_multiplier is not None:
+                m = self.sep_mul3(y)
+                y = torch.where(activate_multiplier, m, y)
+        else:
+            y = self.nonsep_conv(x)
+            if self.mul and activate_multiplier is not None:
+                m = self.nonsep_mul(y)
+                y = torch.where(activate_multiplier, m, y)
+        if self.post_activation is not None:
+            y = self.post_activation(y)
+
+        return y
+
+
+class MultiBranchModule(nn.Module):
+    """A module representing multple, parallel branches. If the input is a list, each element in the list is fed into the corresponding branch,
+    otherwise the input is fed into every branch. The outputs of each branch are then merged.
+    """
+
+    def __init__(self, *branch_modules, merge_dimension: int = -3):
+        """
+        Args:
+            branch_modules: modules to run in parallel
+            merge_dimension: the dimension to merge outputs from each branch
+        """
+        super().__init__()
+        self.branches = nn.ModuleList(branch_modules)
+        self.merge_dimension = merge_dimension
+
+    def forward(self, args: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
+        inputs = args if isinstance(args, list) else len(self.branches) * [args]
+        branch_outputs = [branch(input) for branch, input in zip(self.branches, inputs)]
+        return torch.cat(branch_outputs, dim=self.merge_dimension)
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, C: int = 64, C1: int = 160, C21: int = 32, mul: bool = False):
+        super().__init__()
+        self.conv1 = Conv(C, C1, kernel_size=1, mul=mul)
+        self.conv2 = Conv(C1, C, kernel_size=1, post_activation=None, mul=mul)
+        self.conv3 = Conv(
+            C,
+            C,
+            kernel_size=3,
+            post_activation=None,
+            is_separable=True,
+            hidden_separable_channels=C21,
+            mul=mul,
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        activate_multiplier: torch.Tensor,
+    ) -> torch.Tensor:
+        y = self.conv1(x, activate_multiplier)
+        y = self.conv2(y, activate_multiplier)
+        y = self.conv3(y, activate_multiplier)
+        return x + y
+
+
+class SplitLumaChromaBlocks(nn.Module):
+    def __init__(
+        self,
+        N_Y: int = 12,
+        N_UV: int = 6,
+        C: int = 16,
+        C1_Y: int = 64,
+        C1_UV: int = 48,
+        C21: int = 16,
+        output_channels_y: int = 4,
+        output_channels_uv: int = 2,
+    ):
+        super().__init__()
+
+        self.split_y_path = nn.ModuleList(
+            m
+            for m in [
+                *[ResidualBlock(C, C1_Y, C21, True) for _ in range(N_Y)],
+                Conv(
+                    C,
+                    C,
+                    kernel_size=3,
+                    is_separable=True,
+                    hidden_separable_channels=C21,
+                    mul=True,
+                ),
+                Conv(
+                    C, output_channels_y, kernel_size=3, post_activation=None, mul=True
+                ),
+            ]
+        )
+
+        self.split_uv_path = nn.ModuleList(
+            m
+            for m in [
+                *[ResidualBlock(C, C1_UV, C21, True) for _ in range(N_UV)],
+                Conv(
+                    C,
+                    C,
+                    kernel_size=3,
+                    is_separable=True,
+                    hidden_separable_channels=C21,
+                    mul=True,
+                ),
+                Conv(
+                    C, output_channels_uv, kernel_size=3, post_activation=None, mul=True
+                ),
+            ]
+        )
+
+        self.Cy = C
+        self.Cuv = C
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        activate_multiplier: torch.Tensor = None,
+    ) -> torch.Tensor:
+        split_y_input = x[:, : self.Cy, :, :]
+        split_uv_input = x[:, self.Cy : self.Cy + self.Cuv, :, :]
+
+        y_output = split_y_input
+        for m in self.split_y_path:
+            y_output = m(y_output, activate_multiplier)
+
+        uv_output = split_uv_input
+        for m in self.split_uv_path:
+            uv_output = m(uv_output, activate_multiplier)
+
+        return torch.cat((y_output, uv_output), dim=1)
+
+
+class SADLNet(nn.Module):
+    """The network used during SADL inference"""
+
+    def __init__(
+        self,
+        input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1],
+        input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3],
+        D1: int = 12,
+        D2: int = 8,
+        D3: int = 4,
+        D4: int = 2,
+        D5: int = 2,
+        D6: int = 24,
+        N_Y: int = 12,
+        N_UV: int = 6,
+        C: int = 16,
+        C1_Y: int = 64,
+        C1_UV: int = 48,
+        C21: int = 16,
+        output_channels_y: int = 4,
+        output_channels_uv: int = 2,
+    ):
+        """
+        Args:
+            input_channels: the number of channels expected for each input
+            input_kernels: the kernel size for each input convolution
+            output_channels: the number of output channels
+        """
+        super().__init__()
+        self.input_channels = input_channels
+        self.input_kernels = input_kernels
+        self.input_features = [D1, D2, D3, D4, D4, D5]
+
+        self.multi_branch = MultiBranchModule(
+            *[
+                Conv(c, d, kernel_size=k, post_activation=None)
+                for c, d, k in zip(
+                    self.input_channels, self.input_features, self.input_kernels
+                )
+            ]
+        )
+        self.conv1 = Conv(sum(self.input_features), D6, kernel_size=1)
+        self.conv2 = Conv(D6, C + C, kernel_size=3, stride=2)
+        self.split = SplitLumaChromaBlocks(
+            N_Y, N_UV, C, C1_Y, C1_UV, C21, output_channels_y, output_channels_uv
+        )
+
+    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+        y = self.multi_branch(inputs[:-2])
+        y = self.conv1(y)
+        y = self.conv2(y)
+        activate_multipliers = torch.eq(inputs[-2], inputs[-1])
+        y = self.split(y, activate_multipliers)
+        return y
+
+    def get_example_inputs(
+        self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1
+    ) -> List[torch.Tensor]:
+        patch_size = (
+            (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
+        )
+        x = [
+            torch.rand(
+                batch_size,
+                conv.in_channels,
+                *patch_size,
+                device=conv.nonsep_conv.weight.device,
+            )
+            for conv in self.multi_branch.branches
+        ]
+        x.append(torch.ones(1))
+        x.append(torch.ones(1))
+        return x
+
+    def to_onnx(
+        self, filename: str, patch_size: int = 144, batch_size: int = 1, **kwargs
+    ) -> None:
+        mode = self.training
+        self.eval()
+        torch.onnx.export(
+            self,
+            self.get_example_inputs(patch_size, batch_size),
+            filename,
+            **kwargs,
+        )
+        self.train(mode)
+
+
+class Net(nn.Module):
+    """Wrapper for SADL model that implements input pre- and post-processing for training."""
+
+    def __init__(
+        self,
+        input_channels: Iterable[Iterable[str]] = [
+            ["rec_before_dbf_Y", "rec_before_dbf_U", "rec_before_dbf_V"],
+            ["pred_Y", "pred_U", "pred_V"],
+            ["bs_Y", "bs_U", "bs_V"],
+            ["qp_base"],
+            ["qp_slice"],
+            ["ipb_Y"],
+        ],
+        input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3],
+        D1: int = 12,
+        D2: int = 8,
+        D3: int = 4,
+        D4: int = 2,
+        D5: int = 2,
+        D6: int = 24,
+        N_Y: int = 12,
+        N_UV: int = 6,
+        C: int = 16,
+        C1_Y: int = 64,
+        C1_UV: int = 48,
+        C21: int = 16,
+    ):
+        super(Net, self).__init__()
+        assert len(input_channels) == len(
+            input_kernels
+        ), "[ERROR] input size and kernels size not equal"
+        self.input_channels = input_channels
+        sizes = [len(a) for a in input_channels]
+        self.SADL_model = SADLNet(
+            sizes,
+            input_kernels,
+            D1,
+            D2,
+            D3,
+            D4,
+            D5,
+            D6,
+            N_Y,
+            N_UV,
+            C,
+            C1_Y,
+            C1_UV,
+            C21,
+            4,
+            2,
+        )
+        self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest")
+
+    def preprocess_args(self, batch: Dict[str, torch.Tensor]) -> List[torch.Tensor]:
+        for name, data in batch.items():
+            if "U" in name or "V" in name:
+                batch[name] = self.chroma_upsampler(batch[name])
+
+        return [
+            torch.cat([batch[name] for name in input_], dim=1)
+            for input_ in self.input_channels
+        ]
+
+    def postprocess_outputs(
+        self, batch: Dict[str, torch.Tensor], out: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        Y_res, UV_res = out.split([4, 2], dim=1)
+        return (
+            batch["rec_before_dbf_Y"] + F.pixel_shuffle(Y_res, 2),
+            torch.cat((batch["rec_before_dbf_U"], batch["rec_before_dbf_V"]), dim=1)[
+                ..., ::2, ::2
+            ]
+            + UV_res,
+        )
+
+    def forward(
+        self,
+        batch: Dict[str, torch.Tensor],
+        equal_left_side: Tensor,
+        equal_right_side: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        args = self.preprocess_args(batch)
+        args.append(equal_left_side)
+        args.append(equal_right_side)
+        out = self.SADL_model(args)
+        return self.postprocess_outputs(batch, out)
diff --git a/training/training_scripts/NN_Adaptive_Filtering/models/models.json b/training/training_scripts/NN_Adaptive_Filtering/models/models.json
new file mode 100755
index 0000000000000000000000000000000000000000..8afb5299ce24929cad0e3b09a975203456feecb4
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/models/models.json
@@ -0,0 +1,51 @@
+{
+  "lop2":
+  {
+    "model":
+    {
+      "class" : "Net",
+      "input_channels" :
+      [
+        [
+          "rec_before_dbf_Y",
+          "rec_before_dbf_U",
+          "rec_before_dbf_V"
+        ],
+        [
+          "pred_Y",
+          "pred_U",
+          "pred_V"
+        ],
+        [
+          "bs_Y",
+          "bs_U",
+          "bs_V"
+        ],
+        [ "qp_base" ],
+        [ "qp_slice" ],
+        [ "ipb_Y" ]
+      ],
+      "input_kernels" :
+      [
+          3,
+          3,
+          1,
+          1,
+          1,
+          1
+      ],
+      "D1" : 12,
+      "D2" : 8,
+      "D3" : 4,
+      "D4" : 2,
+      "D5" : 2,
+      "D6" : 24,
+      "N_Y" : 14,
+      "N_UV" : 4,  
+      "C" : 16,
+      "C1_Y" : 64,
+      "C1_UV" : 32,
+      "C21" : 16
+    }
+  }
+}
diff --git a/training/training_scripts/NN_Adaptive_Filtering/overfitting_pipeline.py b/training/training_scripts/NN_Adaptive_Filtering/overfitting_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..293dd18380130f04d081c44aaa35d288a66c5f30
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/overfitting_pipeline.py
@@ -0,0 +1,396 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import sys
+from tempfile import NamedTemporaryFile
+
+import click
+import numpy as np
+from torch.utils.data.dataloader import DataLoader
+from tqdm import tqdm
+
+from trainer.nn_filter import NnFilter
+from util.dataset_bin import NnFilterBinDataset
+from util.dataset_yuv import NnFilterYuvDataset
+from util.file_system import read_json_file
+from wu_encoding import compress_one_model_and_2sadl
+
+
+@click.command()
+@click.option(
+    "--deco_dir_train",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains decoded training data",
+)
+@click.option(
+    "--orig_dir_train",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains original training data",
+)
+@click.option(
+    "--prop_file_train",
+    default=None,
+    type=click.Path(),
+    help="Properties file for the training data",
+)
+@click.option(
+    "--seq_qp_train", default=list(), multiple=True, help="List of training QPs"
+)
+@click.option(
+    "--slice_type",
+    default="B",
+    type=click.Choice(["I", "B", "*"]),
+    help="Slice type",
+)
+@click.option("--min_slice_qp", default=0, type=int, help="Minimum slice QP (included)")
+@click.option(
+    "--max_slice_qp", default=63, type=int, help="Maximum slice QP (included)"
+)
+@click.option(
+    "--block_size",
+    default=128,
+    type=int,
+    help="Size of the luma block",
+)
+@click.option(
+    "--border_size",
+    default=8,
+    type=int,
+    help="Size of the border to be added to the input patch (one side for chroma)",
+)
+@click.option(
+    "--random_patches",
+    default=False,
+    is_flag=True,
+    help="Patches are selected randomly",
+)
+@click.option(
+    "--num_patches",
+    default=10,
+    type=int,
+    help="Number of patches to extract from each frame",
+)
+@click.option("--bit_depth", default=10, type=int, help="Sequence bit depth")
+@click.option("--seq_name", default=None, type=str, help="Sequence name/tag")
+@click.option(
+    "--num_workers", default=16, type=int, help="Number of threads used to load data"
+)
+@click.option("--start_epoch", default=0, type=int, help="Starting epoch")
+@click.option("--max_epochs", default=100, type=int, help="Max number of epochs")
+@click.option("--lr", default=0.001, type=float, help="Learning rate")
+@click.option(
+    "--lr_patience",
+    default=30,
+    type=int,
+    help="Number of epochs to drop the learning rate when the loss stops improving",
+)
+@click.option(
+    "--lr_gamma",
+    default=0.1,
+    type=float,
+    help="Learning rate drop factor for the scheduler",
+)
+@click.option(
+    "--weight_decay",
+    default=0.0,
+    type=float,
+    help="Weight decay for regularisation",
+)
+@click.option("--batch_size", default=64, type=int, help="Batch size")
+@click.option("--arch", default=None, type=str, help="Model architecture")
+@click.option(
+    "--base_nn_params",
+    default=None,
+    type=click.Path(),
+    help="Abs path to NN parameters of base model",
+)
+@click.option(
+    "--base_quantizers",
+    default=None,
+    type=click.Path(),
+    help="Abs path to base model quantizers",
+)
+@click.option(
+    "--map_orig_params_to_new",
+    default=None,
+    type=click.Path(),
+    help="Abs path to file that maps original parameters to new parameters",
+)
+@click.option(
+    "--nn_params", default=None, type=click.Path(), help="Abs path to NN parameters"
+)
+@click.option(
+    "--optimiser_params",
+    default=None,
+    type=click.Path(),
+    help="Abs path to optimiser parameters",
+)
+@click.option(
+    "--lr_scheduler_params",
+    default=None,
+    type=click.Path(),
+    help="Abs path to learning rate scheduler parameters",
+)
+@click.option(
+    "--loss_op", default="mse", type=click.Choice(["mse", "mae"]), help="Loss operation"
+)
+@click.option(
+    "--stop_patience",
+    default=50,
+    type=int,
+    help="Number of epochs to stop training when the metrics stops improving",
+)
+@click.option(
+    "--output_dir",
+    default="/tmp/output_dir",
+    type=click.Path(),
+    help="Output directory",
+)
+@click.option(
+    "--min_poc",
+    default=0,
+    type=int,
+    help="minimum poc idx",
+)
+@click.option(
+    "--max_poc",
+    default=1000,
+    type=int,
+    help="max poc idx",
+)
+@click.option(
+    "--part_idx",
+    default=0,
+    type=int,
+    help="Random Access segment idx",
+)
+@click.option(
+    "--select_layer",
+    is_flag=True,
+    help="if select layer during the training",
+)
+@click.option(
+    "--num_param",
+    default=1000,
+    type=int,
+    help="number of parameters to be selected during the training",
+)
+@click.option(
+    "--nnr_models_dir", type=click.Path(), default=None, help="Path to nnr models"
+)
+@click.option(
+    "--onnx_models_dir", type=click.Path(), default=None, help="Path to onnx models"
+)
+@click.option(
+    "--sadl_float_dir",
+    type=click.Path(),
+    default=None,
+    help="Path to SADL float models",
+)
+@click.option(
+    "--sadl_int_dir",
+    type=click.Path(),
+    default=None,
+    help="Path to SADL integer models",
+)
+@click.option(
+    "--quantizers_dir", type=click.Path(), default=None, help="Path to quantizers"
+)
+def run_pipeline(
+    deco_dir_train,
+    orig_dir_train,
+    prop_file_train,
+    seq_qp_train,
+    slice_type,
+    min_slice_qp,
+    max_slice_qp,
+    block_size,
+    border_size,
+    random_patches,
+    num_patches,
+    bit_depth,
+    seq_name,
+    num_workers,
+    start_epoch,
+    max_epochs,
+    lr,
+    lr_patience,
+    lr_gamma,
+    weight_decay,
+    batch_size,
+    arch,
+    base_nn_params,
+    base_quantizers,
+    map_orig_params_to_new,
+    nn_params,
+    optimiser_params,
+    lr_scheduler_params,
+    loss_op,
+    stop_patience,
+    output_dir,
+    min_poc,
+    max_poc,
+    part_idx,
+    select_layer,
+    num_param,
+    nnr_models_dir,
+    onnx_models_dir,
+    sadl_float_dir,
+    sadl_int_dir,
+    quantizers_dir,
+):
+    assert block_size % 2 == 0, "The block size must be multiple of 2"
+    assert border_size % 2 == 0, "The border size must be multiple of 2"
+
+    if "lop2" != arch:
+        print(f"{arch} architecture not supported")
+        sys.exit(-1)
+
+    model_config = read_json_file("models/models.json")[arch]
+
+    train_data_yuv = NnFilterYuvDataset(
+        deco_dir_train,
+        orig_dir_train,
+        prop_file_train,
+        seq_name,
+        seq_qp_train,
+        bit_depth,
+        block_size,
+        border_size,
+        slice_type,
+        min_slice_qp,
+        max_slice_qp,
+        random_patches,
+        num_patches,
+        min_poc,
+        max_poc,
+    )
+
+    input_keys = [
+        "rec_before_dbf_Y",
+        "rec_before_dbf_U",
+        "rec_before_dbf_V",
+        "pred_Y",
+        "pred_U",
+        "pred_V",
+        "bs_Y",
+        "bs_U",
+        "bs_V",
+        "ipb_Y",
+    ]
+    label_keys = ["orig_Y", "orig_U", "orig_V", "mask_Y", "mask_U", "mask_V"]
+    num_patches = len(train_data_yuv)
+    # Store pre-processed data as Temporary file to save training time
+    with NamedTemporaryFile(
+        dir=".", suffix=".bin", prefix=f"{seq_name}_{seq_qp_train[0]}_part{part_idx}_"
+    ) as tmp_data:
+        print("[INFO] Dump {} patches in {}".format(num_patches, tmp_data.name))
+
+        for idx in tqdm(range(num_patches)):
+            input_data, label_data, base_qp, slice_qp = train_data_yuv[idx]
+            for k in input_keys:
+                data = input_data[k].cpu().detach().numpy()
+                data.tofile(tmp_data)
+            for k in label_keys:
+                data = label_data[k].cpu().detach().numpy()
+                data.tofile(tmp_data)
+            np.array([base_qp, slice_qp], dtype=np.float32).tofile(tmp_data)
+
+        train_data = NnFilterBinDataset(
+            tmp_data.name,
+            block_size,
+            border_size,
+            input_keys,
+            label_keys,
+            num_patches,
+            np.float32,
+        )
+
+        train_loader = DataLoader(
+            train_data,
+            batch_size=batch_size,
+            shuffle=True,
+            drop_last=True,
+            num_workers=num_workers,
+            pin_memory=True,
+        )
+
+        print("Size of training data: ", num_patches)
+        nn_filter = NnFilter(
+            model_config,
+            lr,
+            lr_patience,
+            lr_gamma,
+            weight_decay,
+            block_size,
+            border_size,
+            base_nn_params,
+            nn_params,
+            optimiser_params,
+            lr_scheduler_params,
+            loss_op,
+            stop_patience,
+            output_dir,
+            select_layer,
+        )
+
+        nn_filter.train_loop(start_epoch, max_epochs, train_loader, num_param)
+
+        # NNR Compression
+        compress_one_model_and_2sadl(
+            output_dir,
+            base_nn_params,
+            model_config,
+            base_quantizers,
+            map_orig_params_to_new,
+            nnr_models_dir,
+            onnx_models_dir,
+            sadl_float_dir,
+            sadl_int_dir,
+            quantizers_dir,
+            train_loader,
+            block_size,
+            border_size,
+        )
+
+
+def main(*args, **kwargs):
+    print("call: {}".format(" ".join(sys.argv)))
+    run_pipeline()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/training/training_scripts/NN_Adaptive_Filtering/resources/config.json b/training/training_scripts/NN_Adaptive_Filtering/resources/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..a65c6bdefbe5a37ec270a8b652e241f07c2dbd33
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/resources/config.json
@@ -0,0 +1,36 @@
+{
+  "architecture": "lop2",
+  "sadl2torch": {
+    "base_model_sadl_i": "/src/models/nnlf_lop2_model_int16.sadl",
+    "base_model_torch": "resources/nnlf_lop2_model.pt",
+    "base_model_quantizers": "resources/nnlf_lop2_quantizers.json",
+    "base_model_params_to_new": "resources/nnlf_lop2_params_to_new_model.json"
+  },
+  "decoder_bin": "/path/to/decoder/bin",
+  "intra_filters": "/path/to/intra/filters",
+  "dataset": {
+    "orig_path": "/path/to/dataset/orig",
+    "deco_path": "/path/to/dataset/deco",
+    "bit-depth": 10,
+    "data_cfg": "resources/datasets/jvet.json"
+  },
+  "nnvc": {
+    "log": "log_dec.txt",
+    "bs": "bs.yuv",
+    "rec": "rec_before_dbf.yuv",
+    "pred": "pred.yuv",
+    "ipb": "bpm.yuv"
+  },
+  "training": {
+    "output_path": "/path/to/output",
+    "num_models": 1,
+    "batch-size": 64,
+    "block-size": 128,
+    "pad-size": 8,
+    "min-tid": 0,
+    "num_layers": 106,
+    "slice-type": "*",
+    "input-keys": ["rec_before_dbf_Y", "rec_before_dbf_U", "rec_before_dbf_V", "pred_Y", "pred_U", "pred_V", "bs_Y", "bs_U", "bs_V", "ipb_Y"],
+    "label-keys": ["orig_Y", "orig_U", "orig_V", "mask_Y", "mask_U", "mask_V"]
+  }
+}
diff --git a/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet.json b/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet.json
new file mode 100644
index 0000000000000000000000000000000000000000..ebc62b38bacb61cdef2b9cbfe20317267c1228cf
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet.json
@@ -0,0 +1,212 @@
+{
+  "A1_Tango": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 294,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "A1_FoodMarket": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "A1_CampfireParty": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 30,
+    "bit-depth": 10
+  },
+  "A2_CatRobot": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "A2_DaylightRoad": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "A2_ParkRunning": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 50,
+    "bit-depth": 10
+  },
+  "B_MarketPlace": {
+    "width": 1920,
+    "height": 1080,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "B_RitualDance": {
+    "width": 1920,
+    "height": 1080,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "B_Cactus": {
+    "width": 1920,
+    "height": 1080,
+    "frames": 500,
+    "fps": 50,
+    "bit-depth": 8
+  },
+  "B_BasketBallDrive": {
+    "width": 1920,
+    "height": 1080,
+    "frames": 500,
+    "fps": 50,
+    "bit-depth": 8
+  },
+  "B_BQTerrace": {
+    "width": 1920,
+    "height": 1080,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 8
+  },
+  "C_BasketballDrill": {
+    "width": 832,
+    "height": 480,
+    "frames": 500,
+    "fps": 50,
+    "bit-depth": 8
+  },
+  "C_BQMall": {
+    "width": 832,
+    "height": 480,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 8
+  },
+  "C_PartyScene": {
+    "width": 832,
+    "height": 480,
+    "frames": 500,
+    "fps": 50,
+    "bit-depth": 8
+  },
+  "C_RaceHorses_big": {
+    "width": 832,
+    "height": 480,
+    "frames": 300,
+    "fps": 30,
+    "bit-depth": 8
+  },
+  "D_BasketBallPass": {
+    "width": 416,
+    "height": 240,
+    "frames": 500,
+    "fps": 50,
+    "bit-depth": 8
+  },
+  "D_BQSquare": {
+    "width": 416,
+    "height": 240,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 8
+  },
+  "D_BlowingBubbles": {
+    "width": 416,
+    "height": 240,
+    "frames": 500,
+    "fps": 50,
+    "bit-depth": 8
+  },
+  "D_RaceHorses_s": {
+    "width": 416,
+    "height": 240,
+    "frames": 300,
+    "fps": 30,
+    "bit-depth": 8
+  },
+  "E_FourPeople": {
+    "width": 1280,
+    "height": 720,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 8
+  },
+  "E_Johnny": {
+    "width": 1280,
+    "height": 720,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 8
+  },
+  "E_KristenAndSara": {
+    "width": 1280,
+    "height": 720,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 8
+  },
+  "F_BBDrillText": {
+    "width": 832,
+    "height": 480,
+    "frames": 500,
+    "fps": 50,
+    "bit-depth": 8
+  },
+  "F_ArenaOfValor": {
+    "width": 1920,
+    "height": 1080,
+    "frames": 600,
+    "fps": 60,
+    "bit-depth": 8
+  },
+  "F_SlideEditing": {
+    "width": 1280,
+    "height": 720,
+    "frames": 300,
+    "fps": 30,
+    "bit-depth": 8
+  },
+  "F_SlideShow": {
+    "width": 1280,
+    "height": 720,
+    "frames": 500,
+    "fps": 20,
+    "bit-depth": 8
+  },
+  "H2_DayStreet": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "H2_FlyingBirds2": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "H2_PeopleInShop": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 60,
+    "bit-depth": 10
+  },
+  "H2_SunsetBeach2": {
+    "width": 3840,
+    "height": 2160,
+    "frames": 300,
+    "fps": 60,
+    "bit-depth": 10
+  }
+}
diff --git a/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet_labels.json b/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet_labels.json
new file mode 100644
index 0000000000000000000000000000000000000000..013994fe657b34ac15b378d3ed4cc2bae8463d0c
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/resources/datasets/jvet_labels.json
@@ -0,0 +1,25 @@
+{
+  "Tango2_3840x2160_60fps_10bit_420": "A1_Tango",
+  "FoodMarket4_3840x2160_60fps_10bit_420": "A1_FoodMarket",
+  "Campfire_3840x2160_30fps_bt709_420_videoRange": "A1_CampfireParty",
+  "CatRobot1_3840x2160p_60_10_709_420": "A2_CatRobot",
+  "DaylightRoad2_3840x2160_60fps_10bit_420": "A2_DaylightRoad",
+  "ParkRunning3_3840x2160_50fps_10bit_420": "A2_ParkRunning",
+  "MarketPlace_1920x1080_60fps_10bit_420": "B_MarketPlace",
+  "RitualDance_1920x1080_60fps_10bit_420": "B_RitualDance",
+  "Cactus_1920x1080_50": "B_Cactus",
+  "BasketballDrive_1920x1080_50": "B_BasketBallDrive",
+  "BQTerrace_1920x1080_60": "B_BQTerrace",
+  "BasketballDrill_832x480_50": "C_BasketballDrill",
+  "BQMall_832x480_60": "C_BQMall",
+  "PartyScene_832x480_50": "C_PartyScene",
+  "RaceHorses_832x480_30": "C_RaceHorses_big",
+  "BasketballPass_416x240_50": "D_BasketBallPass",
+  "BQSquare_416x240_60": "D_BQSquare",
+  "BlowingBubbles_416x240_50": "D_BlowingBubbles",
+  "RaceHorses_416x240_30": "D_RaceHorses_s",
+  "BasketballDrillText_832x480_50": "F_BBDrillText",
+  "ArenaOfValor_1920x1080_60_8bit_420": "F_ArenaOfValor",
+  "SlideEditing_1280x720_30": "F_SlideEditing",
+  "SlideShow_1280x720_20": "F_SlideShow"
+}
diff --git a/training/training_scripts/NN_Adaptive_Filtering/resources/num_layers.json b/training/training_scripts/NN_Adaptive_Filtering/resources/num_layers.json
new file mode 100644
index 0000000000000000000000000000000000000000..edbdd9166128cae5d0e7bb5d94ad453af7f9c7b4
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/resources/num_layers.json
@@ -0,0 +1,163 @@
+{
+    "A1_Tango": {
+        "22": 154,
+        "27": 64,
+        "32": 33,
+        "37": 19,
+        "42": 10
+    },
+    "A1_FoodMarket": {
+        "22": 144,
+        "27": 72,
+        "32": 37,
+        "37": 20,
+        "42": 10
+    },
+    "A1_CampfireParty": {
+        "22": 457,
+        "27": 120,
+        "32": 64,
+        "37": 34,
+        "42": 16
+    },
+    "A2_CatRobot": {
+        "22": 165,
+        "27": 77,
+        "32": 39,
+        "37": 21,
+        "42": 11
+    },
+    "A2_DaylightRoad": {
+        "22": 206,
+        "27": 79,
+        "32": 38,
+        "37": 20,
+        "42": 10
+    },
+    "A2_ParkRunning": {
+        "22": 1166,
+        "27": 468,
+        "32": 207,
+        "37": 95,
+        "42": 40
+    },
+    "B_MarketPlace": {
+        "22": 398,
+        "27": 161,
+        "32": 73,
+        "37": 34,
+        "42": 15
+    },
+    "B_RitualDance": {
+        "22": 291,
+        "27": 148,
+        "32": 78,
+        "37": 42,
+        "42": 22
+    },
+    "B_Cactus": {
+        "22": 362,
+        "27": 145,
+        "32": 70,
+        "37": 35,
+        "42": 18
+    },
+    "B_BasketBallDrive": {
+        "22": 437,
+        "27": 178,
+        "32": 86,
+        "37": 44,
+        "42": 23
+    },
+    "B_BQTerrace": {
+        "22": 678,
+        "27": 120,
+        "32": 49,
+        "37": 24,
+        "42": 12
+    },
+    "C_BasketballDrill": {
+        "22": 49,
+        "27": 23,
+        "32": 11,
+        "37": 6,
+        "42": 3
+    },
+    "C_BQMall": {
+        "22": 41,
+        "27": 19,
+        "32": 9,
+        "37": 5,
+        "42": 2
+    },
+    "C_PartyScene": {
+        "22": 99,
+        "27": 44,
+        "32": 20,
+        "37": 10,
+        "42": 4
+    },
+    "C_RaceHorses_big": {
+        "22": 57,
+        "27": 24,
+        "32": 11,
+        "37": 5,
+        "42": 2
+    },
+    "D_BasketBallPass": {
+        "22": 23,
+        "27": 11,
+        "32": 5,
+        "37": 2,
+        "42": 1
+    },
+    "D_BQSquare": {
+        "22": 18,
+        "27": 6,
+        "32": 3,
+        "37": 1,
+        "42": 1
+    },
+    "D_BlowingBubbles": {
+        "22": 23,
+        "27": 10,
+        "32": 5,
+        "37": 2,
+        "42": 1
+    },
+    "D_RaceHorses_s": {
+        "22": 14,
+        "27": 7,
+        "32": 3,
+        "37": 1,
+        "42": 1
+    },    
+    "F_BBDrillText": {
+        "22": 65,
+        "27": 30,
+        "32": 15,
+        "37": 8,
+        "42": 4
+    },
+    "F_ArenaOfValor": {
+        "22": 228,
+        "27": 99,
+        "32": 47,
+        "37": 24,
+        "42": 12
+    },
+    "F_SlideEditing": {
+        "22": 15,
+        "27": 11,
+        "32": 8,
+        "37": 6,
+        "42": 4
+    },
+    "F_SlideShow": {
+        "22": 13,
+        "27": 7,
+        "32": 4,
+        "37": 2,
+        "42": 1
+    }
+}
\ No newline at end of file
diff --git a/training/training_scripts/NN_Adaptive_Filtering/run/TMLogger.py b/training/training_scripts/NN_Adaptive_Filtering/run/TMLogger.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b1a5d880ccdb61313da1a11671aa4f939fd77b
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/run/TMLogger.py
@@ -0,0 +1,252 @@
+"""
+The copyright in this software is being made available under this Software
+Copyright License. This software may be subject to other third party and
+contributor rights, including patent rights, and no such rights are
+granted under this license.
+Copyright (c) 1995 - 2021 Fraunhofer-Gesellschaft zur Förderung der
+angewandten Forschung e.V. (Fraunhofer)
+All rights reserved.
+Redistribution and use in source and binary forms, with or without
+modification, are permitted for purpose of testing the functionalities of
+this software provided that the following conditions are met:
+*     Redistributions of source code must retain the above copyright notice,
+this list of conditions and the following disclaimer.
+*     Redistributions in binary form must reproduce the above copyright
+notice, this list of conditions and the following disclaimer in the
+documentation and/or other materials provided with the distribution.
+*     Neither the names of the copyright holders nor the names of its
+contributors may be used to endorse or promote products derived from this
+software without specific prior written permission.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
+CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
+INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
+NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
+ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+NO EXPRESS OR IMPLIED LICENSES TO ANY PATENT CLAIMS, INCLUDING
+WITHOUT LIMITATION THE PATENTS OF THE COPYRIGHT HOLDERS AND
+CONTRIBUTORS, ARE GRANTED BY THIS SOFTWARE LICENSE. THE
+COPYRIGHT HOLDERS AND CONTRIBUTORS PROVIDE NO WARRANTY OF PATENT
+NON-INFRINGEMENT WITH RESPECT TO THIS SOFTWARE.
+"""
+
+
+# Format changed
+
+
+import logging
+import os
+import sys
+from timeit import default_timer as timer
+
+# custom timing class
+from memory_profiler import MemTimer, Pipe, choose_backend
+
+
+class TMLogger:
+    def __init__(self, logfile, overwrite):
+        self.timings = []
+        self.memory = []
+        # LOGGER
+        self.FORMAT = logging.Formatter("iNCTM - %(levelname)s - %(message)s")
+        logging.root.setLevel(logging.NOTSET)
+        self.LOGGER = logging.getLogger()
+        if overwrite and os.path.exists(logfile):
+            os.remove(logfile)
+
+        if os.path.isfile(logfile):
+            print("Logfile {} exists. Exiting...".format(logfile))
+            sys.exit(1)
+
+        _fileh = logging.FileHandler(logfile)
+        _fileh.setFormatter(self.FORMAT)
+        _fileh.setLevel(logging.INFO)
+        self.LOGGER.addHandler(_fileh)
+
+
+class TimerMemory:
+    def __init__(
+        self,
+        tm_logger,
+        msg,
+        scope="",
+        interval=0.01,
+        timestamps=False,
+        include_children=False,
+        max_usage=False,
+        backend=None,
+    ):
+        self.tm_logger = tm_logger
+        self.msg = msg
+        self.scope = msg if scope == "" else scope
+        self.child_conn, self.parent_conn = Pipe()
+        self.backend = choose_backend(backend)
+        self.interval = interval
+        self.timestamps = timestamps
+        self.include_children = include_children
+        self.max_usage = max_usage
+
+    def __enter__(self):
+        self.t1 = timer()
+        p = MemTimer(
+            os.getpid(),
+            self.interval,
+            self.child_conn,
+            self.backend,
+            timestamps=self.timestamps,
+            max_usage=self.max_usage,
+            include_children=self.include_children,
+        )
+        p.start()
+        self.parent_conn.recv()  # wait until we start getting memory
+        return self
+
+    def __exit__(self, *args):
+        self.t2 = timer()
+        res = self.t2 - self.t1
+        if self.tm_logger is not None:
+            self.tm_logger.LOGGER.info(self.msg + ": {}".format(res))
+        self.parent_conn.send(0)  # finish timing
+        ret = self.parent_conn.recv()
+        # n_measurements = self.parent_conn.recv()
+        # LOGGER.info('Memory usage ({}): min: {}, max: {}, peak increase: {}'.format(self.scope, min(ret), max(ret), max(ret) - min(ret)))
+        if self.tm_logger is not None:
+            self.tm_logger.LOGGER.info(
+                "      Memory usage {} (peak increase): {}".format(
+                    self.scope, max(ret) - min(ret)
+                )
+            )
+            self.tm_logger.memory.append(max(ret) - min(ret))
+            self.tm_logger.timings.append(res)
+
+
+class TMLoggerICNN:
+    def __init__(self, logfile, overwrite):
+        self.timings = {}
+        self.memory = {}
+        # LOGGER
+        self.FORMAT = logging.Formatter("ICNN - %(levelname)s - %(message)s")
+        logging.root.setLevel(logging.NOTSET)
+        self.LOGGER = logging.getLogger()
+        if overwrite and os.path.exists(logfile):
+            os.remove(logfile)
+
+        if os.path.isfile(logfile):
+            print("Logfile {} exists. Exiting...".format(logfile))
+            sys.exit(1)
+
+        _fileh = logging.FileHandler(logfile)
+        _fileh.setFormatter(self.FORMAT)
+        _fileh.setLevel(logging.INFO)
+        self.LOGGER.addHandler(_fileh)
+
+
+class TimerMemoryICNN:
+    def __init__(
+        self,
+        tm_logger,
+        epoch,
+        client,
+        tag,
+        msg,
+        scope="",
+        interval=0.01,
+        timestamps=False,
+        include_children=False,
+        max_usage=False,
+        backend=None,
+    ):
+        self.tm_logger = tm_logger
+        self.msg = msg
+        self.scope = msg if scope == "" else scope
+        self.child_conn, self.parent_conn = Pipe()
+        self.backend = choose_backend(backend)
+        self.interval = interval
+        self.timestamps = timestamps
+        self.include_children = include_children
+        self.max_usage = max_usage
+        self.tag = tag
+        self.epoch = epoch
+        self.client = client
+
+    def __enter__(self):
+        self.t1 = timer()
+        p = MemTimer(
+            os.getpid(),
+            self.interval,
+            self.child_conn,
+            self.backend,
+            timestamps=self.timestamps,
+            max_usage=self.max_usage,
+            include_children=self.include_children,
+        )
+        p.start()
+        self.parent_conn.recv()  # wait until we start getting memory
+        return self
+        # pass
+
+    def __exit__(self, *args):
+        self.t2 = timer()
+        res = self.t2 - self.t1
+        if self.tm_logger is not None:
+            self.tm_logger.LOGGER.info(self.msg + ": {}".format(res))
+        self.parent_conn.send(0)  # finish timing
+        ret = self.parent_conn.recv()
+        # n_measurements = self.parent_conn.recv()
+        # LOGGER.info('Memory usage ({}): min: {}, max: {}, peak increase: {}'.format(self.scope, min(ret), max(ret), max(ret) - min(ret)))
+        if self.tm_logger is not None:
+            self.tm_logger.LOGGER.info(
+                "      Memory usage {} (peak increase): {}".format(
+                    self.scope, max(ret) - min(ret)
+                )
+            )
+            if (
+                self.epoch in self.tm_logger.memory.keys()
+                and self.epoch in self.tm_logger.timings.keys()
+            ):
+                if self.client in self.tm_logger.memory[self.epoch].keys():
+                    if (
+                        self.tag + "_m"
+                        in self.tm_logger.memory[self.epoch][self.client].keys()
+                        or self.tag + "_t"
+                        in self.tm_logger.timings[self.epoch][self.client].keys()
+                    ):
+                        assert (
+                            0
+                        ), "Tag already used: {} ! Logging requires unique tags!".format(
+                            self.tag
+                        )
+                    else:
+                        self.tm_logger.memory[self.epoch][self.client][
+                            self.tag + "_m"
+                        ] = max(ret) - min(ret)
+                        self.tm_logger.timings[self.epoch][self.client][
+                            self.tag + "_t"
+                        ] = res
+                else:
+                    self.tm_logger.memory[self.epoch][self.client] = {}
+                    self.tm_logger.timings[self.epoch][self.client] = {}
+                    self.tm_logger.memory[self.epoch][self.client][
+                        self.tag + "_m"
+                    ] = max(ret) - min(ret)
+                    self.tm_logger.timings[self.epoch][self.client][
+                        self.tag + "_t"
+                    ] = res
+            else:
+                self.tm_logger.memory[self.epoch] = {}
+                self.tm_logger.timings[self.epoch] = {}
+                self.tm_logger.memory[self.epoch][self.client] = {}
+                self.tm_logger.timings[self.epoch][self.client] = {}
+                self.tm_logger.memory[self.epoch][self.client][self.tag + "_m"] = max(
+                    ret
+                ) - min(ret)
+                self.tm_logger.timings[self.epoch][self.client][self.tag + "_t"] = res
+
+        # pass
diff --git a/training/training_scripts/NN_Adaptive_Filtering/run/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/run/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/training_scripts/NN_Adaptive_Filtering/segment_on_off.py b/training/training_scripts/NN_Adaptive_Filtering/segment_on_off.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9e7a496cd91bc1ecb0526c0f6aa990e8d9969f1
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/segment_on_off.py
@@ -0,0 +1,325 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import argparse
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from util.file_system import (
+    check_directory,
+    check_file,
+    create_directory,
+    list_dirs,
+    list_selected_dirs,
+    read_json_file,
+)
+from util.regex import get_data_from_encoder_log, ra_segment_rgx
+
+
+def get_encoding_data(root_dir: Path) -> pd.DataFrame:
+    seq_dirs = list()
+    sim_dirs = list_selected_dirs(root_dir, "vtm_ra")
+    for sim_dir in sim_dirs:
+        seq_dirs.extend(list_dirs(sim_dir))
+
+    sims_data = list()
+    for seq_dir in seq_dirs:
+        seq_name = seq_dir.name
+        if args.video_sequences is not None and seq_name not in args.video_sequences:
+            continue
+
+        fps = config[seq_name]["fps"]
+        for qp in args.qps:
+            enc_logs = list(seq_dir.glob(f"**/log_enc_{qp}*.txt"))
+            enc_logs = sorted(
+                enc_logs,
+                key=lambda s: int(ra_segment_rgx.search(str(s)).groups()[0]),
+            )
+            for enc_log in enc_logs:
+                part = int(ra_segment_rgx.search(enc_log.stem).group(1))
+                bitstream = seq_dir / f"{seq_dir.name}_{qp}_p{part}.266"
+
+                check_file(bitstream)
+                total_bits = 0
+
+                sum_psnr_y = 0
+                sum_psnr_u = 0
+                sum_psnr_v = 0
+
+                sum_mssim_y = 0
+                sum_mssim_u = 0
+                sum_mssim_v = 0
+
+                num_frames = 0
+
+                with open(enc_log, "r") as stream:
+                    poc_line = 0
+                    for line in stream:
+                        if line.startswith("POC"):
+                            poc_line += 1
+                            if poc_line == 1 and part > 0:
+                                continue
+                            (
+                                slice_type,
+                                slice_qp,
+                                bits,
+                                psnr_y,
+                                psnr_u,
+                                psnr_v,
+                                mssim_y,
+                                mssim_u,
+                                mssim_v,
+                            ) = get_data_from_encoder_log(line)
+
+                            sum_psnr_y += psnr_y
+                            sum_psnr_u += psnr_u
+                            sum_psnr_v += psnr_v
+
+                            sum_mssim_y += mssim_y
+                            sum_mssim_u += mssim_u
+                            sum_mssim_v += mssim_v
+
+                            total_bits += bits
+
+                            num_frames += 1
+
+                sims_data.append(
+                    [
+                        seq_name,
+                        qp,
+                        part,
+                        len(enc_logs),
+                        num_frames,
+                        fps,
+                        total_bits,
+                        sum_psnr_y,
+                        sum_psnr_u,
+                        sum_psnr_v,
+                        sum_mssim_y,
+                        sum_mssim_u,
+                        sum_mssim_v,
+                        bitstream,
+                    ]
+                )
+    df = pd.DataFrame(
+        sims_data,
+        columns=[
+            "sequence",
+            "qp",
+            "part",
+            "num_parts",
+            "num_frames",
+            "fps",
+            "bits",
+            "sum_psnr_y",
+            "sum_psnr_u",
+            "sum_psnr_v",
+            "sum_mssim_y",
+            "sum_mssim_u",
+            "sum_mssim_v",
+            "bitstream",
+        ],
+    )
+
+    df = df.set_index(["sequence", "qp", "part"])
+    return df
+
+
+def create_parcat_script(df: pd.DataFrame) -> None:
+    with open(output_dir / "run.sh", "w") as main_stream:
+        main_stream.write("#!/bin/bash\n\n")
+        main_stream.write(f"PARCAT={parcat_bin}\n\n")
+        tmp = df[["num_parts", "bitstream"]]
+        cmd = list()
+        for it, (index, row) in enumerate(tmp.iterrows()):
+            curr_seq, curr_qp, curr_part = index
+
+            curr_part = int(curr_part)
+            bitstream = row["bitstream"]
+            num_parts = row["num_parts"]
+
+            if curr_part == 0:
+                cmd = ["$PARCAT"]
+
+            # Adding best RA segment bitstream
+            cmd.append(bitstream)
+
+            if curr_part + 1 == num_parts:
+                # Adding merge output bitstream to main script
+                output_bitstream = output_dir / f"{curr_seq}_{curr_qp}.266"
+                cmd.append(output_bitstream)
+                main_stream.write(" ".join(str(s) for s in cmd))
+                main_stream.write("\n\n")
+
+                # Script to decode single merged bitstream
+                seq_script = output_dir / f"{curr_seq}_{curr_qp}.sh"
+                with open(seq_script, "w") as sequence_stream:
+                    log_file = output_dir / f"{curr_seq}_{curr_qp}_log_dec.txt"
+                    sequence_stream.write("#!/bin/bash\n\n")
+                    sequence_stream.write(f"DECODER={decoder_bin}\n\n")
+
+                    output_stem = output_dir / f"{curr_seq}_{curr_qp}"
+                    sequence_stream.write(
+                        f"$DECODER -b {output_bitstream} -o /dev/null --PrefixAbsolutePathsToGraphsOutput={intra_models_dir} --NnlfModelName={lop2_model} --NnfuOutputFileStem={output_stem} > {log_file}"
+                    )
+
+                # Adding cluster params and calling script to decode single merged bitstream
+                main_stream.write(
+                    f"qsub -q batch.q -terse -cwd -S /bin/bash -V -o {log_file} -j y -N {curr_seq}_{curr_qp} -l vf=8000M {seq_script}\n\n"
+                )
+
+
+def segment_on_off() -> None:
+    pass1_df = get_encoding_data(pass1_dir)
+    pass2_df = get_encoding_data(pass2_dir)
+
+    df = pass1_df.join(pass2_df, lsuffix="_pass1", rsuffix="_pass2")
+
+    df["pass1"] = (
+        df["sum_psnr_y_pass2"] / df["num_frames_pass2"]
+        - df["sum_psnr_y_pass1"] / df["num_frames_pass1"]
+        < 0.02
+    )
+
+    best = df.filter(["sequence", "qp", "part"], axis=1)
+    metrics = [
+        "bits",
+        "sum_psnr_y",
+        "sum_psnr_u",
+        "sum_psnr_v",
+        "sum_mssim_y",
+        "sum_mssim_u",
+        "sum_mssim_v",
+        "bitstream",
+        "num_parts",
+    ]
+    for metric in metrics:
+        best[metric] = np.where(
+            df["pass1"], df[f"{metric}_pass1"], df[f"{metric}_pass2"]
+        )
+
+    create_parcat_script(best)
+
+    best["fps"] = df["fps_pass1"]
+    best["num_frames"] = df["num_frames_pass1"]
+    best = best.drop(columns=["num_parts", "bitstream"])
+
+    best = best.groupby(["sequence", "qp", "fps"]).sum()
+    best = best.reset_index()
+
+    metrics = metrics[:-2]
+    for metric in metrics:
+        if "bits" in metric:
+            best["bitrate"] = best.get("bits").astype("float") * (
+                best.get("fps").astype("float")
+                / 1000
+                / best.get("num_frames").astype("float")
+            )
+        else:
+            tokens = metric.split("_")
+            best[f"avg_{tokens[1]}_{tokens[2]}"] = best.get(metric) / best.get(
+                "num_frames"
+            ).astype("float")
+
+    best["seq_qp"] = best["sequence"] + "_" + best["qp"]
+    best = best.set_index("seq_qp").sort_values(
+        by=["sequence", "qp"], key=lambda x: x.map(custom_order)
+    )
+    best = best.drop(columns=metrics).drop(
+        columns=["sequence", "qp", "fps", "num_frames"]
+    )
+    best.to_excel(output_dir / "metrics.xlsx")
+
+
+def parse_arguments():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "-vs",
+        "--video_sequences",
+        nargs="+",
+        type=str,
+        help="Sequence name/tag",
+    )
+    parser.add_argument("--qps", nargs="+", type=str, required=True, help="NNVC QP")
+    parser.add_argument("--pass1", type=str, required=True, help="First pass directory")
+    parser.add_argument(
+        "--pass2", type=str, required=True, help="Second pass directory"
+    )
+    parser.add_argument(
+        "--parcat_bin", type=str, required=True, help="Path to parcat binary"
+    )
+    parser.add_argument(
+        "--decoder_bin", type=str, required=True, help="Path to decoder binary"
+    )
+    parser.add_argument(
+        "--intra_models_dir",
+        type=str,
+        required=True,
+        help="Path to NN intra pred models dir",
+    )
+    parser.add_argument(
+        "--lop2_model", type=str, required=True, help="Path to LOP2 model"
+    )
+    parser.add_argument(
+        "--output_dir", type=str, required=True, help="Output directory"
+    )
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
+    # If video_sequences arg is not given, all sequences in the directories are processed
+    # pass1 and pass2 point to the main directory simulation (the one that contains the
+    # directories vtm_ra* or vtm_rasplitall*)
+    args = parse_arguments()
+
+    pass1_dir = check_directory(args.pass1)
+    pass2_dir = check_directory(args.pass2)
+
+    parcat_bin = check_file(args.parcat_bin)
+    decoder_bin = check_file(args.decoder_bin)
+
+    intra_models_dir = check_directory(args.intra_models_dir)
+    lop2_model = check_file(args.lop2_model)
+
+    output_dir = create_directory(args.output_dir)
+
+    config = read_json_file("resources/datasets/jvet.json")
+
+    custom_order = {}
+    for idx, seq in enumerate(config.keys()):
+        custom_order[seq] = idx + 1
+
+    segment_on_off()
diff --git a/training/training_scripts/NN_Adaptive_Filtering/trainer/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/training_scripts/NN_Adaptive_Filtering/trainer/nn_filter.py b/training/training_scripts/NN_Adaptive_Filtering/trainer/nn_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..34f2455bd721a8214b65c54bc848a38a88a34fc7
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/trainer/nn_filter.py
@@ -0,0 +1,429 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from collections import OrderedDict
+from typing import Dict, Tuple
+
+import numpy as np
+import torch
+from torch import Tensor, nn
+
+import models.model_lop2_with_multiplier as model_with_multiplier
+from models import load_torch_model
+from util import DEVICE, Colour
+from util.file_system import check_file
+from util.logger import Logger
+from util.metrics import LOSS_OP, compute_loss_psnr, zero_out_negative
+
+
+class NnFilter:
+    def __init__(
+        self,
+        model_config: Dict,
+        lr: float,
+        lr_patience: int,
+        lr_gamma: float,
+        weight_decay: float,
+        block_size: int,
+        border_size: int,
+        base_nn_params: str,
+        nn_params: str,
+        optimiser_params: str,
+        scheduler_params: str,
+        loss_op: str,
+        stop_patience: int,
+        output_dir: str,
+        select_layer: bool,
+    ):
+        torch.manual_seed(1234)
+        np.random.seed(1234)
+
+        self._select_layer = select_layer
+        self._y_block_size = block_size
+        self._uv_block_size = block_size // 2
+        self._y_border_size = border_size
+        self._uv_border_size = border_size // 2
+
+        self._model = nn.DataParallel(
+            load_torch_model(model_config, model_with_multiplier)
+        ).to(DEVICE)
+        self._base_model = nn.DataParallel(
+            load_torch_model(model_config, model_with_multiplier)
+        ).to(DEVICE)
+
+        self._logger = Logger(output_dir, stop_patience)
+        self._trainable_parameter_names = self._get_trainable_parameter_names(nn_params)
+        self._trainable_parameters = self._initialise_training()
+
+        self._optimiser = torch.optim.Adam(
+            self._trainable_parameters, lr=lr, weight_decay=weight_decay
+        )
+
+        self._lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+            self._optimiser, patience=lr_patience, verbose=True, factor=lr_gamma
+        )
+        self._loss_op = LOSS_OP[loss_op]
+
+        self._load_models(base_nn_params, nn_params, optimiser_params, scheduler_params)
+
+    def _get_trainable_parameter_names(self, nn_params):
+        # try to load from txt file, stored after online selection
+        param_names = self._logger.load_param_names()
+
+        # if not store parameter names before, then param can be overfitted when multipliers are different from initialized values
+        if len(param_names) == 0 and nn_params:
+            check_file(nn_params)
+            parameters = torch.load(nn_params)
+            for name in parameters:
+                if ".multiplier" in name and not (parameters[name] == 1.0).all():
+                    param_names.append(name)
+
+        # if nn_param is None or all multipliers are not overfitted, then return names of all multipliers
+        if len(param_names) == 0:
+            for name, _ in self._model.module.named_parameters():
+                if ".multiplier" in name:
+                    param_names.append(name)
+
+        return param_names
+
+    def _initialise_training(self):
+        trainable_parameters = []
+        for name, param in self._model.module.named_parameters():
+            if name not in self._trainable_parameter_names:
+                param.requires_grad = False
+            else:
+                trainable_parameters.append(param)
+
+        for name, param in self._base_model.named_parameters():
+            param.requires_grad = False
+        self._base_model.eval()
+        print("num of trainable terms", len(trainable_parameters))
+        return trainable_parameters
+
+    def _load_models(
+        self,
+        base_nn_params: str,
+        nn_params: str,
+        optimiser_params: str,
+        scheduler_params: str,
+    ) -> None:
+        check_file(base_nn_params)
+
+        src_params = torch.load(base_nn_params)
+        tmp_params = OrderedDict()
+        dst_params = OrderedDict()
+
+        for name, value in self._base_model.module.state_dict().items():
+            if "multiplier" not in name:
+                tmp_params[name] = value
+
+        for src_param, dst_param in zip(src_params.keys(), tmp_params.keys()):
+            dst_params[dst_param] = src_params[src_param]
+
+        self._base_model.module.load_state_dict(dst_params, strict=False)
+        print(f"Base model params loaded from {base_nn_params}")
+
+        if nn_params is None:
+            self._model.module.load_state_dict(dst_params, strict=False)
+            print(f"Model params loaded from {base_nn_params}")
+        else:
+            check_file(nn_params)
+            self._model.module.load_state_dict(torch.load(nn_params))
+            print(f"Model params loaded from {nn_params}")
+
+        if optimiser_params is not None:
+            check_file(optimiser_params)
+            self._optimiser.load_state_dict(torch.load(optimiser_params))
+            print(f"Optimiser params loaded from {optimiser_params}")
+
+        if scheduler_params is not None:
+            check_file(scheduler_params)
+            self._lr_scheduler.load_state_dict(torch.load(scheduler_params))
+            print(f"Scheduler params loaded from {scheduler_params}")
+
+    def _compute_base_metrics(
+        self,
+        input_data: Dict,
+        label_data: Dict,
+        use_multiplier: Tensor,
+        cond_target_value: Tensor,
+        loss_op,
+    ) -> Tuple[Tensor, Tensor, Tensor]:
+        """
+        Computes PSNR for the VTM reconstruction and the initial model
+        :param input_data: 4D tensor that includes the reconstruction, QP and boundary strength
+        :param label_data: Ground truth
+        :param loss_op: loss function
+        :return: PSNR (sample-wise [B])
+        """
+        reco_Y = input_data["rec_before_dbf_Y"][
+            :,
+            :,
+            self._y_border_size : self._y_border_size + self._y_block_size,
+            self._y_border_size : self._y_border_size + self._y_block_size,
+        ]
+        reco_U = input_data["rec_before_dbf_U"][
+            :,
+            :,
+            self._uv_border_size : self._uv_border_size + self._uv_block_size,
+            self._uv_border_size : self._uv_border_size + self._uv_block_size,
+        ]
+        reco_V = input_data["rec_before_dbf_V"][
+            :,
+            :,
+            self._uv_border_size : self._uv_border_size + self._uv_block_size,
+            self._uv_border_size : self._uv_border_size + self._uv_block_size,
+        ]
+
+        prediction = self._base_model(input_data, use_multiplier, cond_target_value)
+        prediction_Y = prediction[0][
+            :,
+            :,
+            self._y_border_size : self._y_border_size + self._y_block_size,
+            self._y_border_size : self._y_border_size + self._y_block_size,
+        ]
+        prediction_U, prediction_V = prediction[1][
+            :,
+            :,
+            self._uv_border_size : self._uv_border_size + self._uv_block_size,
+            self._uv_border_size : self._uv_border_size + self._uv_block_size,
+        ].split([1, 1], dim=1)
+
+        _, vtm_psnr, vtm_mask = compute_loss_psnr(
+            label_data, reco_Y, reco_U, reco_V, loss_op
+        )
+        _, base_psnr, base_mask = compute_loss_psnr(
+            label_data, prediction_Y, prediction_U, prediction_V, loss_op
+        )
+
+        return vtm_psnr, vtm_mask, base_psnr, base_mask
+
+    def _select_parameters_to_be_overfitted(self, top_k_luma, top_k_chroma) -> None:
+        """
+        Selects the parameters to be over-fitted based on the l1-norm of the weight-update (normalised by the number of units
+        in the tensor). The highest top_K parameters out of the variable_lists are selected.
+
+        :param top_k: The number of parameters to be selected
+        """
+        energy_luma = []
+        energy_chroma = []
+        var_names_luma = []
+        var_names_chroma = []
+        for (org_name, org_param), (new_name, new_param) in zip(
+            self._base_model.named_parameters(), self._model.named_parameters()
+        ):
+            assert org_name == new_name, "The models have different variables"
+            num_ele = torch.numel(org_param)
+            if "y_path" in new_name:
+                energy_luma.append(
+                    torch.sum(torch.abs(new_param - org_param)) / num_ele
+                )
+                var_names_luma.append(org_name)
+            elif "uv_path" in new_name:
+                energy_chroma.append(
+                    torch.sum(torch.abs(new_param - org_param)) / num_ele
+                )
+                var_names_chroma.append(org_name)
+
+        energy_luma = torch.Tensor(energy_luma)
+        _, e_idx = torch.topk(energy_luma, k=top_k_luma)
+        var_names_luma = [var_names_luma[idx] for idx in e_idx]
+
+        energy_chroma = torch.Tensor(energy_chroma)
+        _, e_idx = torch.topk(energy_chroma, k=top_k_chroma)
+        var_names_chroma = [var_names_chroma[idx] for idx in e_idx]
+
+        var_names = var_names_luma + var_names_chroma
+
+        # reset model
+        self._model.module.load_state_dict(self._base_model.module.state_dict())
+        self._trainable_parameters = []
+        for name, param in self._model.named_parameters():
+            # multiplier
+            if name in var_names:
+                self._trainable_parameters.append(param)
+            else:
+                param.requires_grad = False
+
+        self._optimiser = torch.optim.Adam(
+            self._trainable_parameters,
+            lr=self._optimiser.param_groups[0]["lr"],
+            weight_decay=self._optimiser.param_groups[0]["weight_decay"],
+        )
+
+        # store parameter names
+        self._logger.save_param_names(var_names)
+
+    def train_one_epoch(self, epoch: int, dataloader) -> None:
+        avg_loss = torch.zeros(Colour.YCbCr + 1).to(DEVICE)
+        avg_psnr = torch.zeros(Colour.YCbCr + 1).to(DEVICE)
+        avg_delta_psnr_wrt_vtm = torch.zeros(Colour.YCbCr + 1).to(DEVICE)
+        avg_delta_psnr_wrt_base = torch.zeros(Colour.YCbCr + 1).to(DEVICE)
+        num_elements = torch.zeros(Colour.YCbCr + 1).to(DEVICE)
+        avg_positive_delta_psnr_wrt_vtm = torch.zeros(Colour.YCbCr + 1).to(DEVICE)
+        num_elements_positive_psnr = torch.zeros(Colour.YCbCr + 1).to(DEVICE)
+        use_multiplier = torch.ones(1).to(DEVICE)
+        cond_target_value = torch.ones(1).to(DEVICE)
+
+        for batch_idx, (input_data, label_data) in enumerate(dataloader):
+            for k, v in input_data.items():
+                input_data[k] = v.to(DEVICE)
+            for k, v in label_data.items():
+                label_data[k] = v.to(DEVICE)
+
+            vtm_psnr, vtm_mask, base_psnr, base_mask = self._compute_base_metrics(
+                input_data, label_data, use_multiplier, cond_target_value, self._loss_op
+            )
+
+            self._optimiser.zero_grad()
+            prediction = self._model(input_data, use_multiplier, cond_target_value)
+            pred_Y = prediction[0][
+                :,
+                :,
+                self._y_border_size : self._y_border_size + self._y_block_size,
+                self._y_border_size : self._y_border_size + self._y_block_size,
+            ]
+            pred_U, pred_V = prediction[1][
+                :,
+                :,
+                self._uv_border_size : self._uv_border_size + self._uv_block_size,
+                self._uv_border_size : self._uv_border_size + self._uv_block_size,
+            ].split([1, 1], dim=1)
+
+            pred_loss, pred_psnr, pred_mask = compute_loss_psnr(
+                label_data, pred_Y, pred_U, pred_V, self._loss_op
+            )
+            loss = torch.mean(pred_loss[Colour.YCbCr])
+            loss.backward()
+            self._optimiser.step()
+
+            mask = vtm_mask * base_mask * pred_mask
+
+            vtm_psnr *= mask
+            base_psnr *= mask
+            pred_psnr *= mask
+            delta_psnr_wrt_vtm = pred_psnr - vtm_psnr
+            delta_psnr_wrt_base = pred_psnr - base_psnr
+
+            positive_delta_psnr_wrt_vtm = [0.0] * (Colour.YCbCr + 1)
+            pp_mask = [0.0] * (Colour.YCbCr + 1)
+
+            for color in range(Colour.YCbCr + 1):
+                positive_delta_psnr_wrt_vtm[color], pp_mask[color] = zero_out_negative(
+                    delta_psnr_wrt_vtm[color]
+                )
+
+            positive_delta_psnr_wrt_vtm = torch.stack(positive_delta_psnr_wrt_vtm)
+            pp_mask = torch.stack(pp_mask)
+
+            num_elements_positive_psnr += torch.sum(pp_mask, dim=1)
+            num_elements += torch.sum(mask, dim=1)
+            avg_loss += torch.sum(pred_loss, dim=1)
+            avg_psnr += torch.sum(pred_psnr, dim=1)
+            avg_delta_psnr_wrt_vtm += torch.sum(delta_psnr_wrt_vtm, dim=1)
+            avg_delta_psnr_wrt_base += torch.sum(delta_psnr_wrt_base, dim=1)
+            avg_positive_delta_psnr_wrt_vtm += torch.sum(
+                positive_delta_psnr_wrt_vtm, dim=1
+            )
+
+        avg_loss /= len(dataloader) * dataloader.batch_size
+        avg_psnr /= num_elements
+        avg_delta_psnr_wrt_vtm /= num_elements
+        avg_delta_psnr_wrt_base /= num_elements
+        avg_positive_delta_psnr_wrt_vtm /= num_elements_positive_psnr
+        portion_positive_delta_psnr = num_elements_positive_psnr / num_elements
+
+        avg_loss = avg_loss.tolist()
+        avg_psnr = avg_psnr.tolist()
+        avg_delta_psnr_wrt_vtm = avg_delta_psnr_wrt_vtm.tolist()
+        avg_delta_psnr_wrt_base = avg_delta_psnr_wrt_base.tolist()
+        avg_positive_delta_psnr_wrt_vtm = avg_positive_delta_psnr_wrt_vtm.tolist()
+        portion_positive_delta_psnr = portion_positive_delta_psnr.tolist()
+
+        self._logger.log_metrics(
+            epoch,
+            0,
+            avg_loss,
+            avg_psnr,
+            avg_delta_psnr_wrt_vtm,
+            avg_delta_psnr_wrt_base,
+            avg_positive_delta_psnr_wrt_vtm,
+            portion_positive_delta_psnr,
+        )
+        return avg_loss[Colour.YCbCr]
+
+    def train_loop(
+        self,
+        start_epoch: int,
+        max_epochs: int,
+        train_dataloader,
+        num_trainable_parameters: int,
+    ):
+        print(f"Starting at epoch {start_epoch}...")
+        self._model.train(True)
+        if self._select_layer:
+            select_epoch = max(1, int(max_epochs * 0.1)) if self._select_layer else -1
+
+        for epoch in range(start_epoch, max_epochs):
+            if self._select_layer and epoch == select_epoch:
+                num_trainable_parameters_luma = max(
+                    1, int(0.8 * num_trainable_parameters)
+                )
+                num_trainable_parameters_chroma = (
+                    num_trainable_parameters - num_trainable_parameters_luma
+                )
+                self._select_parameters_to_be_overfitted(
+                    num_trainable_parameters_luma, num_trainable_parameters_chroma
+                )
+                self._logger.reset()
+                print(
+                    f"After selection, the amount of trainable variables is : {len(self._trainable_parameters)}"
+                )
+
+            epoch_loss = self.train_one_epoch(epoch, train_dataloader)
+            self._lr_scheduler.step(epoch_loss)
+
+            stop_train = self._logger.save_model(
+                epoch, epoch_loss, self._model, self._optimiser, self._lr_scheduler, 30
+            )
+
+            if stop_train:
+                print("No improvements... Stopping the training now")
+                break
+
+            if self._optimiser.param_groups[0]["lr"] < 1e-8:
+                print("Learning rate too small... Stopping the training now")
+                break
+
+        self._logger.save_training_time()
diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/__init__.py b/training/training_scripts/NN_Adaptive_Filtering/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b214a3c1ef57a31e0ec871f36da8bf71759695
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/util/__init__.py
@@ -0,0 +1,63 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from typing import NamedTuple
+
+import torch
+
+
+class Colour:
+    Y = 0
+    Cb = 1
+    Cr = 2
+    YCbCr = 3
+
+
+class Position(NamedTuple):
+    x: int
+    y: int
+
+
+COLOUR_LABEL = {
+    Colour.Y: "Y",
+    Colour.Cb: "Cb",
+    Colour.Cr: "Cr",
+    Colour.YCbCr: "YCbCr",
+}
+
+
+COLOUR_WEIGHT = {Colour.Y: 12.0 / 14.0, Colour.Cb: 1.0 / 14.0, Colour.Cr: 1.0 / 14.0}
+DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+DEVICE_CPU = torch.device("cpu")
+TIME_FORMAT: str = "%Y%m%d-%H%M%S"
diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/dataset_bin.py b/training/training_scripts/NN_Adaptive_Filtering/util/dataset_bin.py
new file mode 100644
index 0000000000000000000000000000000000000000..57876746fde97ebe2a118fbd8bc99c01f35244e0
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/util/dataset_bin.py
@@ -0,0 +1,153 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from typing import List, Tuple
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.utils.data.dataset import Dataset
+
+
+class NnFilterBinDataset(Dataset):
+    def __init__(
+        self,
+        data_path: str,
+        block_size: int,
+        pad_size: int,
+        input_keys: List[str],
+        label_keys: List[str],
+        num_patches: int,
+        data_type: np.dtype,
+    ):
+        """
+        Constructor
+        :param data_path: Directory that contains the decoded video data
+        :param block_size: Block size, without padding
+        :param pad_size: Padding size (# of samples taken for a luma block at each side)
+        :param input_keys: Keys of input data dictionary
+        :param label_kyes: Keys of label data dictionary
+        :param num_patches: Number of patches in the dataset
+        :param data_type: Type of data in binary data file
+        """
+
+        self._data_path = data_path
+        self._block_size_y = block_size
+        self._block_size_uv = block_size // 2
+        self._pad_size = pad_size
+        self._num_patches = num_patches
+        self._data_type = data_type
+        self._input_keys = input_keys
+        self._label_keys = label_keys
+        self._patch_size_y = self._block_size_y + self._pad_size * 2
+        self._patch_size_uv = self._patch_size_y // 2
+        num_channels_input_Y = 4
+        num_channels_input_uv = 6
+        num_channels_label_Y = 2
+        num_channels_label_uv = 4
+        self._num_scalar = 2
+        self._input_block_volume = (
+            num_channels_input_Y * self._patch_size_y**2
+            + num_channels_input_uv * self._patch_size_uv**2
+        )
+        self._label_block_volume = (
+            num_channels_label_Y * self._block_size_y**2
+            + num_channels_label_uv * self._block_size_uv**2
+        )
+        self._shape_input_y = (1, self._patch_size_y, self._patch_size_y)
+        self._shape_input_uv = (1, self._patch_size_uv, self._patch_size_uv)
+        self._shape_label_y = (1, self._block_size_y, self._block_size_y)
+        self._shape_label_uv = (1, self._block_size_uv, self._block_size_uv)
+
+    def __len__(self) -> int:
+        """
+        Dataset length
+        """
+        return self._num_patches
+
+    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
+        """
+        Returns a single item (input and label)
+        :param idx: Index
+        :return: input and label data
+        """
+
+        input_offset = (
+            self._input_block_volume * idx * np.dtype(self._data_type).itemsize
+        )
+        label_offset = (
+            self._label_block_volume * idx * np.dtype(self._data_type).itemsize
+        )
+        qp_offset = self._num_scalar * idx * np.dtype(self._data_type).itemsize
+
+        with open(self._data_path) as f:
+            data = np.fromfile(
+                f,
+                dtype=self._data_type,
+                count=self._input_block_volume
+                + self._label_block_volume
+                + self._num_scalar,
+                offset=input_offset + label_offset + qp_offset,
+            )
+
+        input_data = {}
+        label_data = {}
+        for k in self._input_keys:
+            if "Y" in k:
+                input_data[k] = data[: self._patch_size_y**2]
+                input_data[k] = input_data[k].reshape(self._shape_input_y)
+                data = np.delete(data, range(self._patch_size_y**2))
+            else:
+                input_data[k] = data[: self._patch_size_uv**2]
+                input_data[k] = input_data[k].reshape(self._shape_input_uv)
+                data = np.delete(data, range(self._patch_size_uv**2))
+
+            input_data[k] = torch.from_numpy(input_data[k])
+
+        for k in self._label_keys:
+            if "Y" in k:
+                label_data[k] = data[: self._block_size_y**2]
+                label_data[k] = label_data[k].reshape(self._shape_label_y)
+                data = np.delete(data, range(self._block_size_y**2))
+            else:
+                label_data[k] = data[: self._block_size_uv**2]
+                label_data[k] = label_data[k].reshape(self._shape_label_uv)
+                data = np.delete(data, range(self._block_size_uv**2))
+
+            label_data[k] = torch.from_numpy(label_data[k])
+
+        input_data["qp_base"] = data[-2] * torch.ones(self._shape_input_y)
+        input_data["qp_slice"] = data[-1] * torch.ones(self._shape_input_y)
+
+        return input_data, label_data
diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/dataset_yuv.py b/training/training_scripts/NN_Adaptive_Filtering/util/dataset_yuv.py
new file mode 100644
index 0000000000000000000000000000000000000000..97c03c6cb4b02177db42308fe71d559c26d0e4c7
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/util/dataset_yuv.py
@@ -0,0 +1,533 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from typing import List, NamedTuple, Tuple, Union
+
+import numpy as np
+import torch
+from torch import Tensor
+from torch.utils.data.dataset import Dataset
+from torchvision.transforms.functional import pad
+
+from util import Position
+from util.file_system import (
+    check_directory,
+    check_file,
+    list_dirs,
+    list_selected_dirs,
+    read_json_file,
+)
+from util.image_ops import pad_image, read_yuv_frame
+
+
+class PatchInfo(NamedTuple):
+    poc: int
+    seq_width: int
+    seq_height: int
+    reco_yuv_path: str
+    boun_yuv_path: str
+    pred_yuv_path: str
+    ipb_yuv_path: str
+    orig_yuv_path: str
+    slice_qp: int
+    base_qp: int
+    temporal_layer: int
+    slice_type: str
+    bit_depth: int
+    pos: Position
+
+
+class NnFilterYuvDataset(Dataset):
+    def __init__(
+        self,
+        deco_dir: str,
+        orig_dir: str,
+        prop_file: str,
+        seq_name: str,
+        seq_qps: Union[Tuple[str], List[str]],
+        bit_depth: int,
+        block_size: int,
+        pad_size: int,
+        slice_type: str,
+        min_slice_qp: int,
+        max_slice_qp: int,
+        random_patches: bool,
+        num_patches: int,
+        min_poc: int,
+        max_poc: int,
+    ):
+        """
+        Constructor
+        :param deco_dir: Directory that contains the decoded video data
+        :param orig_dir: Directory that contains the original video data
+        :param prop_file: Properties file for the training data
+        :param seq_name: Sequence name
+        :param seq_qps: list of sequence QPs
+        :param bit_depth: Bit-depth for the data
+        :param block_size: Block size, without padding
+        :param pad_size: Padding size (# of samples taken for a luma block at each side)
+        :param slice_type: Slice type (i.e. I, B, *)
+        :param min_slice_qp: Minimum frame QP (inclusive)
+        :param max_slice_qp: Maximum frame QP (inclusive)
+        :param random_patches: Extract patches at random position
+        :param num_patches: Number of patches to extract from each frame
+        :param min_poc: Minimum index of picture
+        :param max_poc: maximum index of picture
+        """
+        self._bit_depth = bit_depth
+        self._block_size = block_size
+        self._pad_size = pad_size
+        self._random_patches = random_patches
+        self._num_patches = num_patches
+        prop_file = check_file(prop_file)
+        self._seq_name = seq_name
+        self._base_qp = None
+        self._seq_info = read_json_file(prop_file)
+        self._create_patch_lists(
+            deco_dir,
+            orig_dir,
+            seq_name,
+            seq_qps,
+            slice_type,
+            min_slice_qp,
+            max_slice_qp,
+            min_poc,
+            max_poc,
+        )
+
+    def _create_patch_lists(
+        self,
+        root_deco_dir: str,
+        root_orig_dir: str,
+        global_seq_name: str,
+        seq_qps: Union[Tuple[str], List[str]],
+        slice_type: str,
+        min_slice_qp: int,
+        max_slice_qp: int,
+        min_poc: int,
+        max_poc: int,
+    ) -> None:
+        """
+        Gets the filenames of the images to be processed (original, reconstruction) and the frame QP
+        :param root_deco_dir: Directory that contains the decoded video data
+        :param root_orig_dir: Directory that contains the original video data
+        :param global_seq_name: Sequence name
+        :param seq_qps: list of sequence QPs
+        :param slice_type: Slice type (i.e. I, P, B, *)
+        :param min_slice_qp: Minimum frame QP (inclusive)
+        :param max_slice_qp: Maximum frame QP (inclusive)
+        """
+        root_deco_dir = check_directory(root_deco_dir)
+        root_orig_dir = check_directory(root_orig_dir)
+
+        if global_seq_name is None:
+            deco_seq_dirs = list_dirs(root_deco_dir)
+        else:
+            deco_seq_dirs = [root_deco_dir / global_seq_name]
+
+        self._patches = []
+
+        for deco_seq_dir in deco_seq_dirs:
+            seq_name = deco_seq_dir.name
+            seq_width = self._seq_info[seq_name]["width"]
+            seq_height = self._seq_info[seq_name]["height"]
+            num_frames = self._seq_info[seq_name]["frames"]
+            bit_depth = self._seq_info[seq_name]["bit-depth"]
+
+            if len(seq_qps) > 0:
+                qp_dirs = list_selected_dirs(deco_seq_dir, seq_qps)
+            else:
+                qp_dirs = list_dirs(deco_seq_dir)
+
+            for qp_dir in qp_dirs:
+                frames_info = read_json_file(qp_dir / "coding_info.json")
+                base_qp = frames_info["base_qp"]
+                reco_yuv_path = qp_dir / f"{seq_name}_rec_before_dbf.yuv"
+                boun_yuv_path = qp_dir / f"{seq_name}_bs.yuv"
+                pred_yuv_path = qp_dir / f"{seq_name}_pred.yuv"
+                ipb_yuv_path = qp_dir / f"{seq_name}_bpm.yuv"
+                orig_yuv_path = list((root_orig_dir / seq_name).glob("*.yuv"))[0]
+
+                for curr_poc in range(min_poc, min(num_frames, max_poc)):
+                    curr_slice_type = frames_info["POC"][str(curr_poc)]["slice_type"]
+                    curr_slice_qp = frames_info["POC"][str(curr_poc)]["slice_qp"]
+                    curr_tid = frames_info["POC"][str(curr_poc)]["temporal_layer"]
+
+                    if (
+                        (slice_type != "*" and slice_type != curr_slice_type)
+                        or (curr_slice_qp < min_slice_qp)
+                        or (curr_slice_qp > max_slice_qp)
+                    ):
+                        continue
+
+                    # compute position
+                    if self._random_patches:
+                        positions = self._compute_random_positions(
+                            seq_width, seq_height, self._num_patches
+                        )
+                    else:
+                        positions = self._compute_predetermined_positions(
+                            seq_width, seq_height
+                        )
+
+                    for pos in positions:
+                        patch = PatchInfo(
+                            curr_poc,
+                            seq_width,
+                            seq_height,
+                            reco_yuv_path,
+                            boun_yuv_path,
+                            pred_yuv_path,
+                            ipb_yuv_path,
+                            orig_yuv_path,
+                            curr_slice_qp,
+                            base_qp,
+                            curr_tid,
+                            curr_slice_type,
+                            bit_depth,
+                            pos,
+                        )
+                        self._patches.append(patch)
+
+    def _compute_random_positions(
+        self, width: int, height: int, num_patches: int
+    ) -> List[Position]:
+        luma_bs = self._block_size
+        max_x = (
+            luma_bs * (width // luma_bs)
+            if width % luma_bs > 0
+            else luma_bs * (width // luma_bs - 1)
+        )
+        max_y = (
+            luma_bs * (height // luma_bs)
+            if height % luma_bs > 0
+            else luma_bs * (height // luma_bs - 1)
+        )
+
+        x = np.random.randint(self._pad_size, max_x + self._pad_size + 1, num_patches)
+        y = np.random.randint(self._pad_size, max_y + self._pad_size + 1, num_patches)
+
+        x = (x // 2) * 2
+        y = (y // 2) * 2
+        positions = []
+        for idx in range(num_patches):
+            positions.append(Position(x[idx], y[idx]))
+        return positions
+
+    def _compute_predetermined_positions(
+        self, width: int, height: int
+    ) -> List[Position]:
+        luma_bs = self._block_size
+        max_x = (
+            luma_bs * (width // luma_bs)
+            if width % luma_bs > 0
+            else luma_bs * (width // luma_bs - 1)
+        )
+        max_y = (
+            luma_bs * (height // luma_bs)
+            if height % luma_bs > 0
+            else luma_bs * (height // luma_bs - 1)
+        )
+        positions = []
+        for x in range(self._pad_size, max_x + self._pad_size + 1, luma_bs):
+            for y in range(self._pad_size, max_y + self._pad_size + 1, luma_bs):
+                positions.append(Position(x, y))
+        return positions
+
+    def _create_input_data(
+        self,
+        reco_y: Tensor,
+        reco_u: Tensor,
+        reco_v: Tensor,
+        boun_y: Tensor,
+        boun_u: Tensor,
+        boun_v: Tensor,
+        pred_y: Tensor,
+        pred_u: Tensor,
+        pred_v: Tensor,
+        ipb_y: Tensor,
+        ipb_u: Tensor,
+        ipb_v: Tensor,
+        slice_qp: Tensor,
+        base_qp: Tensor,
+        pos_x: int,
+        pos_y: int,
+    ) -> Tensor:
+        """
+        Creates input data
+        :param reco_y: luma reconstruction image
+        :param reco_u: cb reconstruction image
+        :param reco_v: cr reconstruction image
+        :param boun_y: luma boundary stength image
+        :param boun_u: cb boundary stength image
+        :param boun_v: cr boundary stength image
+        :param qp: QP
+        :param pos_x: left corner position
+        :param pos_y: top corner position
+        :return: Patch
+        """
+        if self._pad_size > 0:
+            reco_y = pad_image(reco_y, self._block_size, self._pad_size)
+            reco_u = pad_image(reco_u, self._block_size // 2, self._pad_size // 2)
+            reco_v = pad_image(reco_v, self._block_size // 2, self._pad_size // 2)
+
+            boun_y = pad_image(boun_y, self._block_size, self._pad_size)
+            boun_u = pad_image(boun_u, self._block_size // 2, self._pad_size // 2)
+            boun_v = pad_image(boun_v, self._block_size // 2, self._pad_size // 2)
+
+            pred_y = pad_image(pred_y, self._block_size, self._pad_size)
+            pred_u = pad_image(pred_u, self._block_size // 2, self._pad_size // 2)
+            pred_v = pad_image(pred_v, self._block_size // 2, self._pad_size // 2)
+
+            ipb_y = pad_image(ipb_y, self._block_size, self._pad_size)
+
+        patch_size = self._block_size + self._pad_size * 2
+
+        actual_y_pos_x = pos_x - self._pad_size
+        actual_y_pos_y = pos_y - self._pad_size
+
+        actual_uv_pos_x = (pos_x - self._pad_size) // 2
+        actual_uv_pos_y = (pos_y - self._pad_size) // 2
+
+        input_data = {}
+        input_data["rec_before_dbf_Y"] = reco_y[
+            :,
+            actual_y_pos_y : actual_y_pos_y + patch_size,
+            actual_y_pos_x : actual_y_pos_x + patch_size,
+        ]
+        input_data["rec_before_dbf_U"] = reco_u[
+            :,
+            actual_uv_pos_y : actual_uv_pos_y + patch_size // 2,
+            actual_uv_pos_x : actual_uv_pos_x + patch_size // 2,
+        ]
+        input_data["rec_before_dbf_V"] = reco_v[
+            :,
+            actual_uv_pos_y : actual_uv_pos_y + patch_size // 2,
+            actual_uv_pos_x : actual_uv_pos_x + patch_size // 2,
+        ]
+        input_data["pred_Y"] = pred_y[
+            :,
+            actual_y_pos_y : actual_y_pos_y + patch_size,
+            actual_y_pos_x : actual_y_pos_x + patch_size,
+        ]
+        input_data["pred_U"] = pred_u[
+            :,
+            actual_uv_pos_y : actual_uv_pos_y + patch_size // 2,
+            actual_uv_pos_x : actual_uv_pos_x + patch_size // 2,
+        ]
+        input_data["pred_V"] = pred_v[
+            :,
+            actual_uv_pos_y : actual_uv_pos_y + patch_size // 2,
+            actual_uv_pos_x : actual_uv_pos_x + patch_size // 2,
+        ]
+        input_data["bs_Y"] = boun_y[
+            :,
+            actual_y_pos_y : actual_y_pos_y + patch_size,
+            actual_y_pos_x : actual_y_pos_x + patch_size,
+        ]
+        input_data["bs_U"] = boun_u[
+            :,
+            actual_uv_pos_y : actual_uv_pos_y + patch_size // 2,
+            actual_uv_pos_x : actual_uv_pos_x + patch_size // 2,
+        ]
+        input_data["bs_V"] = boun_v[
+            :,
+            actual_uv_pos_y : actual_uv_pos_y + patch_size // 2,
+            actual_uv_pos_x : actual_uv_pos_x + patch_size // 2,
+        ]
+        input_data["qp_base"] = base_qp * torch.ones(
+            input_data["rec_before_dbf_Y"].shape
+        )
+        input_data["qp_slice"] = slice_qp * torch.ones(
+            input_data["rec_before_dbf_Y"].shape
+        )
+
+        input_data["ipb_Y"] = ipb_y[
+            :,
+            actual_y_pos_y : actual_y_pos_y + patch_size,
+            actual_y_pos_x : actual_y_pos_x + patch_size,
+        ]
+
+        return input_data
+
+    def _create_label_data(
+        self, y: Tensor, u: Tensor, v: Tensor, pos_x: int, pos_y: int
+    ) -> Tensor:
+        """
+        Creates label data
+        :param y: luma image
+        :param u: cb image
+        :param v: cr image
+        :param pos_x: left corner position
+        :param pos_y: top corner position
+        :return: Patch
+        """
+
+        if self._pad_size > 0:
+            _, height, width = y.shape
+            mod = width % self._block_size
+            right_padding = self._block_size - mod if mod > 0 else 0
+            mod = height % self._block_size
+            bottom_padding = self._block_size - mod if mod > 0 else 0
+
+            mask = torch.ones(y.shape)
+            mask = pad(mask, [0, 0, right_padding, bottom_padding])
+
+            y = pad(y, [0, 0, right_padding, bottom_padding])
+            u = pad(u, [0, 0, right_padding // 2, bottom_padding // 2])
+            v = pad(v, [0, 0, right_padding // 2, bottom_padding // 2])
+
+        actual_pos_x = pos_x - self._pad_size
+        actual_pos_y = pos_y - self._pad_size
+
+        label_data = {}
+        label_data["orig_Y"] = y[
+            :,
+            actual_pos_y : actual_pos_y + self._block_size,
+            actual_pos_x : actual_pos_x + self._block_size,
+        ]
+        label_data["mask_Y"] = mask[
+            :,
+            actual_pos_y : actual_pos_y + self._block_size,
+            actual_pos_x : actual_pos_x + self._block_size,
+        ]
+        actual_pos_x //= 2
+        actual_pos_y //= 2
+        label_data["orig_U"] = u[
+            :,
+            actual_pos_y : actual_pos_y + self._block_size // 2,
+            actual_pos_x : actual_pos_x + self._block_size // 2,
+        ]
+        label_data["orig_V"] = v[
+            :,
+            actual_pos_y : actual_pos_y + self._block_size // 2,
+            actual_pos_x : actual_pos_x + self._block_size // 2,
+        ]
+        mask_UV = torch.nn.functional.pixel_unshuffle(label_data["mask_Y"], 2)
+
+        label_data["mask_U"] = mask_UV[0:1, :, :]
+        label_data["mask_V"] = mask_UV[0:1, :, :]
+
+        return label_data
+
+    def __len__(self) -> int:
+        """
+        Dataset length
+        """
+        return len(self._patches)
+
+    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
+        """
+        Returns a single item (input and label)
+        :param idx: Index
+        :return: input and label data
+        """
+        patch = self._patches[idx]
+
+        orig_y, orig_u, orig_v = read_yuv_frame(
+            patch.orig_yuv_path,
+            patch.seq_width,
+            patch.seq_height,
+            patch.bit_depth,
+            1024.0,
+            False,
+            patch.poc,
+        )
+
+        pos_x, pos_y = patch.pos
+
+        reco_y, reco_u, reco_v = read_yuv_frame(
+            patch.reco_yuv_path,
+            patch.seq_width,
+            patch.seq_height,
+            self._bit_depth,
+            1024.0,
+            False,
+            patch.poc,
+        )
+
+        boun_y, boun_u, boun_v = read_yuv_frame(
+            patch.boun_yuv_path,
+            patch.seq_width,
+            patch.seq_height,
+            self._bit_depth,
+            1024.0,
+            False,
+            patch.poc,
+        )
+
+        pred_y, pred_u, pred_v = read_yuv_frame(
+            patch.pred_yuv_path,
+            patch.seq_width,
+            patch.seq_height,
+            self._bit_depth,
+            1024.0,
+            False,
+            patch.poc,
+        )
+
+        ipb_y, ipb_u, ipb_v = read_yuv_frame(
+            patch.ipb_yuv_path,
+            patch.seq_width,
+            patch.seq_height,
+            self._bit_depth,
+            2.0,
+            True,
+            patch.poc,
+        )
+
+        slice_qp = float(patch.slice_qp) / 64.0
+        base_qp = float(patch.base_qp) / 64.0
+
+        input_data = self._create_input_data(
+            reco_y,
+            reco_u,
+            reco_v,
+            boun_y,
+            boun_u,
+            boun_v,
+            pred_y,
+            pred_u,
+            pred_v,
+            ipb_y,
+            ipb_u,
+            ipb_v,
+            slice_qp,
+            base_qp,
+            pos_x,
+            pos_y,
+        )
+        label_data = self._create_label_data(orig_y, orig_u, orig_v, pos_x, pos_y)
+
+        return input_data, label_data, base_qp, slice_qp
diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/file_system.py b/training/training_scripts/NN_Adaptive_Filtering/util/file_system.py
new file mode 100644
index 0000000000000000000000000000000000000000..3323f2fec1927f5a2fb6a995a783996eadd1d689
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/util/file_system.py
@@ -0,0 +1,292 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import json
+import sys
+from pathlib import Path
+from typing import Dict, List, Tuple, Union
+
+import numpy as np
+
+
+def check_directory(input_path: Union[Path, str]) -> Path:
+    """
+    Checks whether the given path exists and corresponds to a directory
+    :param input_path: directory path
+    :return: Path obj
+    """
+    dir_path = Path(input_path) if isinstance(input_path, str) else input_path
+    if dir_path.exists() and dir_path.is_dir():
+        return dir_path
+
+    print(f"{input_path} is not a directory")
+    sys.exit(-1)
+
+
+def check_file(input_path: Union[Path, str]) -> Path:
+    """
+    Checks whether the given path exists and corresponds to a file
+    :param input_path: file path
+    :return: Path obj
+    """
+    file_path = Path(input_path) if isinstance(input_path, str) else input_path
+    if file_path.exists() and file_path.is_file():
+        return file_path
+
+    print(f"{input_path} is not a file")
+    sys.exit(-1)
+
+
+def list_dirs(input_dir: Path) -> List[Path]:
+    """
+    Lists the subdirectories of the given directory
+    :param input_dir: Input directory
+    :return: list of directories
+    """
+    return [f for f in input_dir.iterdir() if f.is_dir()]
+
+
+def create_directory(input_path: Union[Path, str]) -> Path:
+    """
+    Creates a directory
+    :param input_path: Input path
+    :return Path obj
+    """
+    dir_path = Path(input_path) if isinstance(input_path, str) else input_path
+    dir_path.mkdir(parents=True, exist_ok=True)
+    return dir_path
+
+
+def list_selected_dirs(
+    input_dir: Path, pattern: Union[str, Tuple[str], List[str]]
+) -> List[Path]:
+    """
+    Lists the subdirectories that contain a given text in their names
+    :param input_dir: Input directory
+    :param pattern: pattern to match
+    :return: list of directories
+    """
+    if isinstance(pattern, str):
+        return [f for f in input_dir.iterdir() if f.is_dir() and pattern in f.name]
+
+    return [f for f in input_dir.iterdir() if f.is_dir() and f.name in pattern]
+
+
+def read_json_file(json_path: Union[Path, str]) -> Dict:
+    """
+    Reads JSON file
+    :param json_path: Absolute path to the JSON file
+    :return: Dictionary containing JSON file data
+    """
+    file_path = check_file(json_path)
+    with open(file_path, "r") as stream:
+        config = json.load(stream)
+    return config
+
+
+def write_json_file(content: Dict, json_path: Union[Path, str]) -> None:
+    """
+    Writes a dictionary to a JSON file
+    :param content: Dictionary to be saved
+    :param json_path: Absolute path to the JSON file
+    """
+    file_path = Path(json_path) if isinstance(json_path, str) else json_path
+    assert (
+        file_path.parent.exists() and file_path.parent.is_dir()
+    ), f"The parent directory {file_path.parent} does not exist"
+    with open(file_path, "w") as stream:
+        json.dump(content, stream, sort_keys=True, indent=4)
+
+
+def create_vtm_config_file(
+    cfg_file: Path, filename: Path, width: int, height: int, fps: int, num_frames: int
+) -> None:
+    """
+    Creates the sequence config file for VTM encoding
+    :param cfg_file: Output file name
+    :param filename: YUV file name
+    :param width: Width of the YUV
+    :param height: Height of the YUV
+    :param fps: Frame rate of the YUV
+    :param num_frames: Number of frames to be encoded
+    """
+    with open(cfg_file, "w") as stream:
+        stream.write(f"InputFile:           {filename}\n")
+        stream.write(f"SourceWidth:         {width}\n")
+        stream.write(f"SourceHeight:        {height}\n")
+        stream.write("InputBitDepth:       10\n")
+        stream.write("InputChromaFormat:   420\n")
+        stream.write(f"FrameRate:           {fps}\n")
+        stream.write("FrameSkip:           0\n")
+        stream.write(f"FramesToBeEncoded:   {num_frames}\n")
+        stream.write("Level:               5.1\n")
+
+
+def read_yuv_frame(
+    file_path: Path,
+    width: int,
+    height: int,
+    bit_depth: int,
+    frame_idx: int,
+    frame_skip: int = 0,
+    temporal_subsample_ratio: int = 1,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+    """
+    Reads YUV frame
+    :param file_path: Absolute file path
+    :param width: Luma width
+    :param height: Luma height
+    :param bit_depth: Bit-depth
+    :param frame_idx: Frame index
+    :param frame_skip: Number of frames to skip
+    :param temporal_subsample_ratio: Temporal subsample ratio (see VTM config)
+    :return: Y, U, V frames
+    """
+    word = 1 if bit_depth == 8 else 2
+    pix_type = np.uint8 if bit_depth == 8 else np.uint16
+
+    frame_pos = (frame_idx + frame_skip) * temporal_subsample_ratio
+    frame_size = width * height * word * 3 // 2
+
+    uv_width = width // 2
+    uv_height = height // 2
+
+    with open(file_path, "rb") as stream:
+        stream.seek(frame_pos * frame_size)
+
+        y = np.frombuffer(stream.read(width * height * word), pix_type)
+        u = np.frombuffer(stream.read(uv_width * uv_height * word), pix_type)
+        v = np.frombuffer(stream.read(uv_width * uv_height * word), pix_type)
+
+    y = np.reshape(y, (height, width))
+    u = np.reshape(u, (height // 2, width // 2))
+    v = np.reshape(v, (height // 2, width // 2))
+
+    if bit_depth == 8:
+        pix_type = np.uint16
+        y = y.astype(pix_type) * 4
+        u = u.astype(pix_type) * 4
+        v = v.astype(pix_type) * 4
+
+    return y, u, v
+
+
+def read_yuv_patch(
+    file_path: Path,
+    width: int,
+    height: int,
+    bit_depth: int,
+    frame_idx: int,
+    luma_pos: int,
+    luma_patch_size: int,
+    uv_pos: int,
+    uv_patch_size: int,
+    frame_skip: int = 0,
+    temporal_subsample_ratio: int = 1,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+    """
+    Reads YUV patch
+    :param file_path: Absolute file path
+    :param width: Luma frame width
+    :param height: Luma frame height
+    :param bit_depth: Bit-depth
+    :param frame_idx: Frame index
+    :param luma_pos: Top-left corner position of the luma patch
+    :param luma_patch_size: Luma patch size (one side)
+    :param uv_pos: Top-left corner position of the chroma patches
+    :param uv_patch_size: Chroma patch size (one side)
+    :param frame_skip: Number of frames to skip
+    :param temporal_subsample_ratio: Temporal subsample ratio (see VTM config)
+    :return: Y, U, V patches
+    """
+    word = 1 if bit_depth == 8 else 2
+    pix_type = np.uint8 if bit_depth == 8 else np.uint16
+
+    y_frame_size = width * height
+    uv_frame_size = width // 2 * height // 2
+    full_frame_size = y_frame_size + 2 * uv_frame_size
+
+    frame_pos = (frame_idx + frame_skip) * temporal_subsample_ratio
+
+    uv_width = width // 2
+
+    y = np.zeros((luma_patch_size, luma_patch_size), dtype=np.uint16)
+    u = np.zeros((uv_patch_size, uv_patch_size), dtype=np.uint16)
+    v = np.zeros((uv_patch_size, uv_patch_size), dtype=np.uint16)
+
+    with open(file_path, "rb") as stream:
+        patch_start = word * (
+            frame_pos * full_frame_size + luma_pos.y * width + luma_pos.x
+        )
+        stream.seek(patch_start, 0)
+
+        for i in range(luma_patch_size):
+            row = np.frombuffer(stream.read(luma_patch_size * word), pix_type)
+            y[i, :] = row.astype(np.uint16) * 4 if bit_depth == 8 else row
+            stream.seek(word * (width - luma_patch_size), 1)
+
+        patch_start = word * (
+            frame_pos * full_frame_size + y_frame_size + uv_pos.y * uv_width + uv_pos.x
+        )
+        stream.seek(patch_start, 0)
+
+        for i in range(uv_patch_size):
+            row = np.frombuffer(stream.read(uv_patch_size * word), pix_type)
+            u[i, :] = row.astype(np.uint16) * 4 if bit_depth == 8 else row
+            stream.seek(word * (uv_width - uv_patch_size), 1)
+
+        patch_start = word * (
+            frame_pos * full_frame_size
+            + y_frame_size
+            + uv_frame_size
+            + uv_pos.y * uv_width
+            + uv_pos.x
+        )
+        stream.seek(patch_start, 0)
+
+        for i in range(uv_patch_size):
+            row = np.frombuffer(stream.read(uv_patch_size * word), pix_type)
+            v[i, :] = row.astype(np.uint16) * 4 if bit_depth == 8 else row
+            stream.seek(word * (uv_width - uv_patch_size), 1)
+
+    return y, u, v
+
+
+def get_models(root_dir: Path) -> List[Path]:
+    sub_dirs = list_dirs(root_dir)
+    return [
+        v / "models/best_model.pt"
+        for v in sub_dirs
+        if (v / "models/best_model.pt").is_file()
+    ]
diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/image_ops.py b/training/training_scripts/NN_Adaptive_Filtering/util/image_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..595167e209313dfd7d09359ce605d5bd0381df71
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/util/image_ops.py
@@ -0,0 +1,149 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from pathlib import Path
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch import Tensor
+
+from util import DEVICE
+
+
+def normalise_image(in_img: np.ndarray, normalize_factor: int) -> Tensor:
+    """
+    Normalises an image
+    :param in_img: Input image
+    :param normalize_factor: Factor for normalization
+    :return: Normalised image
+    """
+    out_img = torch.from_numpy(in_img.astype(np.float32)).to(DEVICE)
+    out_img = out_img.div(normalize_factor)
+    out_img = out_img.unsqueeze(0)
+    return out_img
+
+
+def read_yuv_frame(
+    file_path: Path,
+    width: int,
+    height: int,
+    bit_depth: int,
+    normalize_factor: int,
+    ipb: bool,
+    frame_idx: int,
+    frame_skip: int = 0,
+    temporal_subsample_ratio: int = 1,
+) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+    """
+    Reads YUV frame
+    :param file_path: Absolute file path
+    :param width: Luma width
+    :param height: Luma height
+    :param bit_depth: Bit-depth
+    :param frame_idx: Frame index
+    :param frame_skip: Number of frames to skip
+    :param temporal_subsample_ratio: Temporal subsample ratio (see VTM config)
+    :return: Y, U, V frames
+    """
+    word = 1 if bit_depth == 8 else 2
+    pix_type = np.uint8 if bit_depth == 8 else np.uint16
+
+    frame_pos = (frame_idx + frame_skip) * temporal_subsample_ratio
+    frame_size = width * height * word * 3 // 2
+
+    uv_width = width // 2
+    uv_height = height // 2
+
+    with open(file_path, "rb") as stream:
+        stream.seek(frame_pos * frame_size)
+
+        y = np.frombuffer(stream.read(width * height * word), pix_type)
+        u = np.frombuffer(stream.read(uv_width * uv_height * word), pix_type)
+        v = np.frombuffer(stream.read(uv_width * uv_height * word), pix_type)
+
+    y = np.reshape(y, (height, width))
+    u = np.reshape(u, (height // 2, width // 2))
+    v = np.reshape(v, (height // 2, width // 2))
+
+    if bit_depth == 8:
+        pix_type = np.uint16
+        y = y.astype(pix_type) * 4
+        u = u.astype(pix_type) * 4
+        v = v.astype(pix_type) * 4
+
+    if ipb:
+        # first right shift: Ipred=0, IBC=1, Uni-pred=2, bi-pred=3
+        y = np.right_shift(y, 1)
+        u = np.right_shift(u, 1)
+        v = np.right_shift(v, 1)
+        # Ipred/IBV=0, Unipred=1, Bi-pred=2
+        y = (y == 0) + y - 1
+        u = (u == 0) + u - 1
+        v = (v == 0) + v - 1
+
+    y = normalise_image(y, normalize_factor)
+    u = normalise_image(u, normalize_factor)
+    v = normalise_image(v, normalize_factor)
+
+    return y, u, v
+
+
+def pad_image(in_image: Tensor, block_size: int, border_size: int) -> Tensor:
+    """
+    Applies padding to the input image
+    :param in_image: Input image
+    :param block_size: Size of the actual block (final output size)
+    :param border_size: Number of samples added to each side of the block size
+    :return: Padded image
+    """
+    _, height, width = in_image.shape
+
+    mod = width % block_size
+    if mod > 0:
+        right_border = border_size + block_size - mod
+    else:
+        right_border = border_size
+
+    mod = height % block_size
+    if mod > 0:
+        bottom_border = border_size + block_size - mod
+    else:
+        bottom_border = border_size
+
+    transform = torch.nn.ConstantPad2d(
+        (border_size, right_border, border_size, bottom_border), 0.0
+    )
+    out_image = transform(in_image)
+    return out_image
diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/logger.py b/training/training_scripts/NN_Adaptive_Filtering/util/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..aab5c5e9ddc5980ca0a916b90874f6b4f7f96cee
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/util/logger.py
@@ -0,0 +1,202 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import sys
+import time
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import List, Tuple
+
+import numpy as np
+import torch
+from torch.utils.tensorboard import SummaryWriter
+
+from util import COLOUR_LABEL, TIME_FORMAT, Colour
+
+
+def compute_elapsed_time(start_time: str) -> Tuple[str, timedelta]:
+    end_time = time.strftime(TIME_FORMAT)
+    elapsed_time = datetime.strptime(end_time, TIME_FORMAT) - datetime.strptime(
+        start_time, TIME_FORMAT
+    )
+    return end_time, elapsed_time
+
+
+class Logger:
+    def __init__(self, log_dir: str, stop_patience: int):
+        log_dir = Path(log_dir)
+        log_dir.mkdir(exist_ok=True, parents=True)
+
+        self._models_dir = log_dir / "models"
+        self._models_dir.mkdir(exist_ok=True, parents=True)
+        self._checkpoint_dir = self._models_dir / "checkpoints"
+        self._checkpoint_dir.mkdir(exist_ok=True, parents=True)
+
+        self._time_file = log_dir / "time.csv"
+        self._write_idx = 0
+        self._train_writer = SummaryWriter(str(log_dir / "train"))
+
+        self._start_time = time.strftime(TIME_FORMAT)
+        self._last_time = self._start_time
+
+        self._best_loss = sys.float_info.max
+        self._num_no_improvements = 0
+        self._stop_patience = stop_patience
+
+    def log_metrics(
+        self,
+        epoch: int,
+        batch: int,
+        loss: np.ndarray,
+        psnr: np.ndarray,
+        delta_psnr_wrt_vtm: np.ndarray,
+        delta_psnr_wrt_base: np.ndarray,
+        positive_delta_psnr_wrt_base: np.ndarray,
+        portion_positive_delta_psnr: np.ndarray,
+    ) -> None:
+        _, elapsed_time = compute_elapsed_time(self._last_time)
+
+        metrics_str = ""
+        for c in range(Colour.YCbCr + 1):
+            label = COLOUR_LABEL[c]
+
+            if loss is not None:
+                self._train_writer.add_scalar(f"Loss_{label}", loss[c], epoch)
+                metrics_str += "Loss_{}: {:.6f}; ".format(label, loss[c])
+            if psnr is not None:
+                self._train_writer.add_scalar(f"PSNR_{label}", psnr[c], epoch)
+                metrics_str += "PSNR_{}: {:.6f}; ".format(label, psnr[c])
+            if delta_psnr_wrt_vtm is not None:
+                self._train_writer.add_scalar(
+                    f"dPSNR_{label}_wrt_vtm", delta_psnr_wrt_vtm[c], epoch
+                )
+                metrics_str += "dPSNR_{}_wrt_vtm: {:.6f}; ".format(
+                    label, delta_psnr_wrt_vtm[c]
+                )
+            if delta_psnr_wrt_base is not None:
+                self._train_writer.add_scalar(
+                    f"dPSNR_{label}_wrt_base", delta_psnr_wrt_base[c], epoch
+                )
+                metrics_str += "dPSNR_{}_wrt_base: {:.6f}; ".format(
+                    label, delta_psnr_wrt_base[c]
+                )
+            if positive_delta_psnr_wrt_base is not None:
+                self._train_writer.add_scalar(
+                    f"dPSNR_positive_{label}", positive_delta_psnr_wrt_base[c], epoch
+                )
+                metrics_str += "dPSNR_positive_{}: {:.6f}; ".format(
+                    label, positive_delta_psnr_wrt_base[c]
+                )
+            if portion_positive_delta_psnr is not None:
+                self._train_writer.add_scalar(
+                    f"percentage_positive_dpsnr_{label}",
+                    portion_positive_delta_psnr[c],
+                    epoch,
+                )
+                metrics_str += "percentage_positive_dpsnr_{}: {:.6f}; ".format(
+                    label, portion_positive_delta_psnr[c]
+                )
+
+        # print(
+        #     "Epoch: {:03d}; Batch: {:06d}; {}Time: {}".format(
+        #         epoch, batch, metrics_str, elapsed_time
+        #     )
+        # )
+        self._last_time = time.strftime(TIME_FORMAT)
+
+    def save_model(
+        self,
+        epoch: int,
+        loss: float,
+        model: torch.nn.Module,
+        optimiser: torch.optim.Optimizer,
+        lr_scheduler,
+        save_interval: int = 1,
+    ) -> bool:
+        # x == x to check nan
+        if loss == loss and loss < self._best_loss:
+            self._best_loss = loss
+            self._num_no_improvements = 0
+            print(f"Saving the best model at epoch {epoch}...")
+            torch.save(
+                model.module.state_dict(), str(self._models_dir / "best_model.pt")
+            )
+        else:
+            self._num_no_improvements += 1
+
+        stop_train = self._num_no_improvements > self._stop_patience
+
+        if (epoch + 1) % save_interval != 0:
+            return stop_train
+
+        print(f"Saving checkpoint at epoch {epoch}...")
+        torch.save(
+            model.module.state_dict(),
+            str(self._checkpoint_dir / "model_{:03d}.pt".format(epoch)),
+        )
+        torch.save(
+            optimiser.state_dict(),
+            str(self._checkpoint_dir / "model_{:03d}_optimiser.pt".format(epoch)),
+        )
+        torch.save(
+            lr_scheduler.state_dict(),
+            str(self._checkpoint_dir / "model_{:03d}_lr_scheduler.pt".format(epoch)),
+        )
+
+        return stop_train
+
+    def reset(self) -> None:
+        self._best_loss = sys.float_info.max
+        self._num_no_improvements = 0
+
+    def save_training_time(self) -> None:
+        end_time, elapsed_time = compute_elapsed_time(self._start_time)
+        with open(self._time_file, "w") as stream:
+            stream.writelines("start,end,duration\n")
+            stream.writelines(f"{self._start_time},{end_time},{elapsed_time}\n")
+
+    def save_param_names(self, param_names: List[str]) -> None:
+        with open(self._checkpoint_dir / "param_names.txt", "w") as file:
+            names_to_write = "\n".join(param_names)
+            file.write(names_to_write)
+
+    def load_param_names(self) -> List[str]:
+        param_names = []
+        try:
+            with open(self._checkpoint_dir / "param_names.txt", "r") as file:
+                param_names = file.readlines()
+        except Exception:
+            print("No files with names of trainable parameters")
+
+        return param_names
diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/metrics.py b/training/training_scripts/NN_Adaptive_Filtering/util/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cecabb3cf679cd844d114717063de716dadbb2e
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/util/metrics.py
@@ -0,0 +1,164 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import math
+from typing import Dict, Tuple
+
+import torch
+from torch import Tensor
+
+from util import COLOUR_WEIGHT, Colour
+
+
+def compute_sample_loss_psnr(
+    ground_truth: Tensor, prediction: Tensor, loss_func
+) -> Tuple[Tensor, Tensor, Tensor]:
+    """
+    Computes the MSE and PSNR per sample
+    :param ground_truth: Ground-truth
+    :param prediction: Prediction
+    :param loss_func: Loss function
+    :return: MSE, PSNR and mask that indicates what PSNR values are valid (not inf nor nan)
+    """
+
+    label, loss = loss_func(ground_truth, prediction)
+
+    if label == "MSE":
+        psnr = 10.0 * (torch.log(1.0 / loss) / math.log(10.0))
+    else:
+        mse = torch.mean((ground_truth - prediction) ** 2, dim=(1, 2, 3))
+        psnr = 10.0 * (torch.log(1.0 / mse) / math.log(10.0))
+
+    psnr, mask = zero_out_inf_nan(psnr)
+    return loss, psnr, mask
+
+
+def zero_out_negative(in_tensor: Tensor) -> Tuple[Tensor, Tensor]:
+    """
+    Zeroes out values in the input tensor that are negative
+    :param in_tensor: Input tensor
+    :return: Tensor with zero on those positions where there was previously a negative value. Also a mask that
+    indicates what are the valid values in the tensor
+    """
+    out_tensor = torch.where(in_tensor < 0.0, 0.0, in_tensor)
+    mask = torch.where(in_tensor > 0.0, 1.0, 0.0)
+    return out_tensor, mask
+
+
+def zero_out_inf_nan(in_tensor: Tensor) -> Tuple[Tensor, Tensor]:
+    """
+    Zeroes out values in the input tensor that are either inf or nan
+    :param in_tensor: Input tensor
+    :return: Tensor with zero on those positions where there was previously an inf or nan value. Also a mask that
+    indicates what are the valid values in the tensor
+    """
+    condition = torch.logical_or(torch.isinf(in_tensor), torch.isnan(in_tensor))
+    out_tensor = torch.where(condition, 0.0, in_tensor)
+    mask = torch.where(condition, 0.0, 1.0)
+    return out_tensor, mask
+
+
+def compute_loss_psnr(
+    ground_truth: Dict,
+    prediction_Y: Tensor,
+    prediction_U: Tensor,
+    prediction_V: Tensor,
+    loss_func,
+) -> Tuple[Tensor, Tensor, Tensor]:
+    """
+    Computes both the Loss and PSNR colour- and sample-wise
+    :param ground_truth: Ground-truth
+    :param prediction: Prediction
+    :param loss_func: Loss function
+    :return: MSE, PSNR and a mask. The mask indicates for which samples the PSNR is valid (not inf nor nan)
+    """
+
+    y_orig = ground_truth["orig_Y"]
+    cb_orig = ground_truth["orig_U"]
+    cr_orig = ground_truth["orig_V"]
+
+    y_pred = torch.multiply(prediction_Y, ground_truth["mask_Y"])
+    cb_pred = torch.multiply(prediction_U, ground_truth["mask_U"])
+    cr_pred = torch.multiply(prediction_V, ground_truth["mask_V"])
+
+    loss = [0.0] * (Colour.YCbCr + 1)
+    psnr = [0.0] * (Colour.YCbCr + 1)
+    mask = [0.0] * (Colour.YCbCr + 1)
+
+    loss[Colour.Y], psnr[Colour.Y], mask[Colour.Y] = compute_sample_loss_psnr(
+        y_pred, y_orig, loss_func
+    )
+    loss[Colour.Cb], psnr[Colour.Cb], mask[Colour.Cb] = compute_sample_loss_psnr(
+        cb_pred, cb_orig, loss_func
+    )
+    loss[Colour.Cr], psnr[Colour.Cr], mask[Colour.Cr] = compute_sample_loss_psnr(
+        cr_pred, cr_orig, loss_func
+    )
+
+    y_weight = COLOUR_WEIGHT[Colour.Y]
+    cb_weight = COLOUR_WEIGHT[Colour.Cb]
+    cr_weight = COLOUR_WEIGHT[Colour.Cr]
+
+    loss[Colour.YCbCr] = (
+        y_weight * loss[Colour.Y]
+        + cb_weight * loss[Colour.Cb]
+        + cr_weight * loss[Colour.Cr]
+    )
+    psnr[Colour.YCbCr] = (
+        y_weight * psnr[Colour.Y]
+        + cb_weight * psnr[Colour.Cb]
+        + cr_weight * psnr[Colour.Cr]
+    )
+
+    loss = torch.stack(loss)
+    psnr = torch.stack(psnr)
+
+    mask[Colour.YCbCr] = mask[Colour.Y] * mask[Colour.Cb] * mask[Colour.Cr]
+    mask = torch.stack(mask)
+
+    return loss, psnr, mask
+
+
+def mse_loss(orig, pred):
+    return "MSE", torch.mean((orig - pred) ** 2, dim=(1, 2, 3))
+
+
+def mae_loss(orig, pred):
+    return "MAE", torch.mean(torch.abs(orig - pred), dim=(1, 2, 3))
+
+
+LOSS_OP = {
+    "mse": mse_loss,
+    "mae": mae_loss,
+}
diff --git a/training/training_scripts/NN_Adaptive_Filtering/util/regex.py b/training/training_scripts/NN_Adaptive_Filtering/util/regex.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fdeeddab0c69ccab7426bf92fa246ae537f687e
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/util/regex.py
@@ -0,0 +1,127 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import re
+from typing import Tuple
+
+# regular expression to identify the segments
+ra_segment_rgx = re.compile(r"p(\d+)")
+
+
+def get_regex_groups_for_pattern(
+    text: str, pattern: str, expected_number_of_groups: int
+) -> re.match:
+    """
+    Matches a regular expression from the beginning of a given text
+    :param text: Input text
+    :param pattern: Regular expression
+    :param expected_number_of_groups: number of groups to extract
+    :return: match
+    """
+    match = re.search(pattern, text)
+    if match is None or len(match.groups()) != expected_number_of_groups:
+        raise ValueError(f"Couldn't find pattern {pattern} in '{text}'")
+    return match
+
+
+def look_for_string_pattern_in_text(text: str, pattern: str) -> str:
+    """
+    Looks for a pattern in a text
+    :param text: Input text
+    :param pattern: Pattern
+    :return: String in the text that matches the pattern
+    """
+    matches = re.findall(pattern, text)
+    if len(matches) > 0:
+        return matches[0]
+    return "0"
+
+
+def get_frame_info_from_decoder_log_line(text: str) -> Tuple[int, int, str, int]:
+    """
+    Given an input text, it extracts video frame information
+    :param text: Input text
+    :return: POC, temporal layer, slice type and QP
+    """
+    pattern = r"POC\s+(\d+)\s+LId:\s+\d+\s+TId:\s+(\d+)\s+\(\s+.*,\s+([I|P|B])-\w+,\s+QP\s+(\d+)"
+    match = get_regex_groups_for_pattern(text, pattern, 4)
+    return int(match.group(1)), int(match.group(2)), match.group(3), int(match.group(4))
+
+
+def get_data_from_encoder_log(
+    text: str,
+) -> Tuple[str, int, int, float, float, float, float, float, float]:
+    """
+    It extracts the following data from a line of the encoder log:
+    slice type, qp, bits, psnr yuv (x3), mssim yuv (x3)
+    :param text: file line
+    :return: Slice type, slice QP, bits, PSNR Y, PSNR U, PSNR V, MSSIM Y, MSSIM U, MSSIM V
+    """
+    pattern = r"([I|B])-SLICE\W\s+QP\s+(\d+)\s+\)\s+(\d+)\s+bits\s+\[Y\s+(\d+.\d+)\s+dB\s+U\s+(\d+.\d+)\s+dB\s+V\s+(\d+.\d+)\s+dB\]\s+\[\w+\s+\w+\s+\w+\s+\w+\s+\w+\s+\w+\]\s+\[MS-SSIM\s+Y\s+(\d\.\d+)\s+U\s+(\d+\.\d+)\s+V\s+(\d+\.\d+)"
+    result = get_regex_groups_for_pattern(text, pattern, 9)
+    return (
+        result.group(1),
+        int(result.group(2)),
+        int(result.group(3)),
+        float(result.group(4)),
+        float(result.group(5)),
+        float(result.group(6)),
+        float(result.group(7)),
+        float(result.group(8)),
+        float(result.group(9)),
+    )
+
+
+def get_regex_groups_from_org_zip_file(file_stem: str) -> re.match:
+    file_pattern = r"((\w+_\w+_*\w*)_(\d+)_part(\d+))"
+    return get_regex_groups_for_pattern(file_stem, file_pattern, 4)
+
+
+def get_bitrate_and_mse_from_encoder_log(
+    text: str,
+) -> Tuple[float, float, float, float, float]:
+    """
+    Extracts the bitrate and MSEs from the summary line in the encoder log
+    :param text: file line
+    :return: Bit-rate, MSE Y, MSE U, MSE V
+    """
+    pattern = r"(\d+\.\d+)\s+\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+\s+\w+\s+\w+\s+\w+\s+\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+\s+(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)"
+    result = get_regex_groups_for_pattern(text, pattern, 5)
+    return (
+        float(result.group(1)),
+        float(result.group(2)),
+        float(result.group(3)),
+        float(result.group(4)),
+        float(result.group(5)),
+    )
diff --git a/training/training_scripts/NN_Adaptive_Filtering/wu_decoding.py b/training/training_scripts/NN_Adaptive_Filtering/wu_decoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..12afb7130bf9dff3acaa7675d23c6d8cc4658d13
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/wu_decoding.py
@@ -0,0 +1,131 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+import argparse
+import os
+import sys
+from pathlib import Path
+from tempfile import NamedTemporaryFile
+
+import nnc_core
+import torch
+from framework.mpeg_applications.utils import icnn_tools
+
+from conversion.sadl2torch import create_model, map_params
+from conversion.torch2sadl import torch_to_onnx
+from util.file_system import read_json_file, write_json_file
+from wu_encoding import get_model_data_info
+
+
+def parse_arguments() -> argparse.Namespace:
+    parser = argparse.ArgumentParser("Decoding weight-update")
+    parser.add_argument("--arch", type=str, required=True, help="Model architecture")
+    parser.add_argument(
+        "--base_model", type=str, required=True, help="Path to base model"
+    )
+    parser.add_argument(
+        "--model_arch", type=str, default="lop2", help="Name of the model architecture"
+    )
+    parser.add_argument(
+        "--nnr_bitstream", type=str, required=True, help="Path to NNR bitstream"
+    )
+    parser.add_argument(
+        "--reco_model", type=str, required=True, help="Path to reconstructed model"
+    )
+    return parser.parse_args()
+
+
+def decode_weight_udpate_and_2sadl():
+    if "lop2" != args.arch:
+        print(f"{args.arch} not supported")
+
+    base_model_config = read_json_file("models/models.json")["lop2"]
+    nnr_bitstream_file = Path(args.nnr_bitstream)
+
+    with NamedTemporaryFile() as dq_torch_file, NamedTemporaryFile() as reco_torch_file, NamedTemporaryFile() as onnx_file, NamedTemporaryFile() as sadl_f_file, NamedTemporaryFile() as quantiser_file, NamedTemporaryFile() as base_quantizers_file, NamedTemporaryFile() as map_orig_params_to_new_file:
+        dq_model, base_model_quantizers = create_model(
+            args.base_model, base_model_config
+        )
+        torch.save(dq_model.state_dict(), dq_torch_file.name)
+        write_json_file(base_model_quantizers, base_quantizers_file.name)
+
+        base_model, base_model_params = get_model_data_info(
+            dq_torch_file.name, base_model_config, False
+        )
+
+        map_orig_params_to_new = map_params(base_model_config)
+        write_json_file(map_orig_params_to_new, map_orig_params_to_new_file.name)
+
+        # Decode
+        diff_dec_approx_params = {"parameters": {}, "put_node_depth": {}}
+        diff_rec_approx_data, bs_size, dec_model_info, res = nnc_core.dec_and_rec(
+            nnr_bitstream_file,
+            True,
+            tml=None,
+            epoch=1,
+            client=0,
+            parameter_index=base_model.model_info["parameter_index"],
+            parameter_dimensions=base_model.model_info["parameter_dimensions"],
+            dec_approx_param_base=diff_dec_approx_params,
+            update_base_param=True,
+        )
+
+        # Restore torch model
+        restored_params = icnn_tools.add(
+            base_model_params, diff_rec_approx_data["parameters"]
+        )
+        base_model.restore_and_save(restored_params, reco_torch_file.name)
+
+        # Torch to Onnx
+        torch_to_onnx(reco_torch_file.name, onnx_file.name, base_model_config)
+
+        # onnx to SADL float
+        command = (
+            f"python conversion/onnx2sadl_modified.py --input_onnx {onnx_file.name} --output {sadl_f_file.name} "
+            f"--out_quant {quantiser_file.name} --base_quantizers {base_quantizers_file.name} "
+            f"--orig_params_to_new {map_orig_params_to_new_file.name}"
+        )
+        if os.system(command):
+            sys.exit(-1)
+
+        # convert float SADL to int16
+        command_int = f"tail -n 1 {quantiser_file.name} | ./naive_quantization {sadl_f_file.name} {args.reco_model}"
+        if os.system(command_int):
+            sys.exit(-1)
+
+
+if __name__ == "__main__":
+    args = parse_arguments()
+    decode_weight_udpate_and_2sadl()
+    sys.exit(0)
diff --git a/training/training_scripts/NN_Adaptive_Filtering/wu_encoding.py b/training/training_scripts/NN_Adaptive_Filtering/wu_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b107ef1695205ad2f2ee4788f9547d978f885806
--- /dev/null
+++ b/training/training_scripts/NN_Adaptive_Filtering/wu_encoding.py
@@ -0,0 +1,253 @@
+"""
+/* The copyright in this software is being made available under the BSD
+* License, included below. This software may be subject to other third party
+* and contributor rights, including patent rights, and no such rights are
+* granted under this license.
+*
+* Copyright (c) 2010-2024, ITU/ISO/IEC
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+*  * Redistributions of source code must retain the above copyright notice,
+*    this list of conditions and the following disclaimer.
+*  * Redistributions in binary form must reproduce the above copyright notice,
+*    this list of conditions and the following disclaimer in the documentation
+*    and/or other materials provided with the distribution.
+*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+*    be used to endorse or promote products derived from this software without
+*    specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from pathlib import Path
+from tempfile import NamedTemporaryFile
+from typing import Any, Dict, Tuple, Union
+
+import nnc_core
+import numpy as np
+from framework.mpeg_applications.utils import icnn_tools
+
+import config
+from conversion.torch2sadl import onnx_to_sadl, torch_to_onnx
+from models.model_lop2_nnr import PytorchModel
+from util.regex import get_regex_groups_from_org_zip_file
+
+
+def create_enc_info() -> Dict:
+    param_opt = True
+    temporal_context = False
+
+    info = {
+        "cabac_unary_length_minus1": 10,
+        "param_opt_flag": param_opt,
+        "partial_data_counter": 0,
+    }
+
+    if config.PUT_SYNTAX():
+        info["node_id_present_flag"] = 1
+        info["device_id"] = 0
+        info["parent_node_id_present_flag"] = 1
+        info["parent_node_id_type"] = nnc_core.hls.ParentNodeIdType.ICNN_NDU_ID
+        info["parent_device_id"] = 0
+
+    if config.TEMPORAL_CONTEXT():
+        info["temporal_context_modeling_flag"] = 1 if temporal_context else 0
+
+    return info
+
+
+def get_model_data_info(
+    model_path: Union[Path, str],
+    model_config: Dict,
+    is_overf_model: bool,
+    block_size: int = 0,
+    pad_size: int = 0,
+) -> Tuple[PytorchModel, Dict]:
+    model = PytorchModel(
+        model_config,
+        block_size,
+        pad_size,
+    )
+    path = model_path if isinstance(model_path, str) else str(model_path)
+    model_params = model.load_model(path, is_overf_model)
+    return model, model_params
+
+
+def compress_one_model_and_2sadl(
+    overfitted_model_dir: Path,
+    base_model_file: Path,
+    model_config: Dict,
+    base_quantizers: Path,
+    map_orig_params_to_new,
+    nnr_models_dir: Path,
+    onnx_models_dir: Path,
+    sadl_float_dir: Path,
+    sadl_int_dir: Path,
+    quantizers_dir: Path,
+    train_loader: Any,
+    block_size: int,
+    pad_size: int,
+):
+    enc_info = create_enc_info()
+
+    qp_density = 2
+
+    approx_method = "uniform"
+    nnr_qp = -40
+    opt_qp = True
+    disable_dq = True
+    lambda_scale = 0.0
+    cb_size_ratio = 5000
+    q_mse = 0.00001
+    inc_bn_folding = False
+
+    qp_step = 1
+    bias = 0.005
+
+    overfitted_model_dir = Path(overfitted_model_dir)
+    nnr_models_dir = Path(nnr_models_dir)
+    onnx_models_dir = Path(onnx_models_dir)
+    sadl_float_dir = Path(sadl_float_dir)
+    sadl_int_dir = Path(sadl_int_dir)
+    quantizers_dir = Path(quantizers_dir)
+
+    overfitted_model_file = overfitted_model_dir / "models" / "best_model.pt"
+
+    matched_groups = get_regex_groups_from_org_zip_file(overfitted_model_dir.name)
+    output_stem = matched_groups.group(1)
+
+    print("-----------------------------")
+    print(output_stem)
+    print("-----------------------------")
+
+    _, overfitted_model_params = get_model_data_info(
+        overfitted_model_file, model_config, True, block_size, pad_size
+    )
+    base_model, base_model_params = get_model_data_info(
+        base_model_file, model_config, False, block_size, pad_size
+    )
+
+    approx_data = base_model.init_approx_data(
+        base_model_params, qp_density, scan_order=0
+    )
+
+    approx_info = nnc_core.nnr_model.ApproxInfo(
+        approx_data,
+        base_model.model_info,
+        approx_method,
+        nnr_qp,
+        opt_qp,
+        disable_dq,
+        lambda_scale,
+        cb_size_ratio,
+        q_mse,
+    )
+
+    nnr_qp = np.int32(nnr_qp)
+    diff_qp = nnr_qp
+
+    model_param_diff = icnn_tools.model_diff(overfitted_model_params, base_model_params)
+
+    diff_dec_approx_params = {"parameters": {}, "put_node_depth": {}}
+
+    # Iterative QP
+    if config.OPT_QP() and opt_qp:
+        base_model.set_dataset(train_loader)
+        ref_perf, _, _ = base_model.eval_model(base_model_params)
+        with NamedTemporaryFile(
+            dir="/tmp", suffix=".nnr", prefix="bitstream_", delete=False
+        ) as tmp_nnr_bs:
+            diff_qp = icnn_tools.opt_qp(
+                diff_qp,
+                model_param_diff,
+                base_model_params,
+                diff_dec_approx_params,
+                base_model,
+                base_model.model_info,
+                ref_perf,
+                approx_data,
+                approx_info,
+                enc_info,
+                inc_bn_folding,
+                tmp_nnr_bs.name,
+                1,
+                save_bitstreams=False,
+                tml=None,
+                sbt_args=None,
+                bias=bias,
+                qp_step=qp_step,
+                crit="acc",
+            )
+
+    approx_data["parameters"] = model_param_diff
+
+    approx_info.apply_qp(approx_data, base_model.model_info, diff_qp)
+
+    bitstream_path = nnr_models_dir / "nnr" / f"{output_stem}.nnr"
+
+    # Approximate and encode
+    len_of_bs, enc_diff_rec_approx_data = nnc_core.approx_and_enc(
+        base_model.model_info,
+        approx_data,
+        diff_dec_approx_params,
+        approx_info,
+        enc_info,
+        num_workers=4,
+        bs_filename=bitstream_path,
+        tml=None,
+        n_epochs=1,
+        epoch=1,
+        client=0,
+        sbt_args=None,
+    )
+    assert (
+        len(enc_diff_rec_approx_data["parameters"]) > 0
+    ), "The weight update was fully sparsified"
+
+    # Decode
+    diff_rec_approx_data, bs_size, dec_model_info, res = nnc_core.dec_and_rec(
+        bitstream_path,
+        True,
+        tml=None,
+        epoch=1,
+        client=0,
+        parameter_index=base_model.model_info["parameter_index"],
+        parameter_dimensions=base_model.model_info["parameter_dimensions"],
+        dec_approx_param_base=diff_dec_approx_params,
+        update_base_param=True,
+    )
+
+    # Restoration
+    restored_params = icnn_tools.add(
+        base_model_params, diff_rec_approx_data["parameters"]
+    )
+
+    saved_model_path = nnr_models_dir / f"nnr_{output_stem}.pt"
+    base_model.restore_and_save(restored_params, str(saved_model_path))
+
+    # Torch to Onnx
+    out_onnx_path = onnx_models_dir / saved_model_path.stem
+    torch_to_onnx(saved_model_path, out_onnx_path, model_config)
+
+    # Onnx to SADL and quantized by using base model's quantizers
+    onnx_to_sadl(
+        sadl_float_dir,
+        out_onnx_path,
+        quantizers_dir,
+        sadl_int_dir,
+        base_quantizers,
+        map_orig_params_to_new,
+    )