diff --git a/.gitattributes b/.gitattributes
index 54fda3c6473c15566049239068d7d725d58b5d56..3ffacb90d438646e5b09d2feedd4d5ea43fb4e2e 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -499,3 +499,7 @@ models/super_resolution/Nnsr_LumaCNNSR_Inter_float.sadl filter=lfs diff=lfs merg
 models/super_resolution/Nnsr_LumaCNNSR_Inter_int16.sadl filter=lfs diff=lfs merge=lfs -text
 models/super_resolution/Nnsr_LumaCNNSR_Intra_float.sadl filter=lfs diff=lfs merge=lfs -text
 models/super_resolution/Nnsr_LumaCNNSR_Intra_int16.sadl filter=lfs diff=lfs merge=lfs -text
+*.json filter=lfs diff=lfs merge=lfs -text
+*.index filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.data-* filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 381fcaabb6133939fa123cf73ea24797c282ed34..7d74fb105d04e376477a4f7ec8adfe6d41713156 100644
--- a/README.md
+++ b/README.md
@@ -428,18 +428,16 @@ Content-adaptive post-filter
 ------------------------------------------------------------------------
 
 To activate the content-adaptive post-filter use one of the two following config file:
-
-* int16: [cfg/nn-based/nnpf_int16.cfg](cfg/nn-based/nnpf_int16.cfg)
+[cfg/nn-based/nnpf_int16.cfg](cfg/nn-based/nnpf_int16.cfg)
 
 `--SEINNPostFilterCharacteristicsPayloadFilename4` should be replaced with the absolute path of the
 corresponding NNR bitstream:
-
-* float: [models/post_filter/float/nnr_bitstreams_float](models/post_filter/int16/nnr_bitstreams_int16)
+[models/post_filter/int16/nnr_bitstreams_int16](models/post_filter/int16/nnr_bitstreams_int16)
 
 For both encoding and decoding, the models are specified with `--NnpfModelPath`. For each test, 3 models are
 pre-trained models and one is the over-fitted model.
 
-The following example applies for BlowingBubbles with QP 44:
+The following example applies for BlowingBubbles with QP 42:
 
 ```shell
 MODEL0=models/post_filter/int16/base_models_int16/model0.sadl
diff --git a/models/post_filter/overfitted_models.json b/models/post_filter/overfitted_models.json
index 29ce68c03d0a7ffd8ae2c14bf185c85dbd4db2d8..5e0c5c8b238d42ebc3144319f9ff53216ec21c13 100644
--- a/models/post_filter/overfitted_models.json
+++ b/models/post_filter/overfitted_models.json
@@ -1,163 +1,3 @@
-{
-  "A1_CampfireParty": {
-    "22": 2,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "A1_FoodMarket": {
-    "22": 2,
-    "27": 2,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "A1_Tango": {
-    "22": 2,
-    "27": 2,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "A2_CatRobot": {
-    "22": 2,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "A2_DaylightRoad": {
-    "22": 0,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "A2_ParkRunning": {
-    "22": 2,
-    "27": 2,
-    "32": 2,
-    "37": 2,
-    "42": 3
-  },
-  "B_BQTerrace": {
-    "22": 2,
-    "27": 2,
-    "32": 2,
-    "37": 3,
-    "42": 3
-  },
-  "B_BasketBallDrive": {
-    "22": 2,
-    "27": 0,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "B_Cactus": {
-    "22": 0,
-    "27": 0,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "B_MarketPlace": {
-    "22": 0,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "B_RitualDance": {
-    "22": 2,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "C_BQMall": {
-    "22": 2,
-    "27": 2,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "C_BasketballDrill": {
-    "22": 2,
-    "27": 2,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "C_PartyScene": {
-    "22": 2,
-    "27": 2,
-    "32": 2,
-    "37": 3,
-    "42": 3
-  },
-  "C_RaceHorses_big": {
-    "22": 0,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "D_BQSquare": {
-    "22": 2,
-    "27": 2,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "D_BasketBallPass": {
-    "22": 2,
-    "27": 2,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "D_BlowingBubbles": {
-    "22": 2,
-    "27": 2,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "D_RaceHorses_s": {
-    "22": 2,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "F_ArenaOfValor": {
-    "22": 3,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "F_BBDrillText": {
-    "22": 2,
-    "27": 2,
-    "32": 2,
-    "37": 3,
-    "42": 3
-  },
-  "F_SlideEditing": {
-    "22": 2,
-    "27": 2,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  },
-  "F_SlideShow": {
-    "22": 3,
-    "27": 3,
-    "32": 3,
-    "37": 3,
-    "42": 3
-  }
-}
+version https://git-lfs.github.com/spec/v1
+oid sha256:e3d1f834a56194fba72e241ede2cb1e073c424f9e5d3c020f6f8c31ae1df7a96
+size 2074
diff --git a/training/training_scripts/NN_Post_Filtering/.gitignore b/training/training_scripts/NN_Post_Filtering/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2ee33b4f0b8c366794d0c148b338d5ec05a17a5b
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/.gitignore
@@ -0,0 +1,6 @@
+NCTM/
+org_data/
+vtm_data/
+post_filter_dataset
+finetuning/
+overfitting/
diff --git a/training/training_scripts/NN_Post_Filtering/Readme.md b/training/training_scripts/NN_Post_Filtering/Readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..ccca1f55d53c6df3fc7c12f51f48be324481adb7
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/Readme.md
@@ -0,0 +1,46 @@
+# NN post-filter
+
+The NN post-filter is trained in two stages:
+
+1. Four models are offline fine-tuned on BVI-DVC and DIV2K datasets.
+2. For each test sequence (JVET) and sequence QP, one of the models in (1) is over-fitted.
+
+The input to the NN post-filter are the reconstructed luma and chroma samples along the picture QP.
+The video compression is done with 
+[VTM 11.0 NNVC 1.0](https://vcgit.hhi.fraunhofer.de/jvet-ahg-nnvc/VVCSoftware_VTM/-/tree/VTM-11.0_nnvc-1.0).
+
+## Requirements
+
+* [Conda environment](env.yml)
+* [MPEG NCTM repository](http://mpegx.int-evry.fr/software/MPEG/NNCoding/NCTM)
+
+To get access to the NCTM repository, please follow these steps:
+
+1. Create an account in [http://mpegx.int-evry.fr/software](http://mpegx.int-evry.fr/software)
+2. Email Werner Bailer (werner.bailer@joanneum.at) requesting access to the NCTM project
+3. Wait for your account to be activated
+
+## Setup
+
+```shell
+# Create the BASE env var that points to NN_Post_Filtering absolute path, e.g.
+export BASE=/opt/VVCSoftware_VTM/training/training_scripts/NN_Post_Filtering
+
+# Clone NCTM repository and apply the patch provided
+cd $BASE
+git clone http://mpeg.expert/software/MPEG/NNCoding/NCTM.git
+cd NCTM
+git checkout v1-v2-harmonization
+git apply ../nctm_post_filter.patch
+
+# Activate Conda environment and install DeepCABAC
+conda activate nn-post-filter
+python setup.py build
+python setup.py install
+```
+
+## Process
+
+1. [Data preparation document](data_preparation.md)
+2. [Fine-tuning](finetuning.md)
+3. [Over-fitting, NNR/NNC encoding and decoding and quantisation](overfitting_and_quantisation.md)
diff --git a/training/training_scripts/NN_Post_Filtering/data_preparation.md b/training/training_scripts/NN_Post_Filtering/data_preparation.md
new file mode 100644
index 0000000000000000000000000000000000000000..0c77092da81749cff3be6dd410cf34449126b81b
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/data_preparation.md
@@ -0,0 +1,173 @@
+# Data preparation
+
+The NN post-filter consists of two stages: fine-tuning and over-fitting. Three datasets are used (1) BVI-DVC, (2) DIV2K
+and (3) JVET RA CTC mandatory sequences.
+
+In the fine-tuning stage the training data comes from datasets (1) and (2), whereas the validation data comes
+from dataset (3).
+
+In the over-fitting stage the training data is dataset (3).
+
+## Get training data
+
+1. Create three directories to store the original training data. One for the BVI-DVC dataset, one for the DIV2K dataset
+   and another one for the JVET dataset:
+
+```shell
+cd $BASE
+mkdir -p org_data/bvi_dvc_mp4
+mkdir -p org_data/div2k_png
+mkdir -p org_data/jvet_yuv
+```
+
+2. Download BVI-DVC dataset from [here](https://data.bris.ac.uk/data/dataset/3hj4t64fkbrgn2ghwp9en4vhtn).
+   Make sure the mp4 videos are in `$BASE/org_data/bvi_dvc_mp4`
+
+2. Download all train data tracks under **(NTIRE 2017) Low Res Images** (six in total) and the train data under
+   **High Resolution Images**, from [here](https://data.vision.ee.ethz.ch/cvl/DIV2K/). 
+   Add the zip files to `$BASE/org_data/div2k_png` directory and uncompressed them.
+
+```shell
+cd $BASE/org_data/div2k_png
+
+unzip DIV2K_train_HR.zip
+unzip DIV2K_train_LR_bicubic_X2.zip
+unzip DIV2K_train_LR_bicubic_X3.zip
+unzip DIV2K_train_LR_bicubic_X4.zip
+unzip DIV2K_train_LR_unknown_X2.zip
+unzip DIV2K_train_LR_unknown_X3.zip
+unzip DIV2K_train_LR_unknown_X4.zip
+```
+
+3. Copy JVET RA CTC mandatory sequences into `$BASE/org_data/jvet_yuv`
+
+## Prepare data for video encoding
+
+1. Convert the BVI-DVC mp4 files to YUV. Use the script [prepare_to_code.py](scripts/prepare_to_code.py)
+   to convert all videos in a directory and generate the configuration files for the video encoding.
+
+```shell
+mkdir $BASE/org_data/bvi_dvc_yuv
+cd $BASE/scripts
+
+python prepare_to_code.py --input_dir $BASE/org_data/bvi_dvc_mp4 --output_dir $BASE/org_data/bvi_dvc_yuv --dataset BVI-DVC
+```
+
+2. Convert the PNG files to YUV. Use the script [prepare_to_code.py](scripts/prepare_to_code.py) to convert
+   all images in a directory and generate the configuration files for the video encoding.
+
+```shell
+mkdir $BASE/org_data/div2k_yuv
+cd $BASE/scripts
+
+python prepare_to_code.py --input_dir $BASE/org_data/div2k_png --output_dir $BASE/org_data/div2k_yuv --dataset DIV2K
+```
+
+## Encode/decode original training data
+
+Use [VTM 11.0 NNVC 1.0](https://vcgit.hhi.fraunhofer.de/jvet-ahg-nnvc/VVCSoftware_VTM/-/tree/VTM-11.0_nnvc-1.0) and the
+RA CTC for NNVC: all in-loop filters and sequence QPs {22, 27, 32, 37, 42}.
+
+1. Create three directories to store the coded datasets
+
+```shell
+mkdir -p $BASE/vtm_data/bvi_dvc
+mkdir -p $BASE/vtm_data/div2k
+mkdir -p $BASE/vtm_data/jvet
+```
+
+2. Run the video encoder and decoder for each video sequence, or re-organise the files afterwards, such that within
+   each dataset directory there is a directory for each video sequence and inside one subdirectory for each QP.
+   The directory tree of a dataset should look like this:
+
+```shell
+sequence_name
+├── 22
+│   ├── log_enc.txt
+│   ├── log_dec.txt
+│   └── reco.yuv
+├── 27
+│   └── ...
+├── 32
+│   └── ...
+├── 37
+│   └── ...
+└── 42
+    ├── log_enc.txt
+    ├── log_dec.txt
+    └── reco.yuv
+```
+
+## Generate datasets
+
+Each YUV frame is converted to three 16-bit PNG images (one per channel). These images are then used to train the NN
+post-filter.
+
+Use the script [prepare_to_train.py](scripts/prepare_to_train.py) and run it one for each dataset:
+BVI-DVC, DIV2K and JVET.
+
+**NOTE**. JVET includes 8-bit sequences, which before being converted to 16-bit PNGs need to be converted to yuv420p10le
+using ffmpeg (use the same input file name for the output file name), e.g.:
+
+```shell
+cd $BASE/org_data/jvet_yuv
+
+ffmpeg -s <width>x<height> -pix_fmt yuv420p -f rawvideo -r <fps> -i <input.yuv> -c:v rawvideo -pix_fmt yuv420p10le <output.yuv>
+```
+
+```shell
+mkdir $BASE/post_filter_dataset
+
+cd $BASE/scripts
+
+python prepare_to_train.py --orig_dir $BASE/org_data/bvi_dvc_yuv --deco_dir $BASE/vtm_data/bvi_dvc --output_dir $BASE/post_filter_dataset --dataset BVI-DVC
+python prepare_to_train.py --orig_dir $BASE/org_data/div2k_yuv --deco_dir $BASE/vtm_data/div2k --output_dir $BASE/post_filter_dataset --dataset DIV2K
+python prepare_to_train.py --orig_dir $BASE/org_data/jvet_yuv --deco_dir $BASE/vtm_data/jvet --output_dir $BASE/post_filter_dataset --dataset JVET
+```
+
+Afterwards the output directory tree should look like
+
+```shell
+post_filter_dataset
+├── train_data
+│   ├── deco
+│   │   ├── 0001
+│   │   │   ├── 22
+│   │   │   │   ├── images
+│   │   │   │   └── frames_info.json
+│   │   │   └── ...
+│   │   └── ...
+│   └── orig
+│       └── 0001
+│       │   └── images
+│       └── ...
+└── valid_data
+    ├── deco
+    │   ├── A1_CampfireParty
+    │   │   ├── 22
+    │   │   │   ├── images
+    │   │   │   └── frames_info.json
+    │   │   └── ...
+    │   └── ...
+    └── orig
+        └── A1_CampfireParty
+        │   └── images
+        └── ...
+```
+
+`train_data` will contain the BVI-DVC and DIV2K data
+
+`valid_data` will contain the JVET data.
+
+Each `images` directory contains three PNGs for each frame in the decoded YUV.
+
+Additionally, `frames_info.json` maps the POC to temporal layer, frame (one slice) QP and frame type.
+
+Now, the intermediate original and coded data can be deleted:
+
+```shell
+rm -rf $BASE/org_data
+rm -rf $BASE/vtm_data
+```
+
+[Go back](Readme.md)
diff --git a/training/training_scripts/NN_Post_Filtering/env.yml b/training/training_scripts/NN_Post_Filtering/env.yml
new file mode 100644
index 0000000000000000000000000000000000000000..8d147e2260a960f143bfb29de99e33634c54bef1
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/env.yml
@@ -0,0 +1,32 @@
+name: nn-post-filter
+channels:
+  - anaconda
+  - conda-forge
+  - defaults
+dependencies:
+  - cudatoolkit-dev
+  - cudnn=8
+  - cupti=10.1
+  - pip=21.0.1=py37h06a4308_0
+  - python=3.7.9=h7579374_0
+  - pip:
+    - Click==7.0
+    - dcase-util==0.2.10
+    - imageio
+    - h5py==2.10.0
+    - matplotlib
+    - memory_profiler==0.55.0
+    - numba==0.48.0
+    - onnx==1.11.0
+    - opencv-python==4.4.0.46
+    - pandas==1.0.5
+    - Pillow==6.2.2
+    - pybind11==2.5.0
+    - scikit-image==0.17.2
+    - scikit-learn==0.23.1
+    - tensorflow-gpu==2.4.1
+    - tf2onnx==1.11.0
+    - torch==1.8.0
+    - torchvision==0.9.0
+    - tqdm==4.32.2
+prefix: ~/.conda/envs/nn-post-filter
diff --git a/training/training_scripts/NN_Post_Filtering/finetuning.md b/training/training_scripts/NN_Post_Filtering/finetuning.md
new file mode 100644
index 0000000000000000000000000000000000000000..7a613693bd16c750bf4f36cc76c9f47e6a84c54d
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/finetuning.md
@@ -0,0 +1,116 @@
+# Fine-tuning
+
+The NN post-filter uses models pretrained as in-loop filters (JVET-W0131). These models are referred to as
+**original base models**.
+
+The original base models can be downloaded from 
+[JVET-W-EE1-1.4](https://vcgit.hhi.fraunhofer.de/jvet-w-ee1/VVCSoftware_VTM/-/tree/EE1-1.4/models_1_4_1).
+They are also included in within this repository in 
+[resources/orig_base_models](scripts/resources/orig_base_models).
+
+The fine-tuning consists of two steps:
+
+1. Four models are fine-tuned to serve as post-processing filters using the script 
+[training.py](scripts/training.py).
+2. The multiplier parameters are added to the fine-tuned models via the script
+[create_filter_with_multiplier.py](scripts/create_filter_with_multiplier.py).
+
+The resultant models are called **base models**.
+
+The next subsections describe how the four base models are to be generated.
+
+## Base models 0 an 2 
+The post-filter base models 0 and 2 are generated by jointly training the original base models 2 and 3 on 
+inter coded frames:
+
+```shell
+mkdir -p $BASE/finetuning/models_0_2
+
+cd $BASE/scripts
+
+python training.py --stage fine-tuning --joint_training --epochs 18 \
+--train_dir $BASE/post_filter_dataset/train_data --train_prop_file resources/properties/train_data_properties.json \
+--valid_dir $BASE/post_filter_dataset/valid_data --valid_prop_file resources/properties/valid_data_properties.json \
+--use_frame_type --frame_type B --use_random_patches \
+--base_model_dir resources/orig_base_models/model2 --base_model_dir resources/orig_base_models/model3 \
+--output_dir $BASE/finetuning/models_0_2
+```
+
+Afterwards, add the multiplier parameters. Note the resulting models are named `model2` and `model0`.
+
+```shell
+mkdir -p $BASE/finetuning/base_models
+
+python create_filter_with_multiplier.py --base_model_dir $BASE/finetuning/models_0_2/finetuned_model2/OutputModel \
+--output_dir $BASE/finetuning/base_models/model2
+
+python create_filter_with_multiplier.py --base_model_dir $BASE/finetuning/models_0_2/finetuned_model3/OutputModel \
+--output_dir $BASE/finetuning/base_models/model0
+```
+
+## Base model 1
+
+The original base model 1 is fine-tuned on intra coded frames with high QP values:
+
+```shell
+mkdir -p $BASE/finetuning/model1
+
+cd $BASE/scripts
+
+python training.py --stage fine-tuning --epochs 50 --lr 1e-4 \
+--train_dir $BASE/post_filter_dataset/train_data --train_prop_file resources/properties/train_data_properties.json \
+--valid_dir $BASE/post_filter_dataset/valid_data --valid_prop_file resources/properties/valid_data_properties.json \
+--use_frame_type --frame_type I --use_frame_qp --min_frame_qp 32 --max_frame_qp 42 --use_random_patches \
+--base_model_dir resources/orig_base_models/model1 \
+--output_dir $BASE/finetuning/model1
+```
+
+Then, add the multiplier parameters. Note the resulting model is named `model1`.
+
+```shell
+mkdir -p $BASE/finetuning/base_models
+
+python create_filter_with_multiplier.py --base_model_dir $BASE/finetuning/model1/finetuned_model1/OutputModel \
+--output_dir $BASE/finetuning/base_models/model1
+```
+
+## Base model 3
+
+The original base model 3 is fine-tuned on inter coded frames with high QP values (two steps):
+
+First, fine-tune the original base model 3:
+
+```shell
+mkdir -p $BASE/finetuning/model3_stage1
+
+cd $BASE/scripts
+
+python training.py --stage fine-tuning --epochs 14 --lr 1e-3 \
+--train_dir $BASE/post_filter_dataset/train_data --train_prop_file resources/properties/train_data_properties.json \
+--valid_dir $BASE/post_filter_dataset/valid_data --valid_prop_file resources/properties/valid_data_properties.json \
+--use_frame_type --frame_type B --use_frame_qp --min_frame_qp 38 --max_frame_qp 51 --use_random_patches \
+--base_model_dir resources/orig_base_models/model3 --output_dir $BASE/finetuning/model3_stage1
+```
+
+Second, take the resulting model and fine-tune it again:
+
+```shell
+mkdir -p $BASE/finetuning/model3_stage2
+
+python training.py --stage fine-tuning --epochs 15 --lr 1e-4 \
+--train_dir $BASE/post_filter_dataset/train_data --train_prop_file resources/properties/train_data_properties.json \
+--valid_dir $BASE/post_filter_dataset/valid_data --valid_prop_file resources/properties/valid_data_properties.json \
+--use_frame_type --frame_type B --use_frame_qp --min_frame_qp 38 --max_frame_qp 51 --use_random_patches \
+--base_model_dir $BASE/finetuning/model3_stage1/finetuned_model3/OutputModel \
+--output_dir $BASE/finetuning/model3_stage2
+```
+
+Finally, add the multiplier parameters. Note the resulting model is named `model3`:
+
+```shell
+python create_filter_with_multiplier.py \
+--base_model_dir $BASE/finetuning/model3_stage2/finetuned_OutputModel/OutputModel \
+--output_dir $BASE/finetuning/base_models/model3
+```
+
+[Go back](Readme.md)
diff --git a/training/training_scripts/NN_Post_Filtering/nctm_post_filter.patch b/training/training_scripts/NN_Post_Filtering/nctm_post_filter.patch
new file mode 100644
index 0000000000000000000000000000000000000000..be5f66cc880ee17728b34a9149976d3065e3103e
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/nctm_post_filter.patch
@@ -0,0 +1,5004 @@
+diff --git a/config.py b/config.py
+index 720fc1f..7b68b92 100755
+--- a/config.py
++++ b/config.py
+@@ -52,7 +52,7 @@ def TEMPORAL_CONTEXT(): return True
+ #def OPT_QP(): return False
+ def OPT_QP(): return True
+ 
+-#def SPARSE(): return False
++# def SPARSE(): return False
+ def SPARSE(): return True
+ 
+ # Use center PUT for temporal contexts in UC14A
+diff --git a/framework/mpeg_applications/tf_custom/__init__.py b/framework/mpeg_applications/tf_custom/__init__.py
+new file mode 100644
+index 0000000..79faef8
+--- /dev/null
++++ b/framework/mpeg_applications/tf_custom/__init__.py
+@@ -0,0 +1,57 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++
++class Colour:
++    Y = 0
++    Cb = 1
++    Cr = 2
++    YCbCr = 3
++    NUM_COLOURS = 4
++
++
++class Metric:
++    LOSS = 0
++    PSNR = 1
++    DELTA_PSNR_WRT_VTM = 2
++    DELTA_PSNR_WRT_BASE = 3
++    NUM_METRICS = 4
++
++
++COLOUR_LABEL = {Colour.Y: "Y", Colour.Cb: "Cb", Colour.Cr: "Cr", Colour.YCbCr: "YCbCr"}
++COLOUR_WEIGHTS = {Colour.Y: 4.0 / 6.0, Colour.Cb: 1.0 / 6.0, Colour.Cr: 1.0 / 6.0}
++METRIC_LABEL = {
++    Metric.LOSS: "Loss",
++    Metric.PSNR: "PSNR",
++    Metric.DELTA_PSNR_WRT_VTM: "dPSNR_wrt_VTM",
++    Metric.DELTA_PSNR_WRT_BASE: "dPSNR_wrt_base",
++}
+diff --git a/framework/mpeg_applications/tf_custom/dataset.py b/framework/mpeg_applications/tf_custom/dataset.py
+new file mode 100644
+index 0000000..0780f81
+--- /dev/null
++++ b/framework/mpeg_applications/tf_custom/dataset.py
+@@ -0,0 +1,600 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++import random
++from pathlib import Path
++from typing import List, Optional, Tuple, Union
++
++import numpy as np
++import tensorflow as tf
++
++from framework.mpeg_applications.tf_custom.file_system import (
++    check_directory,
++    list_dirs,
++    list_selected_dirs,
++    read_json_file,
++)
++from framework.mpeg_applications.tf_custom.image_ops import (
++    extract_patches,
++    extract_random_patches,
++    interleave_image,
++    pad_image,
++    read_image,
++)
++
++
++class Dataset:
++    def __init__(
++        self,
++        root_dir: str,
++        properties_file: str,
++        sequence_qps: Union[Tuple[str], List[str]],
++        bit_depth: int,
++        block_size: int,
++        pad_size: int,
++        use_frame_type: bool,
++        frame_type: str,
++        use_frame_qps: bool,
++        min_frame_qp: int,
++        max_frame_qp: int,
++        use_random_patches: bool,
++        num_patches: int,
++        is_validation: bool,
++        cache_dataset: bool,
++    ):
++        """
++        Constructor
++        :param root_dir: Root directory for the dataset
++        :param properties_file: JSON file with sequence properties for the dataset
++        :param sequence_qps: List of sequence QPs
++        :param bit_depth: Bit-depth for the data
++        :param block_size: Block size, without padding
++        :param pad_size: Padding size
++        :param use_frame_type: Enable/disable the frame selection based on its type
++        :param frame_type: Frame type (i.e. I or B)
++        :param use_frame_qps: Enable/disable the frame selection based on the frame QP
++        :param min_frame_qp: Minimum frame QP (inclusive)
++        :param max_frame_qp: Maximum frame QP (inclusive)
++        :param use_random_patches: Enable/disable the extraction of ramdom patches
++        :param num_patches: Number of random patches to be extracted
++        :param is_validation: Is this the validation dataset?, that means JVET
++        :param cache_dataset: Enable/disable dataset caching in memory
++        """
++        check_directory(root_dir)
++        self._deco_dir = Path(root_dir) / "deco"
++        self._orig_dir = Path(root_dir) / "orig"
++
++        self._sequence_qps = sequence_qps
++        self._bit_depth = bit_depth
++        self._block_size = block_size
++        self._pad_size = pad_size
++
++        self._seq_config = None
++        self._seq_width = None
++        self._seq_height = None
++        self._seq_num_blocks = None
++        self._load_seq_properties(properties_file)
++
++        self._use_frame_type = use_frame_type
++        self._frame_type = frame_type
++
++        self._use_frame_qps = use_frame_qps
++        self._min_frame_qp = min_frame_qp
++        self._max_frame_qp = max_frame_qp
++
++        self._use_random_patches = use_random_patches
++        if self._use_random_patches:
++            assert num_patches > 0, "At least one patch must be extracted per frame"
++        self._num_patches = num_patches
++
++        self._is_validation = is_validation
++        self._cache_dataset = cache_dataset
++
++    def _load_seq_properties(self, properties_file: str) -> None:
++        """
++        Loads the properties file that contains the sequence info, such as tags and dimensions. In addition, three
++        tables are created to map the sequence name to: (1) width, (2) height and (3) number of non-overlapping
++        blocks available
++        :param properties_file: Absolute path to the dataset properties file
++        """
++        self._seq_config = read_json_file(Path(properties_file))
++
++        tags = []
++        widths = []
++        heights = []
++        num_blocks = []
++
++        for seq_tag in self._seq_config.keys():
++            tags.append(seq_tag)
++            w = self._seq_config[seq_tag]["width"]
++            h = self._seq_config[seq_tag]["height"]
++
++            num_h_blocks = w // (self._block_size * 2)
++            if w % (self._block_size * 2) > 0:
++                num_h_blocks += 1
++
++            num_v_blocks = h // (self._block_size * 2)
++            if h % (self._block_size * 2) > 0:
++                num_v_blocks += 1
++
++            widths.append(w)
++            heights.append(h)
++            num_blocks.append(num_h_blocks * num_v_blocks)
++
++        tags = tf.constant(tags)
++        widths = tf.constant(widths)
++        heights = tf.constant(heights)
++        num_blocks = tf.constant(num_blocks)
++
++        self._seq_width = tf.lookup.StaticHashTable(
++            tf.lookup.KeyValueTensorInitializer(tags, widths, tf.string, tf.int32),
++            default_value=tf.constant(-1),
++        )
++        self._seq_height = tf.lookup.StaticHashTable(
++            tf.lookup.KeyValueTensorInitializer(tags, heights, tf.string, tf.int32),
++            default_value=tf.constant(-1),
++        )
++        self._seq_num_blocks = tf.lookup.StaticHashTable(
++            tf.lookup.KeyValueTensorInitializer(tags, num_blocks, tf.string, tf.int32),
++            default_value=tf.constant(-1),
++        )
++
++    def _get_file_list(
++        self, deco_seq_dirs: List[Path]
++    ) -> Tuple[
++        List[str],
++        List[str],
++        List[str],
++        List[str],
++        List[str],
++        List[str],
++        List[float],
++        List[np.array],
++        List[np.array],
++    ]:
++        """
++        Gets the filenames of the images to be processed (reconstruction and original), the frame QP and the random
++        top-left corner positions to extract the patches
++        :param deco_seq_dirs: List of decoded sequence directories
++        :return Lists of reconstruction and original images (one luma and two chroma each), list of frame QPs and list
++        of top-left corner positions
++        """
++        orig_y = []
++        orig_u = []
++        orig_v = []
++
++        reco_y = []
++        reco_u = []
++        reco_v = []
++
++        qp = []
++
++        pos_x = []
++        pos_y = []
++
++        luma_bs = self._block_size * 2
++
++        # here just simply hard-coded how many frames we extract for seqs in each class
++        # based on total frames in each class (500 frames for each class)
++        # e.g. A class: 1494 frames, 6 seqs, extract 84 frames in each seq
++        # B class: 2800 frames, 5 seqs, extract 100 frames in each seq
++        # C class: 1900 frames, 4 seqs, extract 125 frames in each seq
++        # D class: 1900 frames, 4 seqs, extract 125 frames in each seq
++        # F class: 1900 frames, 4 seqs, extract 125 frames in each seq
++        if self._is_validation:
++            frame_dict = {"A": 84, "B": 100, "C": 125, "D": 125, "F": 125}
++
++        for deco_seq_dir in deco_seq_dirs:
++            seq_name = deco_seq_dir.name
++
++            if self._is_validation:
++                num_frames = frame_dict[seq_name[0]]
++
++            w = self._seq_config[seq_name]["width"]
++            h = self._seq_config[seq_name]["height"]
++
++            if self._use_random_patches:
++                max_x = (
++                    luma_bs * (w // luma_bs)
++                    if w % luma_bs > 0
++                    else luma_bs * (w // luma_bs) - luma_bs
++                )
++                x_range = range(self._pad_size, max_x + 1, 2)
++
++                max_y = (
++                    luma_bs * (h // luma_bs)
++                    if h % luma_bs > 0
++                    else luma_bs * (h // luma_bs) - luma_bs
++                )
++                y_range = range(self._pad_size, max_y + 1, 2)
++
++            if len(self._sequence_qps) > 0:
++                qp_dirs = list_selected_dirs(deco_seq_dir, self._sequence_qps)
++            else:
++                qp_dirs = list_dirs(deco_seq_dir)
++
++            for qp_dir in qp_dirs:
++                frames_info = read_json_file(qp_dir / "frames_info.json")
++
++                reco_img_dir = qp_dir / "images"
++                orig_img_dir = self._orig_dir / seq_name / "images"
++
++                reco_files = reco_img_dir.glob("*_y.png")
++
++                if self._is_validation:
++                    reco_files_cp = reco_img_dir.glob("*_y.png")
++                    total_frames = len(list(reco_files_cp))
++                    random_pocs = np.random.choice(
++                        total_frames - 1, num_frames, replace=False
++                    )
++
++                for reco_file in reco_files:
++                    curr_poc = str(int(reco_file.stem.split("_")[0]))
++
++                    if self._is_validation and int(curr_poc) not in random_pocs:
++                        continue
++
++                    curr_frame_type = frames_info[curr_poc]["frame_type"]
++                    curr_frame_qp = frames_info[curr_poc]["QP"]
++
++                    if self._use_frame_type and self._frame_type != curr_frame_type:
++                        continue
++
++                    if self._use_frame_qps and (
++                        curr_frame_qp < self._min_frame_qp
++                        or curr_frame_qp > self._max_frame_qp
++                    ):
++                        continue
++
++                    orig_y.append(str(orig_img_dir / reco_file.name))
++                    orig_u.append(
++                        str(orig_img_dir / f'{reco_file.name.replace("y", "u")}')
++                    )
++                    orig_v.append(
++                        str(orig_img_dir / f'{reco_file.name.replace("y", "v")}')
++                    )
++
++                    reco_y.append(str(reco_file))
++                    reco_u.append(
++                        str(reco_img_dir / f'{reco_file.name.replace("y", "u")}')
++                    )
++                    reco_v.append(
++                        str(reco_img_dir / f'{reco_file.name.replace("y", "v")}')
++                    )
++
++                    qp.append(float(curr_frame_qp))
++
++                    if self._use_random_patches:
++                        pos_x.append(
++                            np.array(random.sample(x_range, self._num_patches))
++                        )
++                        pos_y.append(
++                            np.array(random.sample(y_range, self._num_patches))
++                        )
++                    else:
++                        pos_x.append(0)
++                        pos_y.append(0)
++
++        return orig_y, orig_u, orig_v, reco_y, reco_u, reco_v, qp, pos_x, pos_y
++
++    @tf.function
++    def read_images(
++        self, orig_y, orig_u, orig_v, reco_y, reco_u, reco_v, qp, pos_x, pos_y
++    ):
++        seq_tag = tf.strings.split(orig_y, "/")[-3]
++        width = self._seq_width.lookup(seq_tag)
++        height = self._seq_height.lookup(seq_tag)
++
++        orig_y = read_image(orig_y, self._bit_depth, width, height)
++        orig_u = read_image(orig_u, self._bit_depth, width // 2, height // 2)
++        orig_v = read_image(orig_v, self._bit_depth, width // 2, height // 2)
++
++        reco_y = read_image(reco_y, self._bit_depth, width, height)
++        reco_u = read_image(reco_u, self._bit_depth, width // 2, height // 2)
++        reco_v = read_image(reco_v, self._bit_depth, width // 2, height // 2)
++
++        qp_step = tf.math.pow(2.0, (qp - 42) / 6.0)
++        pos_x = tf.cast(pos_x, tf.int32)
++        pos_y = tf.cast(pos_y, tf.int32)
++
++        return (
++            seq_tag,
++            orig_y,
++            orig_u,
++            orig_v,
++            reco_y,
++            reco_u,
++            reco_v,
++            qp_step,
++            pos_x,
++            pos_y,
++        )
++
++    @tf.function
++    def pre_process_input(self, seq_tag, y, u, v, qp_step, pos_x, pos_y):
++        """
++        Creates input patches
++        :param seq_tag: Sequence tag/name
++        :param y: luma image
++        :param u: cb image
++        :param v: cr image
++        :param qp_step: QP step
++        :param pos_x: left corner positions
++        :param pos_y: top corner positions
++        :return: Input patches
++        """
++        with tf.device("/gpu:0"):
++            width = self._seq_width.lookup(seq_tag)
++            height = self._seq_height.lookup(seq_tag)
++
++            pos_xx = (pos_x - self._pad_size) // 2
++            pos_yy = (pos_y - self._pad_size) // 2
++
++            y = pad_image(y, width, height, self._block_size * 2, self._pad_size)
++            u = pad_image(
++                u, width // 2, height // 2, self._block_size, self._pad_size // 2
++            )
++            v = pad_image(
++                v, width // 2, height // 2, self._block_size, self._pad_size // 2
++            )
++
++            y_tl, y_tr, y_bl, y_br = interleave_image(y)
++
++            if self._use_random_patches:
++                y_tl = extract_random_patches(
++                    y_tl,
++                    self._block_size + self._pad_size,
++                    pos_xx,
++                    pos_yy,
++                    self._num_patches,
++                )
++                y_tr = extract_random_patches(
++                    y_tr,
++                    self._block_size + self._pad_size,
++                    pos_xx,
++                    pos_yy,
++                    self._num_patches,
++                )
++                y_bl = extract_random_patches(
++                    y_bl,
++                    self._block_size + self._pad_size,
++                    pos_xx,
++                    pos_yy,
++                    self._num_patches,
++                )
++                y_br = extract_random_patches(
++                    y_br,
++                    self._block_size + self._pad_size,
++                    pos_xx,
++                    pos_yy,
++                    self._num_patches,
++                )
++                u = extract_random_patches(
++                    u,
++                    self._block_size + self._pad_size,
++                    pos_xx,
++                    pos_yy,
++                    self._num_patches,
++                )
++                v = extract_random_patches(
++                    v,
++                    self._block_size + self._pad_size,
++                    pos_xx,
++                    pos_yy,
++                    self._num_patches,
++                )
++
++                qp_step = tf.fill(
++                    [
++                        self._num_patches,
++                        self._block_size + self._pad_size,
++                        self._block_size + self._pad_size,
++                        1,
++                    ],
++                    qp_step,
++                )
++            else:
++                y_tl = extract_patches(
++                    y_tl, self._block_size + self._pad_size, self._block_size
++                )
++                y_tr = extract_patches(
++                    y_tr, self._block_size + self._pad_size, self._block_size
++                )
++                y_bl = extract_patches(
++                    y_bl, self._block_size + self._pad_size, self._block_size
++                )
++                y_br = extract_patches(
++                    y_br, self._block_size + self._pad_size, self._block_size
++                )
++                u = extract_patches(
++                    u, self._block_size + self._pad_size, self._block_size
++                )
++                v = extract_patches(
++                    v, self._block_size + self._pad_size, self._block_size
++                )
++
++                qp_step = tf.fill(
++                    [
++                        self._seq_num_blocks[seq_tag],
++                        self._block_size + self._pad_size,
++                        self._block_size + self._pad_size,
++                        1,
++                    ],
++                    qp_step,
++                )
++
++            return tf.concat([y_tl, y_tr, y_bl, y_br, u, v, qp_step], axis=3)
++
++    @tf.function
++    def pre_process_label(self, seq_tag, y, u, v, pos_x, pos_y):
++        """
++        Creates label patches
++        :param seq_tag: Sequence tag/name
++        :param y: luma image
++        :param u: cb image
++        :param v: cr image
++        :param pos_x: left corner positions
++        :param pos_y: top corner positions
++        :return: Label patches
++        """
++        with tf.device("/gpu:0"):
++            pos_x = pos_x - self._pad_size
++            pos_y = pos_y - self._pad_size
++
++            width = self._seq_width.lookup(seq_tag)
++            height = self._seq_height.lookup(seq_tag)
++
++            mask = tf.ones_like(y)
++
++            block_size = self._block_size * 2
++
++            mod = tf.math.floormod(height, block_size)
++            out_height = tf.cond(
++                tf.greater(mod, 0), lambda: height + block_size - mod, lambda: height
++            )
++
++            mod = tf.math.floormod(width, block_size)
++            out_width = tf.cond(
++                tf.greater(mod, 0), lambda: width + block_size - mod, lambda: width
++            )
++
++            y = tf.image.pad_to_bounding_box(y, 0, 0, out_height, out_width)
++            mask = tf.image.pad_to_bounding_box(mask, 0, 0, out_height, out_width)
++            u = tf.image.pad_to_bounding_box(u, 0, 0, out_height // 2, out_width // 2)
++            v = tf.image.pad_to_bounding_box(v, 0, 0, out_height // 2, out_width // 2)
++
++            if self._use_random_patches:
++                y = extract_random_patches(
++                    y, block_size, pos_x, pos_y, self._num_patches
++                )
++                mask = extract_random_patches(
++                    mask, block_size, pos_x, pos_y, self._num_patches
++                )
++                u = extract_random_patches(
++                    u, self._block_size, pos_x // 2, pos_y // 2, self._num_patches
++                )
++                v = extract_random_patches(
++                    v, self._block_size, pos_x // 2, pos_y // 2, self._num_patches
++                )
++            else:
++                y = extract_patches(y, block_size, block_size)
++                mask = extract_patches(mask, block_size, block_size)
++                u = extract_patches(u, self._block_size, self._block_size)
++                v = extract_patches(v, self._block_size, self._block_size)
++
++            y_tl, y_tr, y_bl, y_br = interleave_image(y)
++            mask_tl, mask_tr, mask_bl, mask_br = interleave_image(mask)
++
++            return tf.concat(
++                [
++                    y_tl,
++                    y_tr,
++                    y_bl,
++                    y_br,
++                    u,
++                    v,
++                    mask_tl,
++                    mask_tr,
++                    mask_bl,
++                    mask_br,
++                    mask_tl,
++                    mask_tl,
++                ],
++                axis=3,
++            )
++
++    def _apply_pipeline(
++        self, deco_seq_dirs: List[Path], batch_size: int, seed: int
++    ) -> tf.data.Dataset:
++        """
++        Applies the data pipeline
++        :param deco_seq_dirs: List of decoded sequence directories
++        :param batch_size: Batch size
++        :param seed: Seed for "random" operations
++        :return: dataset
++        """
++        file_list = self._get_file_list(deco_seq_dirs)
++
++        dataset = tf.data.Dataset.from_tensor_slices(file_list)
++        dataset = dataset.shuffle(
++            buffer_size=len(dataset), seed=seed, reshuffle_each_iteration=False
++        )
++
++        dataset = dataset.interleave(
++            lambda *args: tf.data.Dataset.from_tensors(self.read_images(*args)),
++            num_parallel_calls=tf.data.experimental.AUTOTUNE,
++        )
++
++        dataset = dataset.interleave(
++            lambda seq_tag, orig_y, orig_u, orig_v, reco_y, reco_u, reco_v, qp, pos_x, pos_y: tf.data.Dataset.zip(
++                (
++                    tf.data.Dataset.from_tensor_slices(
++                        self.pre_process_input(
++                            seq_tag, reco_y, reco_u, reco_v, qp, pos_x, pos_y
++                        )
++                    ),
++                    tf.data.Dataset.from_tensor_slices(
++                        self.pre_process_label(
++                            seq_tag, orig_y, orig_u, orig_v, pos_x, pos_y
++                        )
++                    ),
++                )
++            ),
++            num_parallel_calls=tf.data.experimental.AUTOTUNE,
++        )
++
++        dataset = dataset.batch(batch_size, drop_remainder=True)
++        if self._cache_dataset:
++            dataset = dataset.cache()
++
++        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
++
++        return dataset
++
++    def create(
++        self, seq_name: Optional[str], batch_size: int, seed: int = 1234
++    ) -> tf.data.Dataset:
++        """
++        Creates the dataset
++        :param seq_name: Sequence name
++        :param batch_size: Batch size
++        :param seed: Seed for "random" operations
++        :return: Dataset
++        """
++        if seq_name is None:
++            deco_seq_dirs = list_dirs(self._deco_dir)
++            random.shuffle(deco_seq_dirs)
++        else:
++            deco_seq_dirs = [self._deco_dir / seq_name]
++
++        dataset = self._apply_pipeline(deco_seq_dirs, batch_size, seed)
++        return dataset
+diff --git a/framework/mpeg_applications/tf_custom/file_system.py b/framework/mpeg_applications/tf_custom/file_system.py
+new file mode 100644
+index 0000000..b31019d
+--- /dev/null
++++ b/framework/mpeg_applications/tf_custom/file_system.py
+@@ -0,0 +1,127 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++import json
++import os
++from pathlib import Path
++from typing import Dict, List, Tuple, Union
++
++
++def check_directory(input_path: str) -> None:
++    """
++    Checks whether the given path exists and corresponds to a directory
++    :param input_path: Absolute path
++    :return:
++    """
++    assert os.path.exists(input_path) and os.path.isdir(
++        input_path
++    ), f"{input_path} is not a directory"
++
++
++def check_file(input_path: str) -> None:
++    """
++    Checks whether the given path exists and corresponds to a file
++    :param input_path: Absolute path
++    """
++    assert os.path.exists(input_path) and os.path.isfile(
++        input_path
++    ), f"{input_path} is not a file"
++
++
++def list_dirs(input_dir: Path) -> List[Path]:
++    """
++    Lists the subdirectories of the given directory
++    :param input_dir: Input directory
++    :return: List of directories
++    """
++    return [f for f in input_dir.iterdir() if f.is_dir()]
++
++
++def list_selected_dirs(
++    input_dir: Path, pattern: Union[Tuple[str], List[str]]
++) -> List[Path]:
++    """
++    Lists the subdirectories that contain a given text in their names
++    :param input_dir: Input directory
++    :param pattern: text to match
++    :return: List of directories
++    """
++    return [f for f in input_dir.iterdir() if f.is_dir() and f.name in pattern]
++
++
++def read_json_file(json_path: Path) -> Dict:
++    """
++    Reads JSON file
++    :param json_path: Absolute path to the JSON file
++    :return: Dictionary containing JSON file data
++    """
++    assert json_path.exists() and json_path.is_file(), f"{json_path} is not a file"
++    with open(json_path, "r") as stream:
++        config = json.load(stream)
++    return config
++
++
++def write_json_file(content: Dict, output_file: Path) -> None:
++    """
++    Writes a dictionary to a JSON file
++    :param content: Dictionary to be saved
++    :param output_file: Absolute path to the JSON file
++    """
++    assert (
++        output_file.parent.exists() and output_file.parent.is_dir()
++    ), f"The parent directory {output_file.parent} does not exist"
++    with open(output_file, "w") as stream:
++        json.dump(content, stream, sort_keys=True, indent=4)
++
++
++def create_vtm_config_file(
++    cfg_file: Path, filename: Path, width: int, height: int, fps: int, num_frames: int
++) -> None:
++    """
++    Creates the sequence config file for VTM encoding
++    :param cfg_file: Output file name
++    :param filename: YUV file name
++    :param width: Width of the YUV
++    :param height: Height of the YUV
++    :param fps: Frame rate of the YUV
++    :param num_frames: Number of frames to be encoded
++    """
++    with open(cfg_file, "w") as stream:
++        stream.write(f"InputFile:           {filename}\n")
++        stream.write(f"SourceWidth:         {width}\n")
++        stream.write(f"SourceHeight:        {height}\n")
++        stream.write(f"InputBitDepth:       10\n")
++        stream.write(f"InputChromaFormat:   420\n")
++        stream.write(f"FrameRate:           {fps}\n")
++        stream.write(f"FrameSkip:           0\n")
++        stream.write(f"FramesToBeEncoded:   {num_frames}\n")
++        stream.write(f"Level:               5.1\n")
+diff --git a/framework/mpeg_applications/tf_custom/image_ops.py b/framework/mpeg_applications/tf_custom/image_ops.py
+new file mode 100644
+index 0000000..ef7bce8
+--- /dev/null
++++ b/framework/mpeg_applications/tf_custom/image_ops.py
+@@ -0,0 +1,204 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++from typing import Tuple
++
++import tensorflow as tf
++
++
++@tf.function
++def read_image(filename: str, bit_depth: int, width: int, height: int) -> tf.Tensor:
++    """
++    Reads an image
++    :param filename: Absolute path to the image
++    :param bit_depth: Target bit-depth
++    :param width: Width of the image
++    :param height: Height of the image
++    :return: 4D tensor BHWC
++    """
++    image = tf.io.read_file(filename)
++    image = tf.image.decode_png(image, 1, tf.uint16)
++    image = tf.cast(tf.image.resize(image, [height, width]), tf.uint16)
++    image = tf.bitwise.right_shift(image, 16 - bit_depth)
++    image = tf.expand_dims(image, axis=0)
++    image = normalise_image(image, bit_depth)
++    return image
++
++
++@tf.function
++def normalise_image(image: tf.Tensor, bit_depth: int) -> tf.Tensor:
++    """
++    Normalises an image to the range [0, 1]
++    :param image: Input image
++    :param bit_depth: Bit-depth of the image
++    :return: Normalised image, 4D tensor BHWC
++    """
++    image = tf.cast(image, tf.float32) / (2**bit_depth - 1)
++    return image
++
++
++@tf.function
++def pad_image(
++    in_image: tf.Tensor, width: int, height: int, block_size: int, pad_size: int
++) -> tf.Tensor:
++    """
++    Applies padding to the input image
++    :param in_image: Input image
++    :param width: Width of the image
++    :param height: Height of the image
++    :param block_size: Size of the actual block (final output size)
++    :param pad_size: Number of samples added to each side of the block size
++    :return: Padded image
++    """
++    left = tf.expand_dims(in_image[:, :, 0, :], axis=2)
++    left = tf.tile(left, [1, 1, pad_size, 1])
++
++    right = tf.expand_dims(in_image[:, :, -1, :], axis=2)
++
++    mod = tf.math.floormod(width, block_size)
++    right = tf.cond(
++        tf.greater(mod, 0),
++        lambda: tf.tile(right, [1, 1, pad_size + block_size - mod, 1]),
++        lambda: tf.tile(right, [1, 1, pad_size, 1]),
++    )
++
++    out_image = tf.concat([left, in_image, right], axis=2)
++
++    top = tf.expand_dims(out_image[:, 0, :, :], axis=1)
++    top = tf.tile(top, [1, pad_size, 1, 1])
++
++    bottom = tf.expand_dims(out_image[:, -1, :, :], axis=1)
++
++    mod = tf.math.floormod(height, block_size)
++    bottom = tf.cond(
++        tf.greater(mod, 0),
++        lambda: tf.tile(bottom, [1, pad_size + block_size - mod, 1, 1]),
++        lambda: tf.tile(bottom, [1, pad_size, 1, 1]),
++    )
++
++    out_image = tf.concat([top, out_image, bottom], axis=1)
++    return out_image
++
++
++@tf.function
++def interleave_image(
++    image: tf.Tensor,
++) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
++    """
++    Interleaves an image into four partitions.
++    Example: http://casu.ast.cam.ac.uk/surveys-projects/wfcam/technical/interleaving
++    :param image: Input image
++    :return: Four image partitions
++    """
++    tl = image[:, 0::2, 0::2, :]
++    tr = image[:, 0::2, 1::2, :]
++    bl = image[:, 1::2, 0::2, :]
++    br = image[:, 1::2, 1::2, :]
++    return tl, tr, bl, br
++
++
++@tf.function
++def de_interleave_luma(block: tf.Tensor, block_sie: int = 2) -> tf.Tensor:
++    """
++    De-interleaves four image partitions into a single image
++    :param block: Image partitions in the form BHWC, where B = 4
++    :param block_sie:
++    :return: Full image
++    """
++    return tf.nn.depth_to_space(block, block_sie)
++
++
++@tf.function
++def extract_patches(image: tf.Tensor, block_size: int, step: int) -> tf.Tensor:
++    """
++    Extracts patches of an image in Z-scan order
++    :param image: Input image
++    :param block_size: Block/patch size
++    :param step: Step size
++    :return: Patches concatenated in the batch dimension
++    """
++    patches = tf.image.extract_patches(
++        image, [1, block_size, block_size, 1], [1, step, step, 1], [1, 1, 1, 1], "VALID"
++    )
++    patches = tf.reshape(patches, [-1, block_size, block_size, image.shape[-1]])
++
++    return patches
++
++
++@tf.function
++def extract_random_patches(
++    image: tf.Tensor, block_size: int, pos_x, pos_y, num_patches: int
++) -> tf.Tensor:
++    """
++    Extracts random patches out of the input image
++    :param image: Input image 4D tensor
++    :param block_size: Patch size
++    :param pos_x: Left corner position
++    :param pos_y: Top corner position
++    :param num_patches: Number of patches to be extracted
++    :return: Patches concatenated in the batch dimension
++    """
++    patches = []
++
++    for i in range(num_patches):
++        patch = tf.image.crop_to_bounding_box(
++            image, pos_y[i], pos_x[i], block_size, block_size
++        )
++        patch = tf.squeeze(patch, axis=0)
++        patches.append(patch)
++
++    patches = tf.stack(patches, axis=0)
++    return patches
++
++
++@tf.function
++def merge_images(first: tf.Tensor, second: tf.Tensor) -> tf.Tensor:
++    """
++    Merges two images in the channel dimension
++    :param first: First image
++    :param second: Second image
++    :return: Merged images
++    """
++    return tf.concat([first, second], axis=3)
++
++
++@tf.function
++def add_zeros_to_image(image: tf.Tensor, channels: int = 3) -> tf.Tensor:
++    """
++    Add zero-filled channels to the input image
++    :param image: Input image
++    :param channels: Number of zero-filled channels to add
++    :return: Image with zero padded channels
++    """
++    zero_bs = tf.zeros_like(image)[:, :, :, :channels]
++    image = merge_images(image, zero_bs)
++    return image
+diff --git a/framework/mpeg_applications/tf_custom/metrics.py b/framework/mpeg_applications/tf_custom/metrics.py
+new file mode 100644
+index 0000000..f4e875f
+--- /dev/null
++++ b/framework/mpeg_applications/tf_custom/metrics.py
+@@ -0,0 +1,139 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++from typing import Dict, Tuple
++
++import tensorflow as tf
++
++from framework.mpeg_applications.tf_custom import Colour, COLOUR_WEIGHTS, Metric
++from framework.mpeg_applications.tf_custom.image_ops import de_interleave_luma
++
++
++@tf.function
++def compute_metrics(
++    ground_truth: tf.Tensor, prediction: tf.Tensor
++) -> Tuple[tf.Tensor, tf.Tensor]:
++    """
++    Computes the MSE and the PSNR between two 4D tensors (single channel), the computation is done in the batch dimension
++    :param ground_truth: Ground-truth tensor
++    :param prediction: Test tensor
++    :return: MSE and PSNR
++    """
++    mse = tf.reduce_mean(
++        tf.math.squared_difference(ground_truth, prediction), axis=[1, 2, 3]
++    )
++    psnr = 10.0 * (tf.math.log(1.0 / mse) / tf.math.log(10.0))
++    return mse, psnr
++
++
++@tf.function
++def compute_loss(
++    ground_truth: tf.Tensor,
++    prediction: tf.Tensor,
++    loss_weights: Dict[int, float] = COLOUR_WEIGHTS,
++) -> Tuple[Dict[int, tf.Tensor], Dict[int, tf.Tensor]]:
++    """
++    Computes the loss function and the associated PSNR for all channels
++    :param ground_truth: Ground-truth tensor
++    :param prediction: Test tensor
++    :param loss_weights: Weights used to compute the average across all channels
++    :return: Channel-wise loss and channel-wise PSNR
++    """
++    mask = ground_truth[:, :, :, 6:]
++    ground_truth = ground_truth[:, :, :, :6]
++    ground_truth = tf.multiply(ground_truth, mask)
++    prediction = tf.multiply(prediction, mask)
++
++    y_orig = de_interleave_luma(ground_truth[:, :, :, :4])
++    y_pred = de_interleave_luma(prediction[:, :, :, :4])
++    y_mse, y_psnr = compute_metrics(y_orig, y_pred)
++
++    cb_mse, cb_psnr = compute_metrics(
++        ground_truth[:, :, :, 4:5], prediction[:, :, :, 4:5]
++    )
++    cr_mse, cr_psnr = compute_metrics(
++        ground_truth[:, :, :, 5:6], prediction[:, :, :, 5:6]
++    )
++
++    y_weight = loss_weights[Colour.Y]
++    cb_weight = loss_weights[Colour.Cb]
++    cr_weight = loss_weights[Colour.Cr]
++
++    mse = y_weight * y_mse + cb_weight * cb_mse + cr_weight * cr_mse
++    psnr = y_weight * y_psnr + cb_weight * cb_psnr + cr_weight * cr_psnr
++
++    mse = {Colour.Y: y_mse, Colour.Cb: cb_mse, Colour.Cr: cr_mse, Colour.YCbCr: mse}
++    psnr = {
++        Colour.Y: y_psnr,
++        Colour.Cb: cb_psnr,
++        Colour.Cr: cr_psnr,
++        Colour.YCbCr: psnr,
++    }
++
++    return mse, psnr
++
++
++@tf.function
++def compute_psnr_gain(
++    test_psnr: Dict[int, tf.Tensor], base_psnr: Dict[int, tf.Tensor]
++) -> Dict[int, tf.Tensor]:
++    """
++    Computes the PSNR gain (delta PSNR = filtered reconstruction PSNR - VTM reconstruction PSNR).
++    Note that if any input PSNR is infinite, the PSNR gain is zero
++    :param test_psnr: PSNR of the filtered reconstruction
++    :param base_psnr: PSNR of the VTM reconstruction
++    :return: PSNR gain
++    """
++    psnr_gain = {}
++    for colour in test_psnr.keys():
++        diff = test_psnr[colour] - base_psnr[colour]
++        is_inf = tf.reduce_any(
++            [tf.math.is_inf(test_psnr[colour]), tf.math.is_inf(base_psnr[colour])],
++            axis=0,
++        )
++        psnr_gain[colour] = tf.where(is_inf, 0.0, diff)
++    return psnr_gain
++
++
++def compute_epoch_metrics(
++    batch_metrics: Dict[int, Dict[int, tf.metrics.Mean]],
++    epoch_metrics: Dict[int, Dict[int, int]],
++) -> None:
++    """
++    Computes the epoch metrics
++    :param batch_metrics: Batch metrics before accumulation
++    :param epoch_metrics: Epoch metrics
++    """
++    for colour in range(Colour.NUM_COLOURS):
++        for metric in range(Metric.NUM_METRICS):
++            epoch_metrics[colour][metric] = batch_metrics[colour][metric].result()
++            batch_metrics[colour][metric].reset_states()
+diff --git a/framework/mpeg_applications/tf_custom/quantisation.py b/framework/mpeg_applications/tf_custom/quantisation.py
+new file mode 100644
+index 0000000..644550e
+--- /dev/null
++++ b/framework/mpeg_applications/tf_custom/quantisation.py
+@@ -0,0 +1,78 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++import deepCABAC
++
++import numpy as np
++from pathlib import Path
++from typing import Dict, Union
++
++from framework.multiplier_model import TensorFlowModel
++
++
++def quantise_overfitted_model(
++    base_model: TensorFlowModel,
++    parameters: Dict[str, np.array],
++    quant_level: int,
++    input_file: Union[str,  Path],
++    output_file: Union[str, Path],
++) -> None:
++    """
++    Quantises a floating-point (over-fitted) model to fixed-point
++    :param base_model: Base model
++    :param parameters: Model parameters used (replace the base model weights)
++    :param quant_level: Quantisation level, i.e. 16
++    :param input_file: Original NNR bitstream
++    :param output_file: Output NNR bitstream
++    """
++    quantizers = base_model.restore_and_quantize(parameters, quant_level)
++
++    # Encode quantize info array
++    encoder = deepCABAC.Encoder()
++    encoder.initCtxModels(10, 1)
++    chan_skip_list = np.zeros(quantizers.shape[0], dtype=np.int16)
++    encoder.encodeLayer(quantizers, 0, 0, 1, chan_skip_list)
++    quantizers_bs = bytearray(encoder.finish().tobytes())
++
++    # open nnr bitstream
++    with open(input_file, "rb") as file:
++        nnr_bs = bytearray(file.read())
++
++    # append quantizers_bs to nnr bitstream
++    # first byte is size of quantizers_bs
++    bs = bytearray(len(quantizers_bs).to_bytes(length=1, byteorder="big"))
++    bs.extend(quantizers_bs)
++    bs.extend(nnr_bs)
++
++    # write to file
++    with open(output_file, "wb") as file:
++        file.write(bs)
+diff --git a/framework/mpeg_applications/utils/icnn_tools.py b/framework/mpeg_applications/utils/icnn_tools.py
+index 61bacd9..aab39a0 100644
+--- a/framework/mpeg_applications/utils/icnn_tools.py
++++ b/framework/mpeg_applications/utils/icnn_tools.py
+@@ -363,13 +363,14 @@ def iter_opt(ref_perf,
+         agg_center_model_it = add(center_model_diff_it_i, center_model_old)
+ 
+         out_val_perf_it_i = model_env.eval_model(agg_center_model_it, bn_folding=inc_bn_folding)
++        # 0 -> MSE, 1 -> PSNR gain, 2 -> PSNR
+ 
+         if crit == 'acc':
+-            val_perf_it_i = out_val_perf_it_i[0]
++            val_perf_it_i = out_val_perf_it_i[1]
+         elif crit == 'loss':
+-            val_perf_it_i = out_val_perf_it_i[2]
++            val_perf_it_i = out_val_perf_it_i[0]
+         elif crit == 'f1':
+-            val_perf_it_i = out_val_perf_it_i[5]
++            val_perf_it_i = out_val_perf_it_i[2]
+ 
+         log_perf += [val_perf_it_i]
+ 
+@@ -391,11 +392,10 @@ def iter_opt(ref_perf,
+                              (hyperparam_it >= lower_bound) and (hyperparam_it <= upper_bound)) or \
+                             (val_perf_opt < ref_perf * 0.997 and (hyperparam_it >= lower_bound) and
+                              (hyperparam_it <= upper_bound))
+-
+         if opt_condition:
+             if direction == 'ascend': # desirable: maximize sparsity and QP
+                 if crit == 'acc' or crit == 'f1':
+-                    opt_update_condition = (val_perf_it_i >= val_perf_opt) or (val_perf_it_i >= ref_perf - 0.1)
++                    opt_update_condition = (val_perf_it_i >= val_perf_opt) or (val_perf_it_i >= ref_perf - 0.005)
+                 elif crit == 'loss':
+                     opt_update_condition = (val_perf_it_i <= val_perf_opt)
+             elif direction == 'descend': # which is usually not desirable
+@@ -432,12 +432,13 @@ if config.OPT_QP():
+     
+         model_cabac = add(prior_model_params, model_param_cabac)
+         out_val_perf = icnn_mdl.eval_model(model_cabac, bn_folding=inc_bn_folding)
++        # 0 -> MSE, 1 -> PSNR, 2 -> PSNR_gain
+         if crit == 'acc':
+-            val_perf = out_val_perf[0]
++            val_perf = out_val_perf[1]
+         elif crit == 'loss':
+-            val_perf = out_val_perf[2]
++            val_perf = out_val_perf[0]
+         elif crit == 'f1':
+-            val_perf = out_val_perf[5]
++            val_perf = out_val_perf[2]
+ 
+         print('Initial NCTM difference-updated model perf: {:.3f} with qp={}'.format(val_perf, qp))
+         print('-----------------------------------------------------------------------------------------------')
+diff --git a/framework/multiplier_model/__init__.py b/framework/multiplier_model/__init__.py
+new file mode 100644
+index 0000000..7301ae4
+--- /dev/null
++++ b/framework/multiplier_model/__init__.py
+@@ -0,0 +1,436 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++
++import logging
++from typing import Any, Dict, Tuple, List
++
++from framework.mpeg_applications.tf_custom import Colour
++from framework.mpeg_applications.tf_custom.file_system import check_directory
++from framework.mpeg_applications.tf_custom.image_ops import add_zeros_to_image
++from framework.mpeg_applications.tf_custom.metrics import (
++    compute_loss,
++    compute_psnr_gain,
++)
++
++LOGGER = logging.getLogger()
++import nnc_core
++import tensorflow as tf
++import onnx
++import tf2onnx
++import numpy as np
++from framework.multiplier_model.nnfilter import FilterWithMultipliers
++
++
++class TensorFlowModel(nnc_core.nnr_model.NNRModel):
++    def __init__(self):
++        self.model = None
++        self.__model_info = None
++        self.__model_parameters = None
++        self.dataset = None
++        self.onnx_model = None
++
++    def load_model(self, model_path: str) -> Dict[str, np.array]:
++        """
++        Loads the models
++        :param model_path: Absolute path to the model
++        :return: Parameters of model
++        """
++        check_directory(model_path)
++        self.model = tf.keras.models.load_model(model_path)
++        org_model, self.__model_info = self.tensorflow_to_nctm()
++        return org_model["parameters"]
++
++    def set_dataset(self, dataset: tf.data.Dataset) -> None:
++        """
++        Assigns the given dataset to be used with the model
++        :param dataset: Dataset
++        """
++        self.dataset = dataset
++
++    def load_from_zoo(
++        self, model_name, dataset_path, batch_size, num_workers, max_iterations
++    ):
++        pass
++
++    @property
++    def model_info(self):
++        return self.__model_info
++
++    @property
++    def model_parameters(self):
++        return self.__model_parameters
++
++    def tensorflow_to_nctm(self) -> Tuple[Dict[str, Any], Dict[str, int]]:
++        """
++        Converts the TensorFlow model to NCTM format
++        """
++        model_data = {"parameters": {}, "reduction_method": "uniform"}
++
++        model_info = {
++            "parameter_type": {},
++            "parameter_dimensions": {},
++            "parameter_index": {},
++            "block_identifier": {},
++            "topology_storage_format": None,
++            "topology_compression_format": nnc_core.nnr_model.TopologyCompressionFormat.NNR_PT_RAW,
++        }
++
++        for idx, weight in enumerate(self.model.variables):
++            if "multiplier" in weight.name:
++                if "_" not in weight.name:
++                    var_name = "m0"
++                else:
++                    var_name = f'm{weight.name.split("/")[0].split("_")[1]}'
++            else:
++                var_name = weight.name
++
++            model_data["parameters"][var_name] = weight.numpy()
++            model_info["parameter_dimensions"][var_name] = weight.shape
++            model_info["parameter_index"][var_name] = idx
++            model_info["parameter_type"][var_name] = (
++                "conv.weight" if "kernel" in var_name else "conv.bias"
++            )
++
++        model_info[
++            "topology_storage_format"
++        ] = nnc_core.nnr_model.TopologyStorageFormat.NNR_TPL_TEF
++
++        return model_data, model_info
++
++    def save_state_dict(self, path, model_data):
++        pass
++
++    def preprocess(self, image, label):
++        pass
++
++    def eval_model(
++        self,
++        parameters: Dict[str, np.array],
++        bn_folding: bool = False,
++        verbose: bool = False,
++    ) -> Tuple[float, float, float]:
++        """
++        Evaluates the model with the given parameters
++        :param parameters: Parameters used to update the model (new weightS)
++        :param bn_folding: Enable/disable batch normalisation folding
++        :param verbose: Enable/disable verbosity
++        :return:
++        """
++        Model = self.model
++
++        if len(parameters.keys()) > 0:
++            for weight in Model.variables:
++                if "multiplier" in weight.name:
++                    if "_" not in weight.name:
++                        var_name = "m0"
++                    else:
++                        var_name = f'm{weight.name.split("/")[0].split("_")[1]}'
++                else:
++                    var_name = weight.name
++                if var_name in parameters:
++                    weight.assign(parameters[var_name])
++
++        pad_size = 8
++        block_size = 64
++
++        mse = tf.metrics.Mean()
++        psnr = tf.metrics.Mean()
++        psnr_gain = tf.metrics.Mean()
++
++        for input_data, label_data in self.dataset:
++            input_data = add_zeros_to_image(input_data)
++
++            _, vtm_psnr = compute_loss(
++                label_data,
++                input_data[
++                    :,
++                    pad_size // 2 : pad_size // 2 + block_size,
++                    pad_size // 2 : pad_size // 2 + block_size,
++                    :6,
++                ],
++            )
++
++            prediction = Model(input_data)
++            pred_mse, pred_psnr = compute_loss(label_data, prediction)
++            pred_psnr_gain = compute_psnr_gain(pred_psnr, vtm_psnr)
++
++            mse.update_state(tf.reduce_mean(pred_mse[Colour.YCbCr]))
++            psnr.update_state(tf.reduce_mean(pred_psnr[Colour.YCbCr]))
++            psnr_gain.update_state(tf.reduce_mean(pred_psnr_gain[Colour.YCbCr]))
++
++        mse = mse.result().numpy()
++        psnr = psnr.result().numpy()
++        psnr_gain = psnr_gain.result().numpy()
++
++        del Model
++        return mse, psnr_gain, psnr
++
++    def train_model(self, parameters):
++        pass
++
++    def restore_and_save(
++        self, parameters: Dict[str, np.array], output_dir: str
++    ) -> None:
++        """
++        Restores and saves the model. The input parameters are used to update the model weights
++        :param parameters: weights to be restored
++        :param output_dir: Directory to save the restored model
++        """
++        Model = self.model
++
++        for weight in Model.variables:
++            if "multiplier" in weight.name:
++                if "_" not in weight.name:
++                    var_name = "m0"
++                else:
++                    var_name = f'm{weight.name.split("/")[0].split("_")[1]}'
++            else:
++                var_name = weight.name
++            if var_name in parameters:
++                weight.assign(parameters[var_name])
++
++        tf.saved_model.save(Model, output_dir)
++        del Model
++
++    def restore_model_without_slice(
++        self, parameters: Dict[str, np.array]
++    ) -> FilterWithMultipliers:
++        """
++        Restores a model without the slicing operations
++        :param parameters: weights to be restored
++        :return: Restored model
++        """
++        model = FilterWithMultipliers()
++        for weight in model.variables:
++            if "multiplier" in weight.name:
++                if "_" not in weight.name:
++                    var_name = "m0"
++                else:
++                    var_name = f'm{weight.name.split("/")[0].split("_")[1]}'
++            else:
++                var_name = weight.name
++            if var_name in parameters:
++                weight.assign(parameters[var_name])
++        tf.keras.backend.clear_session()
++        return model
++
++    def restore_and_convert_to_onnx(
++        self, parameters: Dict[str, np.array]
++    ) -> onnx.ModelProto:
++        """
++        Restores and converts the model to ONNX format
++        :param parameters: weights to update model parameters
++        :return: ONNX model
++        """
++        model = self.restore_model_without_slice(parameters)
++        model_onnx, _ = tf2onnx.convert.from_keras(
++            model, [tf.TensorSpec(shape=(1, 72, 72, 10))], opset=13
++        )
++        del model
++        return model_onnx
++
++    def restore_and_quantize(
++        self, parameters: Dict[str, np.array], quant_level: int
++    ) -> np.ndarray:
++        """
++        Restores and quantises an over-fittede model
++        :param parameters: weight-update parameters
++        :param quant_level: Quantisation level
++        :return: Quantisers
++        """
++        model = self.restore_model_without_slice(parameters)
++        quantizers_dict = model.quantize(quant_level, self.dataset, 0)
++        q_list = []
++        internal_list = []
++        for param in quantizers_dict:
++            q_list.append(quant_level - quantizers_dict[param]["quantizer"])
++            if "kernel" in param or "multiplier" in param:
++                internal_list.append(quantizers_dict[param]["internal"])
++        return np.array(q_list + internal_list).astype(np.int32)
++
++
++def parse_quantizers(
++    quantizer_list: np.ndarray, num_layer: int, max_q: int
++) -> Dict[str, Any]:
++    """
++    Parses a list of quantisers
++    :param quantizer_list: Quantisers
++    :param num_layer: Number of layers in the model
++    :param max_q: Maximum quantiser supported
++    :return: Quantisation info
++    """
++    quant_info = {}
++    internal_start = 3 * num_layer
++    input_q = quantizer_list[0]
++    # 0 -> input quantizer, 1:105 -> layer quantizer, 106: -> layer internal
++    quant_info["input0"] = {"quantizer": max_q - input_q, "internal": 0}
++    quantizer_list = quantizer_list[1:]
++    for i in range(num_layer):
++        k_internal = quantizer_list[internal_start + i * 2]
++        m_internal = quantizer_list[internal_start + i * 2 + 1]
++        quant_info[f"conv2d_{i}/kernel"] = {
++            "quantizer": max_q - quantizer_list[i * 3],
++            "internal": k_internal,
++        }
++        quant_info[f"conv2d_{i}/bias"] = {
++            "quantizer": max_q - quantizer_list[i * 3 + 1],
++            "internal": 0,
++        }
++        quant_info[f"multiplier_{i}"] = {
++            "quantizer": max_q - quantizer_list[i * 3 + 2],
++            "internal": m_internal,
++        }
++    quant_info["conv2d/kernel"] = quant_info["conv2d_0/kernel"]
++    quant_info["conv2d/bias"] = quant_info["conv2d_0/bias"]
++    quant_info["multiplier"] = quant_info["multiplier_0"]
++    return quant_info
++
++
++def quantize_array(
++    array: np.array, qtype: np.dtype, quantizer: int, low: int, high: int
++):
++    quantize_array = np.multiply(array, 2**quantizer)
++    quantize_array = np.clip(quantize_array, low, high)
++    return quantize_array.astype(qtype)
++
++
++def quantize_onnx(
++    onnx_model: onnx.ModelProto, quantizers: List[int], quant_level: int
++) -> onnx.ModelProto:
++    """
++    Quantises an ONNX model
++    :param onnx_model: Model
++    :param quantizers: Quantisers
++    :param quant_level: Quantisation level
++    :return: quantised ONNX model
++    """
++    quant_info = parse_quantizers(quantizers, 35, quant_level)
++    onnx.checker.check_model(onnx_model)
++    model_graph = onnx_model.graph
++    initializers = model_graph.initializer
++
++    HIGHEST = 2 ** (quant_level - 1) - 1
++    LOWEST = -HIGHEST
++
++    if quant_level == 16:
++        data_type = np.int16
++        onnx_type = onnx.TensorProto.INT16
++    elif quant_level == 32:
++        data_type = np.int32
++        onnx_type = onnx.TensorProto.INT32
++    else:
++        raise ValueError("The type is not supported for quantisation")
++
++    # tensor data is in the initializer
++    for initializer in initializers:
++        initializer.data_type = onnx_type
++        layer_name = initializer.name.split("/")[1]
++        if "multiplier" in layer_name:
++            var_name = layer_name
++        elif "Conv2D" in initializer.name:
++            var_name = layer_name + "/kernel"
++        elif "BiasAdd" in initializer.name:
++            var_name = layer_name + "/bias"
++
++        org_array = np.frombuffer(initializer.raw_data, dtype=np.float32)
++        quantizer = quant_info[var_name]["quantizer"]
++
++        initializer.raw_data = quantize_array(
++            org_array, data_type, quantizer, LOWEST, HIGHEST
++        ).tobytes()
++
++    # Add additional info in node (eg. quantizer and internal integer)
++    nodes = model_graph.node
++    for node in nodes:
++        if node.input[0] == "args_0":
++            quantizer_attribute = onnx.helper.make_attribute(
++                "quantizer", quant_info["input0"]["quantizer"]
++            )
++            node.attribute.append(quantizer_attribute)
++            continue
++
++        if node.op_type == "Conv":
++            layer_name = node.output[0].split("/")[1]
++            kernel_name = layer_name + "/kernel"
++            bias_name = layer_name + "/bias"
++            q_attr = onnx.helper.make_attribute(
++                "quantizer",
++                [
++                    quant_info[kernel_name]["quantizer"],
++                    quant_info[bias_name]["quantizer"],
++                ],
++            )
++            i_attr = onnx.helper.make_attribute(
++                "internal", quant_info[kernel_name]["internal"]
++            )
++            node.attribute.append(q_attr)
++            node.attribute.append(i_attr)
++            continue
++
++        if node.op_type == "Mul":
++            layer_name = node.output[0].split("/")[1]
++            var_name = layer_name
++            q_attr = onnx.helper.make_attribute(
++                "quantizer", quant_info[var_name]["quantizer"]
++            )
++            i_attr = onnx.helper.make_attribute(
++                "internal", quant_info[var_name]["internal"]
++            )
++            node.attribute.append(q_attr)
++            node.attribute.append(i_attr)
++            continue
++
++        if node.op_type == "LeakyRelu":
++            layer_name = node.output[0].split("/")[1]
++            var_name = layer_name
++            for attri in node.attribute:
++                if attri.name == "alpha":
++                    alpha_f = attri.f
++                    scale = float(HIGHEST) / abs(alpha_f)
++                    quantizer = int(np.log2(scale))
++                    alpha_q = data_type(alpha_f * 2**quantizer)
++                    attri.f = alpha_q
++            q_attr = onnx.helper.make_attribute("quantizer", quantizer)
++            node.attribute.append(q_attr)
++            continue
++
++    # change data type of input and output
++    inputs = model_graph.input
++    for graph_input in inputs:
++        graph_input.type.tensor_type.elem_type = onnx_type
++
++    outputs = model_graph.output
++    for graph_output in outputs:
++        graph_output.type.tensor_type.elem_type = onnx_type
++
++    return onnx_model
+diff --git a/framework/multiplier_model/nnfilter.py b/framework/multiplier_model/nnfilter.py
+new file mode 100644
+index 0000000..0db3b02
+--- /dev/null
++++ b/framework/multiplier_model/nnfilter.py
+@@ -0,0 +1,1321 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++from typing import Any, Dict, List, Union
++
++import numpy as np
++import tensorflow as tf
++from tensorflow.keras import Model
++from tensorflow.keras.layers import Conv2D, InputLayer, Layer, LeakyReLU
++
++from framework.mpeg_applications.tf_custom.image_ops import add_zeros_to_image
++
++
++class Multiplier(Layer):
++    """
++    Multiplier layer. It is initialised with ones
++    """
++
++    def __init__(self, units=1, dtype=tf.float32):
++        super(Multiplier, self).__init__()
++        self.units = units
++        self._multiplier = None
++        self.type = dtype
++
++    def build(self, input_shape):
++        self._multiplier = tf.Variable(
++            initial_value=tf.ones_initializer()(shape=(self.units), dtype=self.type),
++            trainable=False,
++        )
++
++    def call(self, inputs, **kwargs):
++        return tf.math.multiply(inputs, self._multiplier)
++
++    def get_config(self):
++        config = super(Multiplier, self).get_config()
++        config.update({"units": self.units})
++        return config
++
++
++class FilterWithMultipliers(Model):
++    """
++    Qualcomm"s NN filter with multiplier layers and removing the slicing layers
++    Let W be the conv kernel, b the bias terms and c the multipliers. The layer is applied in the following order:
++    (w * x + b) * c
++    """
++
++    def __init__(self, dtype=tf.float32, **kwargs):
++        super(FilterWithMultipliers, self).__init__(dtype=dtype)
++        self.built = True
++        self._input = InputLayer(input_shape=[72, 72, 10], dtype=dtype)
++        self._conv1 = Conv2D(
++            filters=72,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier1 = Multiplier(units=72, dtype=dtype)
++        self._leaky1 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv2 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier2 = Multiplier(units=72, dtype=dtype)
++        self._leaky2 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv3 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier3 = Multiplier(units=24, dtype=dtype)
++        self._conv4 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier4 = Multiplier(units=24, dtype=dtype)
++        self._conv5 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier5 = Multiplier(units=72, dtype=dtype)
++        self._leaky3 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv6 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier6 = Multiplier(units=24, dtype=dtype)
++        self._conv7 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier7 = Multiplier(units=24, dtype=dtype)
++        self._conv8 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier8 = Multiplier(units=72, dtype=dtype)
++        self._leaky4 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv9 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier9 = Multiplier(units=24, dtype=dtype)
++        self._conv10 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier10 = Multiplier(units=24, dtype=dtype)
++        self._conv11 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier11 = Multiplier(units=72, dtype=dtype)
++        self._leaky5 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv12 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier12 = Multiplier(units=24, dtype=dtype)
++        self._conv13 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier13 = Multiplier(units=24, dtype=dtype)
++        self._conv14 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier14 = Multiplier(units=72, dtype=dtype)
++        self._leaky6 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv15 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier15 = Multiplier(units=24, dtype=dtype)
++        self._conv16 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier16 = Multiplier(units=24, dtype=dtype)
++        self._conv17 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier17 = Multiplier(units=72, dtype=dtype)
++        self._leaky7 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv18 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier18 = Multiplier(units=24, dtype=dtype)
++        self._conv19 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier19 = Multiplier(units=24, dtype=dtype)
++        self._conv20 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier20 = Multiplier(units=72, dtype=dtype)
++        self._leaky8 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv21 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier21 = Multiplier(units=24, dtype=dtype)
++        self._conv22 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier22 = Multiplier(units=24, dtype=dtype)
++        self._conv23 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier23 = Multiplier(units=72, dtype=dtype)
++        self._leaky9 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv24 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier24 = Multiplier(units=24, dtype=dtype)
++        self._conv25 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier25 = Multiplier(units=24, dtype=dtype)
++        self._conv26 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier26 = Multiplier(units=72, dtype=dtype)
++        self._leaky10 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv27 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier27 = Multiplier(units=24, dtype=dtype)
++        self._conv28 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier28 = Multiplier(units=24, dtype=dtype)
++        self._conv29 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier29 = Multiplier(units=72, dtype=dtype)
++        self._leaky11 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv30 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier30 = Multiplier(units=24, dtype=dtype)
++        self._conv31 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier31 = Multiplier(units=24, dtype=dtype)
++        self._conv32 = Conv2D(
++            filters=72,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier32 = Multiplier(units=72, dtype=dtype)
++        self._leaky12 = LeakyReLU(alpha=tf.cast(0.2, dtype).numpy(), dtype=dtype)
++        self._conv33 = Conv2D(
++            filters=24,
++            kernel_size=(1, 1),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier33 = Multiplier(units=24, dtype=dtype)
++        self._conv34 = Conv2D(
++            filters=24,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier34 = Multiplier(units=24, dtype=dtype)
++        self._conv35 = Conv2D(
++            filters=6,
++            kernel_size=(3, 3),
++            strides=(1, 1),
++            padding="same",
++            dilation_rate=(1, 1),
++            dtype=dtype,
++        )
++        self._multiplier35 = Multiplier(units=6, dtype=dtype)
++        if dtype == tf.float32:
++            self.build([None, 72, 72, 10])
++
++    def load_weights(self, base_model_dir: str = None) -> None:
++        """
++        Loads weights from the pretrained model
++        :param base_model_dir: Absolute path to the base model
++        """
++
++        trained_model = tf.keras.models.load_model(base_model_dir)
++        for org_var, new_var in zip(trained_model.variables, self.variables):
++            assert org_var.name == new_var.name, "The models have different variables"
++            new_var.assign(org_var)
++
++    def call(self, input_1):
++        """
++        Applies CNN
++        :param input_1: Patches of size 72x72x10
++        :return: filtered patches
++        """
++        x = self._input(input_1)
++
++        y = self._conv1(x)
++        y = self._multiplier1(y)
++        y = self._leaky1(y)
++
++        y = self._conv2(y)
++        y = self._multiplier2(y)
++        y = self._leaky2(y)
++
++        y = self._conv3(y)
++        y = self._multiplier3(y)
++
++        y = self._conv4(y)
++        y = self._multiplier4(y)
++
++        y = self._conv5(y)
++        y = self._multiplier5(y)
++        y = self._leaky3(y)
++
++        y = self._conv6(y)
++        y = self._multiplier6(y)
++
++        y = self._conv7(y)
++        y = self._multiplier7(y)
++
++        y = self._conv8(y)
++        y = self._multiplier8(y)
++        y = self._leaky4(y)
++
++        y = self._conv9(y)
++        y = self._multiplier9(y)
++
++        y = self._conv10(y)
++        y = self._multiplier10(y)
++
++        y = self._conv11(y)
++        y = self._multiplier11(y)
++        y = self._leaky5(y)
++
++        y = self._conv12(y)
++        y = self._multiplier12(y)
++
++        y = self._conv13(y)
++        y = self._multiplier13(y)
++
++        y = self._conv14(y)
++        y = self._multiplier14(y)
++        y = self._leaky6(y)
++
++        y = self._conv15(y)
++        y = self._multiplier15(y)
++
++        y = self._conv16(y)
++        y = self._multiplier16(y)
++
++        y = self._conv17(y)
++        y = self._multiplier17(y)
++        y = self._leaky7(y)
++
++        y = self._conv18(y)
++        y = self._multiplier18(y)
++
++        y = self._conv19(y)
++        y = self._multiplier19(y)
++
++        y = self._conv20(y)
++        y = self._multiplier20(y)
++        y = self._leaky8(y)
++
++        y = self._conv21(y)
++        y = self._multiplier21(y)
++
++        y = self._conv22(y)
++        y = self._multiplier22(y)
++
++        y = self._conv23(y)
++        y = self._multiplier23(y)
++        y = self._leaky9(y)
++
++        y = self._conv24(y)
++        y = self._multiplier24(y)
++
++        y = self._conv25(y)
++        y = self._multiplier25(y)
++
++        y = self._conv26(y)
++        y = self._multiplier26(y)
++        y = self._leaky10(y)
++
++        y = self._conv27(y)
++        y = self._multiplier27(y)
++
++        y = self._conv28(y)
++        y = self._multiplier28(y)
++
++        y = self._conv29(y)
++        y = self._multiplier29(y)
++        y = self._leaky11(y)
++
++        y = self._conv30(y)
++        y = self._multiplier30(y)
++
++        y = self._conv31(y)
++        y = self._multiplier31(y)
++
++        y = self._conv32(y)
++        y = self._multiplier32(y)
++        y = self._leaky12(y)
++
++        y = self._conv33(y)
++        y = self._multiplier33(y)
++
++        y = self._conv34(y)
++        y = self._multiplier34(y)
++
++        y = self._conv35(y)
++        y = self._multiplier35(y)
++        return y
++
++    def get_config(self):
++        config = super(FilterWithMultipliers, self).get_config()
++        config.update({"name": self.name})
++        return config
++
++    @classmethod
++    def from_config(cls, config, custom_objects=None):
++        return cls(**config)
++
++    def compute_quantizer(self, tensor_max: tf.Tensor, quantize_max: int) -> int:
++        """
++        Computes the quantiser based on te maximum value
++        :param tensor_max: Maximum absolute value in the tensor
++        :param quantize_max: Maximum absolute value supported
++        :return: quantiser
++        """
++        scale = float(quantize_max) / tensor_max
++        quantizer = tf.math.log(scale) / tf.math.log(2.0)
++        return int(quantizer)
++
++    def quantize(
++        self, quant_level: int, dataset: tf.data.Dataset = None, scale: int = 0
++    ) -> Dict[str, Any]:
++        """
++        Quantises the model
++        :param quant_level: Quantisation level
++        :param dataset: Dataset
++        :param scale: Scale
++        :return: Quantisation information
++        """
++        highest_val = 2 ** (quant_level - 1) - 1
++        computation_highest_val = 2 ** (2 * quant_level - 1) - 1
++        max_q = quant_level - 2
++
++        # Simple Quantization
++        quantize_info = {
++            "input0": {"org_array": None, "quantizer": max_q, "internal": 0}
++        }
++
++        for var in self.variables:
++            org_array = tf.convert_to_tensor(var.numpy(), dtype=var.dtype)
++            abs_highest = tf.reduce_max(tf.abs(org_array))
++            quantizer = min(self.compute_quantizer(abs_highest, highest_val), max_q)
++            if "bias" in var.name or "kernel" in var.name:
++                # conv2d/bias:0:0, conv2d_1/kernel:0:0 -> conv2d/bias, conv2d_1/bias (conv2d/kernel, conv2d_1/kernel)
++                var_name = var.name.split(":")[0]
++            if "multiplier" in var.name:
++                # multiplier/Variable:0:0 -> multiplier, multiplier_1 ...
++                var_name = var.name.split("/")[0]
++
++            quantize_info[var_name] = {
++                "org_array": var.numpy(),
++                "quantizer": quantizer,
++                "internal": 0,
++            }
++
++        if dataset is not None:
++            quantize_info = self.quantize_calibration(
++                quantize_info,
++                computation_highest_val,
++                highest_val,
++                max_q,
++                scale,
++                dataset,
++            )
++
++        return quantize_info
++
++    def quantize_calibration(
++        self,
++        quantize_info: Dict,
++        computation_highest: int,
++        value_highest: int,
++        max_q: int,
++        scale: int,
++        dataset: tf.data.Dataset,
++    ) -> Dict[str, Any]:
++        """
++        Calibrates the quantisation
++        :param quantize_info: Quantisation information
++        :param computation_highest: Highest accumulated value
++        :param value_highest:
++        :param max_q: Maximum supported quantiser
++        :param scale: Scale
++        :param dataset: Calibration dataset
++        :return: Quantisation info
++        """
++        max_values_per_layer = []
++        for input_data, _ in dataset:
++            input_data = add_zeros_to_image(input_data)
++            max_abs_outpus = self.get_max_out(input_data, quantize_info)
++            max_values_per_layer.append(max_abs_outpus)
++
++        max_values = np.amax(max_values_per_layer, axis=0)
++        # quantizer for input
++        max_input = max_values[0]
++        input_s = float(value_highest) / max_input
++        input_q = int(np.log2(input_s))
++        input_q = min(input_q, max_q)
++        quantize_info["input0"]["quantizer"] = input_q
++
++        current_q = input_q
++        layer_idx = 0
++        # quantizer for kernel, bias and multiplier
++        for idx in range(1, len(max_values), 4):
++            if layer_idx == 0:
++                kernel_name = "conv2d/kernel"
++                bias_name = "conv2d/bias"
++                mul_name = "multiplier"
++            else:
++                kernel_name = f"conv2d_{layer_idx}/kernel"
++                bias_name = f"conv2d_{layer_idx}/bias"
++                mul_name = f"multiplier_{layer_idx}"
++
++            # conv2d
++            max_conv_middle = max_values[idx]
++            kernel_max_q = (
++                int(np.log2(float(computation_highest) / max_conv_middle))
++                - current_q
++                - scale
++            )
++            kernel_max_q = min(kernel_max_q, quantize_info[kernel_name]["quantizer"])
++            max_conv = max_values[idx + 1]
++            output_q_conv = min(
++                int(np.log2(float(value_highest) / max_conv)) - scale, max_q
++            )
++            quantize_info[kernel_name]["quantizer"] = kernel_max_q
++            quantize_info[kernel_name]["internal"] = max(current_q - output_q_conv, 0)
++            current_q = current_q - quantize_info[kernel_name]["internal"]
++
++            # bias
++            max_bias = max_values[idx + 2]
++            output_q = min(
++                int(np.log2(float(value_highest) / max_bias)) - scale,
++                max_q,
++                quantize_info[bias_name]["quantizer"],
++            )
++            quantize_info[bias_name]["quantizer"] = output_q
++            current_q = output_q
++
++            # multiplier
++            max_mul = max_values[idx + 3]
++            output_q = min(int(np.log2(float(value_highest) / max_mul)) - scale, max_q)
++            quantize_info[mul_name]["internal"] = max(current_q - output_q, 0)
++            current_q = current_q - quantize_info[mul_name]["internal"]
++
++            layer_idx += 1
++
++        return quantize_info
++
++    def get_max_out(
++        self, input_data: tf.Tensor, quantize_info: Dict[str, Any]
++    ) -> List[Union[float, tf.Tensor]]:
++        """
++        Computes the maximum absolute value output for each layer in the NN
++        :param input_data: Input data (batch)
++        :param quantize_info: Quantisation information
++        :return: List of maximum absolute values per layer
++        """
++        max_outputs = []
++        x = self._input(input_data)
++
++        max_outputs.append(tf.reduce_max(tf.abs(x)).numpy())
++
++        abs_conv_mean = tf.reduce_mean(
++            tf.abs(quantize_info["conv2d/kernel"]["org_array"])
++        )
++        input_c = float(tf.shape(x)[3])
++        kernel_w = float(tf.shape(quantize_info["conv2d/kernel"]["org_array"])[0])
++        # max value for computation
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++
++        y = self._conv1(x)
++        # max value after conv
++        max_outputs.append(
++            tf.reduce_max(tf.abs(y - quantize_info["conv2d/bias"]["org_array"])).numpy()
++        )
++        # max value after bias
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier1(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky1(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_1/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv2(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_1/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier2(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky2(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_2/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv3(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_2/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier3(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_3/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv4(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_3/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier4(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_4/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv5(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_4/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier5(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky3(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_5/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv6(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_5/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier6(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_6/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv7(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_6/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier7(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_7/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv8(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_7/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier8(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky4(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_8/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv9(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_8/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier9(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_9/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv10(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_9/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier10(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_10/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv11(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_10/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier11(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky5(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_11/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv12(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_11/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier12(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_12/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv13(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_12/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier13(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_13/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv14(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_13/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier14(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky6(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_14/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv15(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_14/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier15(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_15/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv16(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_15/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier16(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_16/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv17(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_16/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier17(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky7(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_17/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv18(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_17/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier18(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_18/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv19(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_18/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier19(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_19/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv20(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_19/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier20(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky8(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_20/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv21(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_20/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier21(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_21/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv22(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_21/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier22(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_22/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv23(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_22/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier23(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky9(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_23/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv24(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_23/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier24(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_24/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv25(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_24/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier25(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_25/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv26(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_25/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier26(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky10(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_26/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv27(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_26/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier27(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_27/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv28(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_27/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier28(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_28/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv29(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_28/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier29(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky11(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_29/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv30(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_29/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier30(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_30/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv31(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_30/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier31(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_31/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv32(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_31/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier32(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._leaky12(y)
++
++        abs_conv = tf.abs(quantize_info["conv2d_32/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv33(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_32/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier33(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_33/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv34(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_33/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier34(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        abs_conv = tf.abs(quantize_info["conv2d_34/kernel"]["org_array"])
++        abs_conv_mean = tf.reduce_mean(abs_conv)
++        input_c = float(tf.shape(y)[3])
++        kernel_w = float(tf.shape(abs_conv)[0])
++        max_outputs.append(
++            abs_conv_mean * max_outputs[-1] * input_c * kernel_w * kernel_w
++        )
++        y = self._conv35(y)
++        max_outputs.append(
++            tf.reduce_max(
++                tf.abs(y - quantize_info["conv2d_34/bias"]["org_array"])
++            ).numpy()
++        )
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++        y = self._multiplier35(y)
++        max_outputs.append(tf.reduce_max(tf.abs(y)).numpy())
++
++        return max_outputs
+diff --git a/nnc_core/__init__.py b/nnc_core/__init__.py
+index 63c3348..e6f1fc3 100644
+--- a/nnc_core/__init__.py
++++ b/nnc_core/__init__.py
+@@ -70,21 +70,21 @@ def approx_and_enc(
+     ):
+ 
+     # 5 - approximating
+-    with TimerMemoryICNN(tml, epoch, client, "app", "[APP] {} approximator runtime (s)".format(ap_info.approx_info["approx_method"]), scope="{}_approximator".format(ap_info.approx_info["approx_method"])) as t:
+-
+-        approx_data_enc = approximator.approx(
+-            ap_info.approx_info,
+-            model_info,
+-            approx_data,
+-            n_epochs,
+-            epoch,
+-            sbt_args,
+-            num_workers=num_workers
+-        )
++    # with TimerMemoryICNN(tml, epoch, client, "app", "[APP] {} approximator runtime (s)".format(ap_info.approx_info["approx_method"]), scope="{}_approximator".format(ap_info.approx_info["approx_method"])) as t:
++
++    approx_data_enc = approximator.approx(
++        ap_info.approx_info,
++        model_info,
++        approx_data,
++        n_epochs,
++        epoch,
++        sbt_args,
++        num_workers=num_workers
++    )
+ 
+     # 6 - encode the model
+-    with TimerMemoryICNN(tml, epoch, client, "enc", "[ENC] baseline encoder runtime (s)", scope="baseline_encoder") as t:
+-        bitstream      = coder.encode(enc_info, model_info, approx_data_enc, approx_param_base)
++    # with TimerMemoryICNN(tml, epoch, client, "enc", "[ENC] baseline encoder runtime (s)", scope="baseline_encoder") as t:
++    bitstream      = coder.encode(enc_info, model_info, approx_data_enc, approx_param_base)
+ 
+     # Do reconstruction at encoder side!
+     enc_rec_approx_data = copy.deepcopy(approx_data_enc)
+@@ -117,19 +117,19 @@ def dec_and_rec(
+                       'topology_storage_format' : None }
+ 
+     # 7 - decode the model
+-    with TimerMemoryICNN(tml, epoch, client, "dec", "[DEC] baseline decoder runtime (s)", scope="baseline_decoder") as t:
+-        hls_bytes = {}
+-        with open(bitstream_path, "rb") as _file:
+-            bitstream = bytearray(_file.read())
+-        bs_size = len(bitstream)
+-        dec_approx_data = coder.decode(bitstream, dec_model_info, hls_bytes, dec_approx_param_base, update_base_param)
++    # with TimerMemoryICNN(tml, epoch, client, "dec", "[DEC] baseline decoder runtime (s)", scope="baseline_decoder") as t:
++    hls_bytes = {}
++    with open(bitstream_path, "rb") as _file:
++        bitstream = bytearray(_file.read())
++    bs_size = len(bitstream)
++    dec_approx_data = coder.decode(bitstream, dec_model_info, hls_bytes, dec_approx_param_base, update_base_param)
+ 
+     hls_bytes = hls_bytes["mps_bytes"] + sum(hls_bytes["ndu_bytes"])
+ 
+     rec_approx_data = dec_approx_data
+     # 8 - reconstruction
+-    with TimerMemoryICNN(tml, epoch, client, "rec", "[REC] reconstructor runtime (s)", scope="reconstructor") as t:
+-        approximator.rec(rec_approx_data, dec_model_info )
++    # with TimerMemoryICNN(tml, epoch, client, "rec", "[REC] reconstructor runtime (s)", scope="reconstructor") as t:
++    approximator.rec(rec_approx_data, dec_model_info )
+ 
+     # size of the reconstructed model as number of parameters times 32 bits
+     org_size = 0
+diff --git a/run/convert_and_quantise_base_model.py b/run/convert_and_quantise_base_model.py
+new file mode 100644
+index 0000000..e27efad
+--- /dev/null
++++ b/run/convert_and_quantise_base_model.py
+@@ -0,0 +1,121 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++
++import sys
++
++sys.path.insert(0, "")
++
++from framework.multiplier_model import TensorFlowModel
++import click
++from pathlib import Path
++import sys
++import tensorflow as tf
++from framework.multiplier_model import quantize_onnx
++from onnx2sadl import convert_to_sadl
++from framework.mpeg_applications.tf_custom.dataset import Dataset
++
++
++@click.command()
++@click.option("--base_model_dir", default=None, type=click.Path(), help="Base model")
++@click.option(
++    "--dataset_dir", default=None, type=click.Path(), help="Dataset directory"
++)
++@click.option(
++    "--properties_file",
++    default=None,
++    type=click.Path(),
++    help="JSON file with sequence properties",
++)
++@click.option("--bit_depth", default=10, type=int, help="Sequence bit depth")
++@click.option(
++    "--block_size", default=64, type=int, help="Block/patch size of the label"
++)
++@click.option(
++    "--pad_size", default=8, type=int, help="Padding size (sum of left and right)"
++)
++@click.option("--batch_size", default=64, type=int, help="Batch size")
++@click.option(
++    "--output_dir", default=None, type=click.Path(), help="Path to the output directory"
++)
++def quantise_and_convert_to_sadl(
++    base_model_dir,
++    dataset_dir,
++    properties_file,
++    bit_depth,
++    block_size,
++    pad_size,
++    batch_size,
++    output_dir,
++):
++    base_model = TensorFlowModel()
++    base_model_params = base_model.load_model(str(base_model_dir))
++    model_onnx = base_model.restore_and_convert_to_onnx(base_model_params)
++
++    output_dir = Path(output_dir)
++    output_dir.mkdir(exist_ok=True)
++    output_file = output_dir / f"{Path(base_model_dir).name}_float.sadl"
++    convert_to_sadl(model_onnx, output_file)
++
++    quant_level = 16
++    dataset = Dataset(
++        dataset_dir,
++        properties_file,
++        list(),
++        bit_depth,
++        block_size,
++        pad_size,
++        False,
++        None,
++        False,
++        0,
++        0,
++        True,
++        10,
++        False,
++        False,
++    )
++    dataset = dataset.create(None, batch_size)
++    base_model.set_dataset(dataset)
++    quantizers = base_model.restore_and_quantize(base_model_params, quant_level)
++    model_onnx = quantize_onnx(model_onnx, quantizers, quant_level)
++    output_file = output_dir / f"{Path(base_model_dir).name}_int16.sadl"
++    convert_to_sadl(model_onnx, output_file)
++
++
++if __name__ == "__main__":
++    gpus = tf.config.list_physical_devices("GPU")
++    for gpu in gpus:
++        tf.config.experimental.set_memory_growth(gpu, True)
++
++    quantise_and_convert_to_sadl()
++    sys.exit()
+diff --git a/run/decode_multipliers.py b/run/decode_multipliers.py
+new file mode 100644
+index 0000000..8e095ff
+--- /dev/null
++++ b/run/decode_multipliers.py
+@@ -0,0 +1,135 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:#
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++
++import sys
++
++sys.path.insert(0, "")
++
++from framework.mpeg_applications.tf_custom.file_system import check_file
++from framework.mpeg_applications.utils import icnn_tools
++from framework.multiplier_model import TensorFlowModel, quantize_onnx
++from onnx2sadl import convert_to_sadl
++import nnc_core
++
++import click
++import sys
++import tensorflow as tf
++import deepCABAC
++import numpy as np
++import tempfile
++
++
++@click.command()
++@click.option("--base_model_dir", default=None, type=click.Path(), help="Base model")
++@click.option("--nnr_bitstream", default=None, type=click.Path(), help="NNR bitstream")
++@click.option(
++    "--quantise",
++    default=False,
++    type=bool,
++    is_flag=True,
++    help="Quantise the model",
++)
++@click.option(
++    "--output_file", default=None, type=click.Path(), help="Path of Output model"
++)
++def decompress_weight_update(base_model_dir, nnr_bitstream, quantise, output_file):
++    diff_dec_approx_params = {"parameters": {}, "put_node_depth": {}}
++
++    base_model = TensorFlowModel()
++    base_model_params = base_model.load_model(base_model_dir)
++
++    check_file(nnr_bitstream)
++
++    if quantise:
++        # decode quantisation parameters
++        with open(nnr_bitstream, "rb") as file:
++            bs = bytearray(file.read())
++        quantize_info_bs = bs[1 : 1 + bs[0]]
++        decoder = deepCABAC.Decoder()
++        decoder.initCtxModels(10)
++        decoder.setStream(quantize_info_bs)
++        # 176 is the number of quantisation parameters
++        quantizers = np.zeros(176, dtype=np.int32)
++        decoder.decodeLayer(quantizers, 0, 0, 1)
++
++        nnr_bs = bs[1 + bs[0] :]
++        with tempfile.NamedTemporaryFile("wb") as tmp_file:
++            tmp_name = tmp_file.name
++            tmp_file.write(nnr_bs)
++            tmp_file.flush()
++
++            # decode weight updates
++            diff_rec_approx_data, bs_size, dec_model_info, res = nnc_core.dec_and_rec(
++                tmp_name,
++                True,
++                tml=None,
++                epoch=1,
++                client=0,
++                parameter_index=base_model.model_info["parameter_index"],
++                parameter_dimensions=base_model.model_info["parameter_dimensions"],
++                dec_approx_param_base=diff_dec_approx_params,
++                update_base_param=True,
++            )
++
++        restored_params = icnn_tools.add(
++            base_model_params, diff_rec_approx_data["parameters"]
++        )
++        model_onnx = base_model.restore_and_convert_to_onnx(restored_params)
++
++        quant_level = 16
++        model_onnx = quantize_onnx(model_onnx, quantizers, quant_level)
++        convert_to_sadl(model_onnx, output_file)
++    else:
++        # decode weight updates
++        diff_rec_approx_data, bs_size, dec_model_info, res = nnc_core.dec_and_rec(
++            nnr_bitstream,
++            True,
++            tml=None,
++            epoch=1,
++            client=0,
++            parameter_index=base_model.model_info["parameter_index"],
++            parameter_dimensions=base_model.model_info["parameter_dimensions"],
++            dec_approx_param_base=diff_dec_approx_params,
++            update_base_param=True,
++        )
++        restored_params = icnn_tools.add(
++            base_model_params, diff_rec_approx_data["parameters"]
++        )
++        model_onnx = base_model.restore_and_convert_to_onnx(restored_params)
++        convert_to_sadl(model_onnx, output_file)
++
++
++if __name__ == "__main__":
++    gpus = tf.config.list_physical_devices("GPU")
++    for gpu in gpus:
++        tf.config.experimental.set_memory_growth(gpu, True)
++
++    decompress_weight_update()
++    sys.exit()
+diff --git a/run/encode_multipliers.py b/run/encode_multipliers.py
+new file mode 100644
+index 0000000..430e455
+--- /dev/null
++++ b/run/encode_multipliers.py
+@@ -0,0 +1,298 @@
++# The copyright in this software is being made available under the BSD
++# License, included below. This software may be subject to other third party
++# and contributor rights, including patent rights, and no such rights are
++# granted under this license.
++#
++# Copyright (c) 2010-2022, ITU/ISO/IEC
++# All rights reserved.
++#
++# Redistribution and use in source and binary forms, with or without
++# modification, are permitted provided that the following conditions are met:
++#  * Redistributions of source code must retain the above copyright notice,
++#    this list of conditions and the following disclaimer.
++#  * Redistributions in binary form must reproduce the above copyright notice,
++#    this list of conditions and the following disclaimer in the documentation
++#    and/or other materials provided with the distribution.
++#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++#    be used to endorse or promote products derived from this software without
++#    specific prior written permission.
++#
++# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++# THE POSSIBILITY OF SUCH DAMAGE.
++#
++
++import sys
++from pathlib import Path
++
++sys.path.insert(0, "")
++
++import config
++from framework.mpeg_applications.utils import icnn_tools
++from framework.multiplier_model import TensorFlowModel
++import nnc_core
++
++from framework.mpeg_applications.tf_custom.dataset import Dataset
++from framework.mpeg_applications.tf_custom.quantisation import quantise_overfitted_model
++
++import click
++import numpy as np
++import tensorflow as tf
++from typing import Dict, Tuple
++
++
++def create_enc_info() -> Dict[str, int]:
++    """
++    Creates encoding configuration
++    :return: Dictionary that contains the DeepCABAC config parameters
++    """
++    param_opt = True
++    temporal_context = False
++
++    info = {
++        "cabac_unary_length_minus1": 10,
++        "param_opt_flag": param_opt,
++        "partial_data_counter": 0,
++    }
++
++    if config.PUT_SYNTAX():
++        info["node_id_present_flag"] = 1
++        info["device_id"] = 0
++        info["parent_node_id_present_flag"] = 1
++        info["parent_node_id_type"] = nnc_core.hls.ParentNodeIdType.ICNN_NDU_ID
++        info["parent_device_id"] = 0
++
++    if config.TEMPORAL_CONTEXT():
++        info["temporal_context_modeling_flag"] = 1 if temporal_context else 0
++
++    return info
++
++
++def get_model_data_info(
++    base_model_dir: str, overfitted_model_dir: str
++) -> Tuple[
++    TensorFlowModel,
++    Dict[str, tf.Tensor],
++    TensorFlowModel,
++    Dict[str, tf.Tensor],
++]:
++    """
++    Gets the model and its parameters for both the base model and the over-fitted model
++    :param base_model_dir: Path to the base model
++    :param overfitted_model_dir: Path to the over-fitted model
++    :return: Base model, over-fitted model and their parameters
++    """
++    base_model = TensorFlowModel()
++    base_model_params = base_model.load_model(str(base_model_dir))
++
++    overfitted_model = TensorFlowModel()
++    overfitted_model_params = overfitted_model.load_model(str(overfitted_model_dir))
++
++    return base_model, base_model_params, overfitted_model, overfitted_model_params
++
++
++@click.command()
++@click.option(
++    "--dataset_dir", default=None, type=click.Path(), help="Dataset directory"
++)
++@click.option(
++    "--properties_file",
++    default=None,
++    type=click.Path(),
++    help="JSON file with sequence properties",
++)
++@click.option("--seq_name", default=None, type=str, help="Sequence name")
++@click.option("--seq_qp", default=list(), multiple=True, help="Sequence QP [VTM]")
++@click.option("--bit_depth", default=10, type=int, help="Sequence bit depth")
++@click.option(
++    "--block_size", default=64, type=int, help="Block/patch size of the label"
++)
++@click.option(
++    "--pad_size", default=8, type=int, help="Padding size (sum of left and right)"
++)
++@click.option("--batch_size", default=64, type=int, help="Batch size")
++@click.option("--base_model_dir", default=None, type=click.Path(), help="Base model")
++@click.option(
++    "--overfitted_model_dir", default=None, type=click.Path(), help="Overfitted model"
++)
++@click.option(
++    "--output_dir",
++    default=None,
++    type=click.Path(),
++    help="Output directory",
++)
++@click.option(
++    "--cache_dataset",
++    default=False,
++    type=bool,
++    is_flag=True,
++    help="Cache the dataset in RAM",
++)
++def compress_weight_update(
++    dataset_dir,
++    properties_file,
++    seq_name,
++    seq_qp,
++    bit_depth,
++    block_size,
++    pad_size,
++    batch_size,
++    base_model_dir,
++    overfitted_model_dir,
++    output_dir,
++    cache_dataset,
++):
++    enc_info = create_enc_info()
++
++    qp_density = 2
++
++    approx_method = "uniform"
++    nnr_qp = -40
++    opt_qp = True
++    disable_dq = True
++    lambda_scale = 0.0
++    cb_size_ratio = 5000
++    q_mse = 0.00001
++    inc_bn_folding = False
++
++    qp_step = 1
++    bias = 0.005
++
++    (
++        base_model,
++        base_model_params,
++        overfitted_model,
++        overfitted_model_params,
++    ) = get_model_data_info(base_model_dir, overfitted_model_dir)
++
++    approx_data = overfitted_model.init_approx_data(
++        base_model_params, qp_density, scan_order=0
++    )
++
++    approx_info = nnc_core.nnr_model.ApproxInfo(
++        approx_data,
++        base_model.model_info,
++        approx_method,
++        nnr_qp,
++        opt_qp,
++        disable_dq,
++        lambda_scale,
++        cb_size_ratio,
++        q_mse,
++    )
++
++    nnr_qp = np.int32(nnr_qp)
++    diff_qp = nnr_qp
++
++    model_param_diff = icnn_tools.model_diff(overfitted_model_params, base_model_params)
++
++    diff_dec_approx_params = {"parameters": {}, "put_node_depth": {}}
++
++    # Dataset
++    dataset = Dataset(
++        dataset_dir,
++        properties_file,
++        seq_qp,
++        bit_depth,
++        block_size,
++        pad_size,
++        True,
++        "B",
++        False,
++        0,
++        0,
++        False,
++        0,
++        False,
++        cache_dataset,
++    )
++    dataset = dataset.create(seq_name, batch_size)
++
++    # Iterative QP
++    overfitted_model.set_dataset(dataset)
++    base_model.set_dataset(dataset)
++
++    a, ref_acc, b = overfitted_model.eval_model({})
++
++    diff_qp = icnn_tools.opt_qp(
++        diff_qp,
++        model_param_diff,
++        base_model_params,
++        diff_dec_approx_params,
++        base_model,
++        base_model.model_info,
++        ref_acc,
++        approx_data,
++        approx_info,
++        enc_info,
++        inc_bn_folding,
++        "/tmp/bitstream.nnr",
++        1,
++        save_bitstreams=False,
++        tml=None,
++        sbt_args=None,
++        bias=bias,
++        qp_step=qp_step,
++        crit="acc",
++    )
++
++    approx_data["parameters"] = model_param_diff
++
++    approx_info.apply_qp(approx_data, base_model.model_info, diff_qp)
++
++    output_dir = Path(output_dir)
++    output_dir.mkdir(exist_ok=True)
++    nnr_file = output_dir / f"{seq_name}_{seq_qp[0]}_float.nnr"
++
++    # Approximate and encode
++    nnc_core.approx_and_enc(
++        base_model.model_info,
++        approx_data,
++        diff_dec_approx_params,
++        approx_info,
++        enc_info,
++        num_workers=4,
++        bs_filename=nnr_file,
++        tml=None,
++        n_epochs=1,
++        epoch=1,
++        client=0,
++        sbt_args=None,
++    )
++
++    # Quantize the restored model
++    diff_rec_approx_data, bs_size, dec_model_info, res = nnc_core.dec_and_rec(
++        nnr_file,
++        True,
++        tml=None,
++        epoch=1,
++        client=0,
++        parameter_index=base_model.model_info["parameter_index"],
++        parameter_dimensions=base_model.model_info["parameter_dimensions"],
++        dec_approx_param_base=diff_dec_approx_params,
++        update_base_param=True,
++    )
++
++    restored_params = icnn_tools.add(
++        base_model_params, diff_rec_approx_data["parameters"]
++    )
++
++    nnr_file_i16 = output_dir / f"{seq_name}_{seq_qp[0]}_int16.nnr"
++    quantise_overfitted_model(base_model, restored_params, 16, nnr_file, nnr_file_i16)
++
++
++if __name__ == "__main__":
++    gpus = tf.config.list_physical_devices("GPU")
++    for gpu in gpus:
++        tf.config.experimental.set_memory_growth(gpu, True)
++
++    compress_weight_update()
++    sys.exit()
+diff --git a/run/onnx2sadl.py b/run/onnx2sadl.py
+new file mode 100644
+index 0000000..bc8cbad
+--- /dev/null
++++ b/run/onnx2sadl.py
+@@ -0,0 +1,1283 @@
++"""
++/* The copyright in this software is being made available under the BSD
++* License, included below. This software may be subject to other third party
++* and contributor rights, including patent rights, and no such rights are
++* granted under this license.
++*
++* Copyright (c) 2010-2022, ITU/ISO/IEC
++* All rights reserved.
++*
++* Redistribution and use in source and binary forms, with or without
++* modification, are permitted provided that the following conditions are met:
++*
++*  * Redistributions of source code must retain the above copyright notice,
++*    this list of conditions and the following disclaimer.
++*  * Redistributions in binary form must reproduce the above copyright notice,
++*    this list of conditions and the following disclaimer in the documentation
++*    and/or other materials provided with the distribution.
++*  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
++*    be used to endorse or promote products derived from this software without
++*    specific prior written permission.
++*
++* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
++* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
++* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
++* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
++* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
++* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
++* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
++* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
++* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
++* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
++* THE POSSIBILITY OF SUCH DAMAGE.
++"""
++
++from __future__ import print_function
++import argparse
++import onnx
++import copy
++import struct
++from collections import OrderedDict
++from enum import IntEnum
++import numpy as np
++
++# file format:
++# MAGIC: SADL0001 [char[8]]
++# type_model [int32_t] 0:int32, 1:float, 2:int16
++# nb_layers [int32_t]
++# nb_inputs [int32_t]
++# inputs_id [int32_t[nb_inputs]]
++# nb_outputs [int32_t]
++# outputs_id [int32_t[nb_outputs]]
++# (for all layers:)
++#  layer_id [int32_t]
++#  op_id    [int32_t]
++#  name_size [int32_t]
++#  name [char[name_size]]
++#  nb_inputs [int32_t]
++#  intput_ids [int32_t[nb_inputs]]
++#
++# (additional information)
++#  Const_layer:
++#   length_dim [int32_t]
++#   dim [int32_t[length_dim]]
++#   type [int32_t] 0:int32, 1:float32 2:int16
++#   [if integer: quantizer [int32])
++#   data [type[prod(dim)]]
++#
++#  Conv2D
++#    nb_dim_strides [int32_t]
++#    strides [int32_t[nb_dim_strides]]
++#    quantizer [int32_t]
++#
++#  MatMul
++#    quantizer [int32_t]
++#
++#  Mul
++#    quantizer [int32_t]
++#
++#  PlaceHolder
++#   length_dim [int32_t]
++#   dim [int32_t[length_dim]]
++#   quantizer [int32_t]
++#
++#  MaxPool
++#    nb_dim_strides [int32_t]
++#    strides [int32_t[nb_dim_strides]]
++#    nb_dim_kernel [int32_t]
++#    kernel_dim [int32_t[nb_dim_kernel]]
++
++### Define model type and array type
++MODEL_TYPE = None
++NUMPY_TYPE = None
++INPUT_QUANTIZER = None
++
++
++class OPTYPE(IntEnum):
++    Const = 1,
++    Placeholder = 2,
++    Identity = 3,
++    BiasAdd = 4,
++    MaxPool = 5,
++    MatMul = 6,
++    Reshape = 7,
++    Relu = 8,
++    Conv2D = 9,
++    Add = 10,
++    ConcatV2 = 11,
++    Mul = 12,
++    Maximum = 13,
++    LeakyReLU = 14,
++    Transpose = 15,
++    Flatten = 16,
++    Shape = 17,
++    Expand = 18,
++    # In "tf2cpp", the same layer performs the matrix multiplication
++    # and the matrix multiplication by batches.
++    BatchMatMul = 6,
++
++    # "BatchMatMulV2" did not exist in Tensorflow 1.9. It exists in
++    # Tensorflow 1.15.
++    BatchMatMulV2 = 6
++
++    def __repr__(self):
++        return self.name
++
++    def __str__(self):
++        return self.name
++
++
++class DTYPE_SADL(IntEnum):
++    FLOAT = 1,  # float
++    INT8 = 3,  # int8_t
++    INT16 = 2,  # int16_t
++    INT32 = 0  # int32_t
++
++    def __repr__(self):
++        return self.name
++
++    def __str__(self):
++        return self.name
++
++
++class DTYPE_ONNX(IntEnum):
++    # https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto#L483-L485
++    FLOAT = 1,  # float
++    INT8 = 3,  # int8_t
++    INT16 = 5,  # int16_t
++    INT32 = 6,  # int32_t
++    INT64 = 7  # int64_t
++
++    def __repr__(self):
++        return self.name
++
++    def __str__(self):
++        return self.name
++
++
++class Node_Annotation:
++    to_remove = False
++    add_transpose_before = False
++    add_transpose_after = False
++    to_transpose = False
++    layout_onnx = None
++
++    def __repr__(self):
++        return "to_remove={}, to_transpose={}, layout_onnx={}, add_transpose_before={} add_transpose_after={}".format(
++            self.to_remove, self.to_transpose, self.layout_onnx, self.add_transpose_before, self.add_transpose_after)
++
++
++# get attribute name in node
++def getAttribute(node, attr):
++    for a in node.attribute:
++        if a.name == attr: return a
++    return None
++
++
++def transpose_tensor(raw_data, dims):
++    """
++        When convert TF2 to ONNX, ONNX weight's  are not represent in the same way as TF2 weight's
++    """
++    # print(dims)
++    tmp = []
++    tmp.append(dims[2])
++    tmp.append(dims[3])
++    tmp.append(dims[1])
++    tmp.append(dims[0])
++
++    x = np.frombuffer(raw_data, dtype=NUMPY_TYPE)
++    x = x.reshape(tmp[3], tmp[2], tmp[0] * tmp[1]).transpose().flatten()
++    return x.tobytes(), tmp
++
++
++def transpose_matrix(raw_data, dims):
++    x = np.frombuffer(raw_data, dtype=NUMPY_TYPE)
++    tmp = []
++    tmp.append(dims[1])
++    tmp.append(dims[0])
++    x = x.reshape(dims[0], dims[1])
++    x = np.transpose(x)  # moveaxis(x, -2, -1)
++    return x.flatten().tobytes(), tmp
++
++
++def toList(ii):
++    d = []
++    for i in ii: d.append(i)
++    return d
++
++
++def is_constant(name, onnx_initializer):
++    for n in onnx_initializer:
++        if n.name == name: return True
++    return False
++
++
++def is_output(name, onnx_output):
++    for out in onnx_output:
++        if out.name == name:
++            return True
++    return False
++
++
++def parse_graph_input_node(input_node, map_onnx_to_myGraph, to_transpose):
++    map_onnx_to_myGraph[input_node.name] = input_node.name
++    struct = {}
++    struct["inputs"] = []
++    struct["additional"] = {}
++    if to_transpose:  # data_layout == 'nchw' and len(input_node.type.tensor_type.shape.dim)==4:
++        struct["additional"]["dims"] = [input_node.type.tensor_type.shape.dim[0].dim_value,
++                                        input_node.type.tensor_type.shape.dim[2].dim_value,
++                                        input_node.type.tensor_type.shape.dim[3].dim_value,
++                                        input_node.type.tensor_type.shape.dim[1].dim_value]
++    else:
++        struct["additional"]["dims"] = [d.dim_value for d in input_node.type.tensor_type.shape.dim]
++    struct["op_type"] = OPTYPE.Placeholder
++    return struct
++
++
++def extract_additional_data_from_node(data, to_transpose):
++    tmp = {}
++    if data.dims == []:
++        tmp["dims"] = [1]
++    else:
++        tmp["dims"] = [dim for dim in data.dims]
++
++    tmp["raw_data"] = data.raw_data
++
++    if data.data_type == DTYPE_ONNX.FLOAT:
++        tmp["dtype"] = DTYPE_SADL.FLOAT
++    elif data.data_type == DTYPE_ONNX.INT8:
++        tmp["dtype"] = DTYPE_SADL.INT8
++    elif data.data_type == DTYPE_ONNX.INT16:
++        tmp["dtype"] = DTYPE_SADL.INT16
++    elif data.data_type == DTYPE_ONNX.INT32:
++        tmp["dtype"] = DTYPE_SADL.INT32
++    elif data.data_type == DTYPE_ONNX.INT64:
++        def convert_int64_to_int32(binary_data):
++            x = np.frombuffer(binary_data, dtype=np.int64)
++            x = x.astype(np.int32)
++            return x.tobytes()
++
++        tmp["dtype"] = DTYPE_SADL.INT32
++        tmp["raw_data"] = convert_int64_to_int32(tmp["raw_data"])
++    else:
++        raise ValueError("extract_additional_data: Unknown dtype")
++
++    if to_transpose:
++        if len(tmp["dims"]) == 4:
++            tmp["raw_data"], tmp["dims"] = transpose_tensor(tmp["raw_data"], tmp["dims"])
++        elif len(tmp["dims"]) == 2:  # and data_layout == "nchw":
++            tmp["raw_data"], tmp["dims"] = transpose_matrix(tmp["raw_data"], tmp["dims"])
++
++    return tmp["dims"], tmp["raw_data"], tmp["dtype"]
++
++
++def extract_additional_data(name, to_transpose, onnx_graph):
++    for init in onnx_graph.initializer:
++        if name == init.name:      return extract_additional_data_from_node(init, to_transpose)
++    for node in onnx_graph.node:  # not found in initializaer, search in Constant
++        if name == node.output[0]: return extract_additional_data_from_node(node.attribute[0].t, to_transpose)
++    quit("[ERROR] unable to extract data in {}".format(name))
++
++
++def extract_dims(name, onnx_graph):
++    for init in onnx_graph.initializer:
++        if name == init.name:      return init.dims
++    for node in onnx_graph.node:  # not found in initializaer, search in Constant
++        if name == node.output[0]:
++            a = getAttribute(node, "value")
++            if a is not None:
++                return a.t.dims
++            else:
++                return None
++    for node in onnx_graph.input:  # not found in initializaer, search in Constant
++        if name == node.name: return node.type.tensor_type.shape.dim
++    quit("[ERROR] unable to extract dims in {}".format(name))
++
++
++# get the nodes with name as input
++def getNodesWithInput(name, model):
++    L = []
++    for node in model.graph.node:
++        for inp in node.input:
++            if inp == name:
++                L.append(node)
++    return L
++
++
++# get the nodes with name as output
++def getNodesWithOutput(name, model):
++    for node in model.graph.node:
++        for out in node.output:
++            if out == name:
++                return node
++    for node in model.graph.initializer:
++        if node.name == name:
++            return node
++    for node in model.graph.input:
++        if node.name == name:
++            return node
++    quit("[ERROR] not found:".format(name))
++
++
++# get the nodes with name as output
++def getNodesWithOutputNotConst(name, model):
++    for node in model.graph.node:
++        for out in node.output:
++            if out == name:
++                return node
++    for node in model.graph.input:
++        if node.name == name:
++            return node
++    return None
++
++
++# get dims from data
++def getDims(node):
++    if node.data_type != DTYPE_ONNX.INT64:
++        quit("[ERROR] bad node type fpr getDims {}".format(node))
++
++    x = np.frombuffer(node.raw_data, dtype=np.int64)
++    dims = x.tolist()
++    return dims
++
++
++def getInitializer(name, model_onnx):
++    for node in model_onnx.graph.initializer:
++        if node.name == name: return node
++    return None
++
++
++def add_transpose(node, myGraph, map_onnx_to_myGraph):
++    # Transpose inserted
++    # Const
++    reshape_coef_name = node.input[0] + "_COEF_TRANSPOSE_NOT_IN_GRAPH"
++    myGraph[reshape_coef_name] = {}
++    myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const
++    myGraph[reshape_coef_name]["inputs"] = []
++    additional = {}
++    additional["dims"] = [4]
++    additional["raw_data"] = np.array([0, 3, 1, 2], dtype=np.int32).tobytes()  # nhwc -> nchw
++    additional["dtype"] = DTYPE_SADL.INT32
++    additional["data"] = node
++    myGraph[reshape_coef_name]["additional"] = additional
++    map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name
++
++    nname = node.input[0] + "_TRANSPOSE_NOT_IN_GRAPH"
++    myGraph[nname] = {}
++    myGraph[nname]["op_type"] = OPTYPE.Transpose
++    myGraph[nname]["inputs"] = [map_onnx_to_myGraph[node.input[0]], reshape_coef_name]
++    map_onnx_to_myGraph[nname] = nname
++    return nname
++
++
++def add_transpose_after(node, myGraph, map_onnx_to_myGraph):
++    # Transpose inserted
++    # Const
++    reshape_coef_name = node.output[0] + "_COEF_TRANSPOSE_AFTER_NOT_IN_GRAPH"
++    myGraph[reshape_coef_name] = {}
++    myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const
++    myGraph[reshape_coef_name]["inputs"] = []
++    additional = {}
++    additional["dims"] = [4]
++    additional["raw_data"] = np.array([0, 2, 3, 1], dtype=np.int32).tobytes()  # nchw -> nhwc
++    additional["dtype"] = DTYPE_SADL.INT32
++    additional["data"] = node
++    myGraph[reshape_coef_name]["additional"] = additional
++    map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name
++
++    nname = node.output[0] + "_TRANSPOSE_AFTER_NOT_IN_GRAPH"
++    myGraph[nname] = {}
++    myGraph[nname]["op_type"] = OPTYPE.Transpose
++    myGraph[nname]["inputs"] = [map_onnx_to_myGraph[node.output[0]], reshape_coef_name]
++    map_onnx_to_myGraph[nname] = nname
++    map_onnx_to_myGraph[node.output[0]] = nname
++    return nname
++
++
++def parse_graph_node(node, model_onnx, myGraph, node_annotation, map_onnx_to_myGraph, verbose):
++    if verbose > 1: print("parse node", node.name)
++
++    if node_annotation[
++        node.name].add_transpose_before:  # layout_onnx == 'nchw' : # need to go back to original layout before reshape
++        n0name = add_transpose(node, myGraph, map_onnx_to_myGraph)
++    else:
++        if len(node.input) >= 1:
++            n0name = node.input[0]
++        else:
++            n0name = None
++
++    if node.op_type == "Conv" or node.op_type == "Gemm":
++        nb_inputs = len(node.input)
++        if (nb_inputs != 3) and (nb_inputs != 2): raise Exception("parse_graph_node: Error on node type")
++        additional = {}
++        # Const: weight
++        additional["data"] = node
++        n2 = getNodesWithOutput(node.input[1], model_onnx)
++        additional["dims"], additional["raw_data"], additional["dtype"] = extract_additional_data(node.input[1],
++                                                                                                  node_annotation[
++                                                                                                      n2.name].to_transpose,
++                                                                                                  model_onnx.graph)
++        map_onnx_to_myGraph[node.input[1]] = node.input[1]
++
++        myGraph[node.input[1]] = {}
++        myGraph[node.input[1]]["inputs"] = []
++        myGraph[node.input[1]]["additional"] = additional
++        myGraph[node.input[1]]["op_type"] = OPTYPE.Const
++
++        # Conv2d
++        inputs, additional = [], {}
++        inputs = [map_onnx_to_myGraph[n0name]] + [map_onnx_to_myGraph[node.input[1]]]
++
++        additional["data"] = node
++        if node.op_type == "Conv":
++            a = getAttribute(node, 'strides')
++            additional["strides"] = a.ints
++            a = getAttribute(node, 'pads')
++            additional["pads"] = a.ints
++
++        if nb_inputs == 2:
++            map_onnx_to_myGraph[node.output[0]] = node.output[0]
++        elif nb_inputs == 3:
++            map_onnx_to_myGraph[node.output[0]] = node.output[0] + "_NOT_IN_GRAPH"
++
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["inputs"] = inputs
++        myGraph[node.output[0]]["additional"] = additional
++        if node.op_type == "Conv":
++            myGraph[node.output[0]]["op_type"] = OPTYPE.Conv2D
++        elif node.op_type == "Gemm":
++            myGraph[node.output[0]]["op_type"] = OPTYPE.MatMul
++
++        if nb_inputs == 3:
++            additional = {}
++            # Const: bias
++            additional["data"] = node
++            additional["dims"], additional["raw_data"], additional["dtype"] = extract_additional_data(node.input[2],
++                                                                                                      False,
++                                                                                                      model_onnx.graph)
++            map_onnx_to_myGraph[node.input[2]] = node.input[2]
++            myGraph[node.input[2]] = {}
++            myGraph[node.input[2]]["inputs"] = []
++            myGraph[node.input[2]]["additional"] = additional
++            myGraph[node.input[2]]["op_type"] = OPTYPE.Const
++            # BiasAdd
++            inputs, additional = [], {}
++            inputs = [node.output[0]] + [map_onnx_to_myGraph[node.input[2]]]
++            additional["data"] = node
++            map_onnx_to_myGraph[node.output[0] + "_NOT_IN_GRAPH"] = None
++            myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {}
++            myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = inputs
++            myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional
++            myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.BiasAdd
++
++    elif node.op_type == "Relu":
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Relu
++        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Constant":  # ~ like an initializer
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Const
++        myGraph[node.output[0]]["inputs"] = []
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Add":
++        swap_inputs = False
++        if is_constant(n0name, model_onnx.graph.initializer):
++            additional = {}
++            additional["data"] = node
++            additional["dims"], additional["raw_data"], additional["dtype"] = extract_additional_data(n0name, False,
++                                                                                                      model_onnx.graph)
++            map_onnx_to_myGraph[n0name] = n0name
++            myGraph[n0name] = {}
++            myGraph[n0name]["inputs"] = []
++            myGraph[n0name]["additional"] = additional
++            myGraph[n0name]["op_type"] = OPTYPE.Const
++            swap_inputs = True
++        if is_constant(node.input[1], model_onnx.graph.initializer):
++            additional = {}
++            additional["data"] = node
++            additional["dims"], additional["raw_data"], additional["dtype"] = extract_additional_data(node.input[1],
++                                                                                                      False,
++                                                                                                      model_onnx.graph)
++            map_onnx_to_myGraph[node.input[1]] = node.input[1]
++            myGraph[node.input[1]] = {}
++            myGraph[node.input[1]]["inputs"] = []
++            myGraph[node.input[1]]["additional"] = additional
++            myGraph[node.input[1]]["op_type"] = OPTYPE.Const
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Add
++        if not swap_inputs:
++            D1 = extract_dims(n0name, model_onnx.graph)
++            D2 = extract_dims(node.input[1], model_onnx.graph)
++            if D1 is not None and D2 is not None and len(D1) < len(D2): swap_inputs = True
++
++        if swap_inputs:
++            myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[node.input[1]], map_onnx_to_myGraph[n0name]]
++        else:
++            myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name], map_onnx_to_myGraph[node.input[1]]]
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "MaxPool":
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.MaxPool
++        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
++        myGraph[node.output[0]]["additional"] = {}
++        a = getAttribute(node, 'strides')
++        myGraph[node.output[0]]["additional"]["strides"] = [1, a.ints[0], a.ints[1], 1]
++        a = getAttribute(node, 'pads')
++        if a == None:
++            pp = [0, 0, 0, 0]
++        else:
++            pp = a.ints
++        myGraph[node.output[0]]["additional"]["pads"] = pp
++        a = getAttribute(node, 'kernel_shape')
++        myGraph[node.output[0]]["additional"]["kernel_shape"] = [1, a.ints[0], a.ints[1], 1]
++        myGraph[node.output[0]]["additional"]["data"] = node
++        # todo: check pads?
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Mul":
++        # check the inputs
++        if is_constant(n0name, model_onnx.graph.initializer) and is_constant(node.input[1],
++                                                                             model_onnx.graph.initializer):
++            quit("[ERROR] unsupported double constants Mul", node)
++        swap_inputs = False
++        if is_constant(n0name, model_onnx.graph.initializer):
++            additional = {}
++            additional["data"] = node
++            n2 = getNodesWithOutput(n0name, model_onnx)
++            additional["dims"], additional["raw_data"], additional["dtype"] = extract_additional_data(n0name,
++                                                                                                      node_annotation[
++                                                                                                          n2.name].to_transpose,
++                                                                                                      model_onnx.graph)
++            map_onnx_to_myGraph[n0name] = n0name
++            myGraph[n0name] = {}
++            myGraph[n0name]["inputs"] = []
++            myGraph[n0name]["additional"] = additional
++            myGraph[n0name]["op_type"] = OPTYPE.Const
++            swap_inputs = True
++        if is_constant(node.input[1], model_onnx.graph.initializer):
++            additional = {}
++            additional["data"] = node
++            n2 = getNodesWithOutput(node.input[1], model_onnx)
++            additional["dims"], additional["raw_data"], additional["dtype"] = extract_additional_data(node.input[1],
++                                                                                                      node_annotation[
++                                                                                                          n2.name].to_transpose,
++                                                                                                      model_onnx.graph)
++            map_onnx_to_myGraph[node.input[1]] = node.input[1]
++            myGraph[node.input[1]] = {}
++            myGraph[node.input[1]]["inputs"] = []
++            myGraph[node.input[1]]["additional"] = additional
++            myGraph[node.input[1]]["op_type"] = OPTYPE.Const
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Mul
++        if swap_inputs:
++            myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[node.input[1]], map_onnx_to_myGraph[n0name]]
++        else:
++            myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name], map_onnx_to_myGraph[node.input[1]]]
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Identity":
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Identity
++        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "LeakyRelu":
++        # leaky coef
++        additional = {}
++        additional["data"] = node
++        additional["dims"] = [1]
++        # TODO: Change quantizer and data type for LeakyRelu layer
++        additional["raw_data"] = np.array(node.attribute[0].f, dtype=NUMPY_TYPE).tobytes()
++        if NUMPY_TYPE == np.int32:
++            additional["dtype"] = DTYPE_SADL.INT32
++        elif NUMPY_TYPE == np.int16:
++            additional["dtype"] = DTYPE_SADL.INT16
++        elif NUMPY_TYPE == np.float32:
++            additional["dtype"] = DTYPE_SADL.FLOAT
++
++        map_onnx_to_myGraph[node.output[0] + "_COEF_NOT_IN_GRAPH"] = None
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {}
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = []
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.Const
++
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.LeakyReLU
++        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name], node.output[0] + "_NOT_IN_GRAPH"]
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "PRelu":  # map to leakyrelu because no training
++        # coef
++        additional = {}
++        additional["data"] = node
++        additional["dims"] = [1]
++        dims, data, dtype = extract_additional_data(node.input[1], False, model_onnx.graph)
++        if np.prod(dims) != 1:
++            quit("[ERROR] PRelu slope not scalar:", dims)
++        f = np.frombuffer(data, dtype=np.float32)
++        additional["raw_data"] = np.array(float(f), dtype=np.float32).tobytes()
++        additional["dtype"] = DTYPE_SADL.FLOAT
++        map_onnx_to_myGraph[node.output[0] + "_COEF_NOT_IN_GRAPH"] = None
++
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {}
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = []
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.Const
++
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.LeakyReLU
++        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name], node.output[0] + "_NOT_IN_GRAPH"]
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Flatten":
++        inputs, additional = [], {}
++        inputs = [map_onnx_to_myGraph[n0name]]
++        additional["data"] = node
++        a = getAttribute(node, "axis")
++        additional["axis"] = a.i
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["inputs"] = inputs
++        myGraph[node.output[0]]["additional"] = additional
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Flatten
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Shape":
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Shape
++        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name]]
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Expand":
++        inputs, additional = [], {}
++        inputs = [map_onnx_to_myGraph[n0name], map_onnx_to_myGraph[node.input[1]]]
++        additional["data"] = node
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["inputs"] = inputs
++        myGraph[node.output[0]]["additional"] = additional
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Expand
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Reshape" or node.op_type == "MatMul":
++        # Const
++        myGraph[node.input[1]] = {}
++        myGraph[node.input[1]]["op_type"] = OPTYPE.Const
++        myGraph[node.input[1]]["inputs"] = []
++        additional = {}
++        additional["dims"], additional["raw_data"], additional["dtype"] = extract_additional_data(node.input[1], False,
++                                                                                                  model_onnx.graph)
++        additional["data"] = node
++        myGraph[node.input[1]]["additional"] = additional
++        map_onnx_to_myGraph[node.input[1]] = node.input[1]
++        n2 = getNodesWithOutput(node.input[0], model_onnx)
++        # Reshape
++        inputs, additional = [], {}
++        inputs = [map_onnx_to_myGraph[n0name], node.input[1]]
++        additional["data"] = node
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["inputs"] = inputs
++        myGraph[node.output[0]]["additional"] = additional
++
++        if node.op_type == "Reshape":
++            myGraph[node.output[0]]["op_type"] = OPTYPE.Reshape
++        elif node.op_type == "MatMul":
++            myGraph[node.output[0]]["op_type"] = OPTYPE.MatMul
++
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Concat":
++        # Const
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Const
++        myGraph[node.output[0]]["inputs"] = []
++        additional = {}
++        additional["dims"] = [1]
++        additional["raw_data"] = np.array(node.attribute[0].i, dtype=np.int32).tobytes()
++        additional["dtype"] = DTYPE_SADL.INT32
++        additional["data"] = node
++        myGraph[node.output[0]]["additional"] = additional
++        map_onnx_to_myGraph[node.output[0] + "_NOT_IN_GRAPH"] = None
++
++        # Concatenate
++        inputs, additional = [], {}
++        for inp in node.input:
++            inputs.append(map_onnx_to_myGraph[inp])
++        inputs.append(node.output[0])
++        additional["data"] = node
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"] = {}
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["inputs"] = inputs
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["additional"] = additional
++        myGraph[node.output[0] + "_NOT_IN_GRAPH"]["op_type"] = OPTYPE.ConcatV2
++        map_onnx_to_myGraph[node.output[0]] = node.output[0] + "_NOT_IN_GRAPH"
++
++
++    elif node.op_type == "Max":
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Maximum
++        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name], map_onnx_to_myGraph[node.input[1]]]
++        myGraph[node.output[0]]["additional"] = {}
++        myGraph[node.output[0]]["additional"]["data"] = node
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Unsqueeze":
++        # No need to parse Unsqueeze as SADL can handle it.
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    elif node.op_type == "Transpose":
++        # Const
++        reshape_coef_name = node.output[0] + "_COEF_TRANSPOSE"
++        myGraph[reshape_coef_name] = {}
++        myGraph[reshape_coef_name]["op_type"] = OPTYPE.Const
++        myGraph[reshape_coef_name]["inputs"] = []
++        additional = {}
++        d = toList(getAttribute(node, "perm").ints)
++        additional["dims"] = [len(d)]
++        additional["raw_data"] = np.array(d, dtype=np.int32).tobytes()
++        additional["dtype"] = DTYPE_SADL.INT32
++        additional["data"] = node
++        myGraph[reshape_coef_name]["additional"] = additional
++        map_onnx_to_myGraph[reshape_coef_name] = reshape_coef_name
++
++        myGraph[node.output[0]] = {}
++        myGraph[node.output[0]]["op_type"] = OPTYPE.Transpose
++        myGraph[node.output[0]]["inputs"] = [map_onnx_to_myGraph[n0name], reshape_coef_name]
++        map_onnx_to_myGraph[node.output[0]] = node.output[0]
++
++    else:
++        raise Exception("[ERROR] node not supported:\n{})".format(node))
++
++    if node_annotation[node.name].add_transpose_after:
++        n0name = add_transpose_after(node, myGraph, map_onnx_to_myGraph)
++
++
++def parse_onnx(model_onnx, node_annotation, verbose=False):
++    myGraph, map_onnx_to_myGraph = OrderedDict(), {}
++
++    # Inputs
++    for inp in model_onnx.graph.input:
++        myGraph[inp.name] = parse_graph_input_node(inp, map_onnx_to_myGraph, node_annotation[inp.name].to_transpose)
++
++    # Nodes removal
++    for node in model_onnx.graph.node:
++        if node.name in node_annotation and node_annotation[node.name].to_remove:
++            curr_key = node.input[0]
++            while map_onnx_to_myGraph[curr_key] != None and map_onnx_to_myGraph[curr_key] != curr_key:
++                next_key = map_onnx_to_myGraph[curr_key]
++                curr_key = next_key
++                if curr_key not in map_onnx_to_myGraph:
++                    curr_key = node.input[0]
++                    break
++
++            map_onnx_to_myGraph[node.output[0]] = curr_key
++        else:
++            parse_graph_node(node, model_onnx, myGraph, node_annotation, map_onnx_to_myGraph, verbose)
++
++    myInputs = []
++    for inp in model_onnx.graph.input:
++        myInputs.append(inp.name)
++
++    myOutputs = []
++    for out in model_onnx.graph.output:
++        for key, value in map_onnx_to_myGraph.items():
++            if key == out.name:
++                myOutputs.append(value)
++
++    return myGraph, myInputs, myOutputs
++
++
++def dump_onnx(graph, my_inputs, my_outputs, output_filename, verbose=False):
++    # graph[my_name]={ op_type
++    #                  inputs: []
++    #                  dtype:
++    #                  onnx : model.graph.node[x]
++    #                  }
++
++    # my_input=[my_name, my_name..]
++    # outputs=[my_name, ...]
++    # print(graph)
++    map_name_to_idx = dict()
++    for idx, (key, value) in enumerate(graph.items()):
++        map_name_to_idx[key] = idx
++
++    # dbg print(map_name_to_idx)
++    with open(output_filename, "wb") as f:
++        f.write(str.encode('SADL0002'))
++        # output of the network type 0: int32 | 1: float | 2: int16 | default: float(1)
++        # TODO: change model type
++        f.write(struct.pack('i', int(MODEL_TYPE)))
++
++        if verbose: print(f"# Nb layers: {len(graph.keys())}")
++        f.write(struct.pack('i', int(len(graph.keys()))))
++
++        inputs = []
++        for name in my_inputs:
++            inputs.append(map_name_to_idx[name])
++        if verbose: print(f"# Nb inputs: {len(inputs)}")
++        f.write(struct.pack('i', int(len(inputs))))
++        for i in inputs:
++            if verbose: print(f'#  input', i)
++            f.write(struct.pack('i', int(i)))
++
++        outputs = []
++        for name in my_outputs:
++            outputs.append(map_name_to_idx[name])
++        if verbose: print(f"# Nb outputs: {len(outputs)}")
++        f.write(struct.pack('i', int(len(outputs))))
++        for i in outputs:
++            if verbose: print(f'#  output {i}')
++            f.write(struct.pack('i', int(i)))
++
++        for (name, node) in graph.items():
++            if verbose: print(f"# Layer id {map_name_to_idx[name]}")
++            f.write(struct.pack('i', int(map_name_to_idx[name])))
++
++            if verbose: print("#\t op " + str(node['op_type']))
++            f.write(struct.pack('i', int(node['op_type'].value)))
++
++            # Name size
++            if verbose: print(f"#\t name_size {len(name)}")
++            f.write(struct.pack('i', int(len(name))))
++
++            # Name
++            if verbose: print(f"#\t name {name}")
++            f.write(str.encode(str(name)))
++
++            # Nb inputs
++            if verbose: print(f"#\t nb_inputs {len(node['inputs'])}")
++            f.write(struct.pack('i', int(len(node['inputs']))))
++
++            for name_i in node['inputs']:
++                idx = map_name_to_idx[name_i]
++                if verbose: print(f"#\t\t {idx} ({name_i})")
++                f.write(struct.pack('i', int(idx)))
++
++            # Additional info depending on OPTYPE
++            if node['op_type'] == OPTYPE.Const:
++                if verbose: print(f"#\t nb_dim {len(node['additional']['dims'])}")
++                f.write(struct.pack('i', int(len(node['additional']['dims']))))
++
++                for dim in node['additional']['dims']:
++                    if verbose: print(f"#\t\t {dim}")
++                    f.write(struct.pack('i', int(dim)))
++
++                if verbose: print(f"#\t dtype {node['additional']['dtype']}")
++                f.write(struct.pack('i', int(node['additional']['dtype'])))
++
++                if node['additional']['dtype'] != DTYPE_SADL.FLOAT:  # not float
++                    op_name = name.split('/')[2]
++                    # TODO: Quantizer for weights of Conv, Bias, Mul
++                    q = 0
++                    for attri in node['additional']['data'].attribute:
++                        if attri.name == 'quantizer':
++                            if op_name == 'Mul' or 'LeakyRelu' in op_name:
++                                q = attri.i
++                            elif op_name == 'Conv2D':
++                                q = attri.ints[0]
++                            elif op_name == 'BiasAdd':
++                                q = attri.ints[1]
++                    if verbose: print(f"#\t quantizer {q}")
++                    f.write(struct.pack('i', int(q)))
++
++                f.write(node['additional']['raw_data'])
++            # ???    if "alpha" in layer['additional']:
++            #        f.write(struct.pack('f', float(layer['additional']['alpha'])))
++
++            elif node['op_type'] == OPTYPE.Conv2D:
++                if verbose: print("#\t  nb_dim_strides", len(node['additional']['strides']))
++                f.write(struct.pack('i', int(len(node['additional']['strides']))))
++
++                for stride in node['additional']['strides']:
++                    if verbose: print(f"#\t\t {stride}")
++                    f.write(struct.pack('i', int(stride)))
++
++                if verbose: print("#\t  nb_dim_pads", len(node['additional']['pads']))
++                f.write(struct.pack('i', int(len(node['additional']['pads']))))
++
++                for p in node['additional']['pads']:
++                    if verbose: print(f"#\t\t {p}")
++                    f.write(struct.pack('i', int(p)))
++
++            elif node['op_type'] == OPTYPE.Placeholder:
++                if verbose: print(f"#\t nb input dimension {len(node['additional']['dims'])}")
++                f.write(struct.pack('i', int(len(node['additional']['dims']))))
++
++                for dim in node['additional']['dims']:
++                    if verbose: print(f"#\t\t {dim}")
++                    f.write(struct.pack('i', int(dim)))
++
++                # output the quantizer of the input default: 0
++                # TODO: Change input quantizer
++                if verbose: print(f"#\t quantizer_of_input")
++                f.write(struct.pack('i', int(INPUT_QUANTIZER)))
++
++            elif node['op_type'] == OPTYPE.MaxPool:
++                if verbose: print("#\t  nb_dim_strides", len(node['additional']['strides']))
++                f.write(struct.pack('i', int(len(node['additional']['strides']))))
++
++                for stride in node['additional']['strides']:
++                    if verbose: print(f"#\t\t {stride}")
++                    f.write(struct.pack('i', int(stride)))
++
++                if verbose: print("#\t  nb_dim_kernel", len(node['additional']['kernel_shape']))
++                f.write(struct.pack('i', int(len(node['additional']['kernel_shape']))))
++
++                for ks in node['additional']['kernel_shape']:
++                    if verbose: print(f"#\t\t {ks}")
++                    f.write(struct.pack('i', int(ks)))
++
++                if verbose: print("#\t  nb_dim_pads", len(node['additional']['pads']))
++                f.write(struct.pack('i', int(len(node['additional']['pads']))))
++
++                for p in node['additional']['pads']:
++                    if verbose: print(f"#\t\t {p}")
++                    f.write(struct.pack('i', int(p)))
++
++            elif node['op_type'] == OPTYPE.Flatten:
++                if verbose: print("#\t axis", node['additional']['axis'])
++                f.write(struct.pack('i', int(node['additional']['axis'])))
++
++            if node['op_type'] == OPTYPE.Conv2D or node['op_type'] == OPTYPE.MatMul or node['op_type'] == OPTYPE.Mul:
++                # TODO: Change internal integer
++                internal = 0
++                for attri in node['additional']['data'].attribute:
++                    if attri.name == 'internal':
++                        internal = attri.i
++                        # output the internal quantizer default: 0
++                if verbose: print(f"#\t internal integer {internal}")
++                f.write(struct.pack('i', int(internal)))
++
++            if verbose: print("")
++
++
++# adatp (remove/add) the current node to the data_layout and
++# recurse in the output
++def annotate_node(node, model_onnx, node_annotation, global_data_layout, verbose):  # recusrive
++    if node.name in node_annotation: return
++    if verbose > 1: print("[INFO] annotate {}".format(node.name))
++
++    data_layout = None
++
++    # inherit from input
++    for inp in node.input:
++        n2 = getNodesWithOutputNotConst(inp, model_onnx)
++        if n2 is not None:
++            if n2.name in node_annotation:
++                if data_layout is None:
++                    data_layout = node_annotation[n2.name].layout_onnx
++                elif node_annotation[n2.name].layout_onnx != None and node_annotation[
++                    n2.name].layout_onnx != data_layout:
++                    quit("[ERROR] inputs with diferent layout for\n{}Layouts: {}".format(node, node_annotation))
++            else:  # not ready yet
++                return
++
++    if verbose > 1 and data_layout is None: print(
++        "[WARNING] no data layout constraints for {}\n {}".format(node.name, node))
++
++    if node.name not in node_annotation: node_annotation[node.name] = Node_Annotation()
++    node_annotation[node.name].layout_onnx = data_layout  # default
++
++    if node.op_type == "Transpose":
++        a = getAttribute(node, "perm")
++        if data_layout == 'nhwc':
++            if a.ints[0] == 0 and a.ints[1] == 3 and a.ints[2] == 1 and a.ints[3] == 2:  # nhwc ->nchw
++                node_annotation[node.name].to_remove = True  # will be removed
++                node_annotation[node.name].layout_onnx = 'nchw'  # new layout at output
++            else:
++                if verbose > 1: print("[WARNING] transpose not for NCHW handling in\n", node)
++        elif data_layout == 'nchw':
++            if a.ints[0] == 0 and a.ints[1] == 2 and a.ints[2] == 3 and a.ints[3] == 1:  # nchw ->nhwc
++                node_annotation[node.name].to_remove = True  # will be removed
++                node_annotation[node.name].layout_onnx = 'nhwc'  # new layout at output
++            else:
++                if verbose > 1: print("[WARNING] transpose not for NCHW handling in\n", node)
++
++    elif node.op_type == "Reshape":
++        initializer = getInitializer(node.input[1], model_onnx)
++        # Case: In pytorch, Reshape is not in model_onnx.graph.initializer but in model_onnx.graph.node
++        if initializer == None:
++            attribute = getAttribute(getNodesWithOutput(node.input[1], model_onnx), "value")
++            initializer = attribute.t
++        dims = getDims(initializer)
++
++        # detect if this reshape is actually added by onnx to emulate a transpose
++        # we need to test more if reshpae is for transpose...
++        if len(dims) == 4 and (dims[0] == 1 or dims[0] == -1):
++            if data_layout == 'nhwc':
++                if dims[1] == 1:  # or dims2 * dims3 == 1 # nhwc ->nchw
++                    node_annotation[node.name].to_remove = True  # will be removed
++                    node_annotation[node.name].layout_onnx = 'nchw'  # new layout at output
++                else:
++                    if verbose > 1: print("[WARNING] reshape unknown for", node, " dims", dims)
++                    node_annotation[node.name].layout_onnx = None
++            elif data_layout == 'ncwh':
++                if dims[3] == 1:  # # or dims2 * dims3 == 1 nchw ->nhwc
++                    node_annotation[node.name].to_remove = True  # will be removed
++                    node_annotation[node.name].layout_onnx = 'nhwc'  # new layout at output
++                else:
++                    if verbose > 1: print("[WARNING] reshape unknown for", node, " dims", dims)
++                    node_annotation[node.name].layout_onnx = None
++            elif data_layout == None:
++                node_annotation[node.name].layout_onnx = global_data_layout  # back to org
++                if global_data_layout == 'nchw':
++                    node_annotation[node.name].add_transpose_after = True  # a bit too agressive
++        else:
++            node_annotation[node.name].layout_onnx = None
++
++        n2 = getNodesWithOutputNotConst(node.input[0], model_onnx)
++        if node_annotation[n2.name].layout_onnx == 'nchw':  # need to go back to original layout before reshape
++            node_annotation[node.name].add_transpose_before = True
++
++    elif node.op_type == "Flatten":
++        if node_annotation[node.name].layout_onnx == 'nchw':  # need to go back to original layout before reshape
++            node_annotation[node.name].add_transpose_before = True
++
++    elif node.op_type == 'Concat':
++        if data_layout == 'nchw':  # nhwc -> nhwc
++            a = getAttribute(node, 'axis')
++            if a.i == 1:
++                a.i = 3
++            elif a.i == 2:
++                a.i = 1
++            elif a.i == 3:
++                a.i = 2
++
++    elif node.op_type == 'Unsqueeze':
++        node_annotation[node.name].to_remove = True
++
++    elif node.op_type == 'Conv':
++        n2 = getInitializer(node.input[1], model_onnx)
++        node_annotation[n2.name].to_transpose = True
++        node_annotation[n2.name].layout_onnx = 'nhwc'
++
++    elif node.op_type == 'Gemm':
++        n2 = getInitializer(node.input[1], model_onnx)
++        if global_data_layout == 'nchw':
++            node_annotation[n2.name].to_transpose = True
++        #    node_annotation[n2.name].layout_onnx = 'nhwc'
++
++    nexts = getNodesWithInput(node.output[0], model_onnx)
++    for n in nexts:
++        annotate_node(n, model_onnx, node_annotation, global_data_layout, verbose)  # rec
++
++
++def annotate_graph(model_onnx, node_annotation, data_layout, verbose):
++    # track the data layout in the graph and remove/add layers if necessary
++    for inp in model_onnx.graph.input:
++        node_annotation[inp.name] = Node_Annotation()
++        if len(inp.type.tensor_type.shape.dim) == 4:
++            node_annotation[inp.name].layout_onnx = data_layout
++            if data_layout == 'nchw':
++                node_annotation[inp.name].to_transpose = True
++        else:
++            node_annotation[inp.name].layout_onnx = None
++
++    for inp in model_onnx.graph.initializer:
++        node_annotation[inp.name] = Node_Annotation()
++        node_annotation[inp.name].layout_onnx = None
++
++    for inp in model_onnx.graph.node:
++        if inp.op_type == "Constant":
++            node_annotation[inp.name] = Node_Annotation()
++            node_annotation[inp.name].layout_onnx = None
++
++    for inp in model_onnx.graph.input:
++        nexts = getNodesWithInput(inp.name, model_onnx)
++        for n in nexts:
++            annotate_node(n, model_onnx, node_annotation, data_layout, verbose)  # recusrive
++
++    if verbose > 1:
++        for node in model_onnx.graph.node:
++            if node.op_type == "Transpose" and (
++                    node.name not in node_annotation or not node_annotation[node.name].to_remove):
++                print(
++                    "[ERROR] preprocess_onnxGraph: all transpose node should be removed but this is not the case here: {}\n{}".format(
++                        node.name, node))
++
++
++def detectDataType(model):  # more adaptation to do here if tf is using nchw
++    if model.producer_name == 'tf2onnx':
++        return 'nhwc'
++    elif model.producer_name == 'pytorch':
++        return 'nchw'
++    else:
++        quit('[ERROR] unable to detect data layout')
++
++
++def dumpModel(model_onnx, output_filename, data_layout, verbose, user_annotation):
++    """Writes the neural network model in the \"sadl\" format to binary file.
++
++    Parameters
++    ----------
++    model : onnx model
++    output_filename : either str or None
++        Path to the binary file to which the neural network model
++        is written.
++    data_type: None, 'ncwh' or 'nwhc'
++    verbose : bool
++        Is additional information printed?
++    """
++    model_onnx_copy = copy.deepcopy(model_onnx)
++    if data_layout is None: data_layout = detectDataType(model_onnx_copy)
++
++    if verbose: print("[INFO] assume data type", data_layout)
++
++    if verbose > 1:
++        # remove data
++        gg = copy.deepcopy(model_onnx.graph)
++        for node in gg.initializer:
++            node.raw_data = np.array(0.).tobytes()
++        print("[INFO] original graph:\n", gg)
++        del gg
++
++    if data_layout != 'nhwc' and data_layout != 'nchw':
++        quit('[ERROR] unsupported layout', data_layout)
++
++    node_annotation = {}
++    annotate_graph(model_onnx_copy, node_annotation, data_layout, verbose)
++
++    for k, v in user_annotation.items():
++        if k in node_annotation:
++            if v.add_transpose_before is not None: node_annotation[k].add_transpose_before = v.add_transpose_before
++            if v.add_transpose_after is not None: node_annotation[k].add_transpose_after = v.add_transpose_after
++            if v.to_remove is not None: node_annotation[k].to_remove = v.to_remove
++            if v.to_transpose is not None: node_annotation[k].to_transpose = v.to_transpose
++        else:
++            print("[ERROR] unknown node user custom", k)
++            quit()
++
++    if verbose > 1: print("INFO] annotations:\n{" + "\n".join("{!r}: {!r},".format(k, v) for k, v in
++                                                              node_annotation.items()) + "}")  # print("[INFO] node annotations:", node_annotation)
++    my_graph, my_inputs, my_outputs = parse_onnx(model_onnx_copy, node_annotation, verbose=verbose)
++    dump_onnx(my_graph, my_inputs, my_outputs, output_filename, verbose=verbose)
++    if data_layout == 'nchw': print("[INFO] in SADL, your inputs and outputs has been changed from NCHW to NHWC")
++
++
++def convert_to_sadl(model_onnx, output, nchw=False, nhwc=False, verbose=None, do_not_add_transpose_before=[],
++                    do_not_add_transpose_after=[]):
++    if model_onnx is None:
++        raise ('[ERROR] You should specify an onnx model')
++        quit()
++
++    if output is None:
++        raise ('[ERROR] You should specify an output file')
++        quit()
++
++    print("[INFO] ONNX converter")
++    if verbose is None: verbose = 0
++
++    user_annotation = {}
++    for node in do_not_add_transpose_before:
++        if node not in user_annotation:
++            user_annotation[node] = Node_Annotation()
++            user_annotation[node].to_remove = None
++            user_annotation[node].add_transpose_before = None
++            user_annotation[node].add_transpose_after = None
++            user_annotation[node].to_transpose = None
++        user_annotation[node].add_transpose_before = False
++
++    for node in do_not_add_transpose_after:
++        if node not in user_annotation:
++            user_annotation[node] = Node_Annotation()
++            user_annotation[node].to_remove = None
++            user_annotation[node].add_transpose_before = None
++            user_annotation[node].add_transpose_after = None
++            user_annotation[node].to_transpose = None
++        user_annotation[node].add_transpose_after = False
++
++    data_layout = None
++    if nchw:
++        data_layout = 'nchw'
++    elif nhwc:
++        data_layout = 'nhwc'
++
++    global INPUT_QUANTIZER
++    global NUMPY_TYPE
++    global MODEL_TYPE
++
++    for node in model_onnx.graph.node:
++        if node.input[0] == 'args_0':
++            for attri in node.attribute:
++                if attri.name == 'quantizer':
++                    INPUT_QUANTIZER = attri.i
++                    break
++
++    if INPUT_QUANTIZER is None:
++        INPUT_QUANTIZER = 0
++
++    for graph_input in model_onnx.graph.input:
++        type = graph_input.type.tensor_type.elem_type
++        if type is not None:
++            if type == onnx.TensorProto.INT32:
++                NUMPY_TYPE = np.int32
++                MODEL_TYPE = DTYPE_SADL.INT32
++            elif type == onnx.TensorProto.INT16:
++                NUMPY_TYPE = np.int16
++                MODEL_TYPE = DTYPE_SADL.INT16
++            else:
++                NUMPY_TYPE = np.float32
++                MODEL_TYPE = DTYPE_SADL.FLOAT
++            break
++
++    assert (INPUT_QUANTIZER is not None)
++    assert (NUMPY_TYPE is not None)
++    assert (MODEL_TYPE is not None)
++
++    dumpModel(model_onnx, output, data_layout, verbose, user_annotation)
++
++
++if __name__ == '__main__':
++    parser = argparse.ArgumentParser(prog='onnx2sadl conversion',
++                                     usage='NB: force run on CPU')
++    parser.add_argument('--input_onnx',
++                        action='store',
++                        nargs='?',
++                        type=str,
++                        help='name of the onnx file')
++    parser.add_argument('--output',
++                        action='store',
++                        nargs='?',
++                        type=str,
++                        help='name of model binary file')
++    parser.add_argument('--nchw', action='store_true')
++    parser.add_argument('--nhwc', action='store_true')
++    parser.add_argument('--verbose', action="count")
++    parser.add_argument('--do_not_add_transpose_before', action="store", nargs='+', default=[],
++                        help='specify a node where add transpose before will be disable')
++    parser.add_argument('--do_not_add_transpose_after', action="store", nargs='+', default=[],
++                        help='specify a node where add transpose after will be disable')
++
++    args = parser.parse_args()
++    if args.input_onnx is None:
++        raise ('[ERROR] You should specify an onnx file')
++        quit()
++    if args.output is None:
++        raise ('[ERROR] You should specify an output file')
++        quit()
++
++    print("[INFO] ONNX converter")
++    if args.verbose is None: args.verbose = 0
++
++    model_onnx = onnx.load(args.input_onnx)
++
++    convert_to_sadl(model_onnx, args.output, args.nchw, args.nhwc, args.verbose, args.do_not_add_transpose_before,
++                    args.do_not_add_transpose_after)
+\ No newline at end of file
diff --git a/training/training_scripts/NN_Post_Filtering/overfitting_and_quantisation.md b/training/training_scripts/NN_Post_Filtering/overfitting_and_quantisation.md
new file mode 100644
index 0000000000000000000000000000000000000000..f8f3ff9b8f58aca623987f11207950ec934f7397
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/overfitting_and_quantisation.md
@@ -0,0 +1,48 @@
+# Over-fitting
+
+For each test video sequence and sequence QP pair one of the base post-filter models is over-fitted. 
+The over-fitting was done on inter coded frames only and the decision on which model to over-fit considered
+the performance of the base model (most often selected when deployed within VTM). 
+
+To identify the model to be over-fitted, run the inference using the base models only and identify the one 
+that is used more often (i.e. full sequence)
+
+The following example shows how to over-fit the base model 3 for D_BlowingBubbles sequence and QP 42:
+
+```shell
+cd $BASE/scripts
+./overfitting.sh D_BlowingBubbles 42 model3
+```
+
+# Base model conversion to SADL and quantisation
+
+```shell
+DATASET=$BASE/post_filter_dataset/train_data
+PROPS=$BASE/scripts/resources/properties/train_data_properties.json
+BASE_MODELS=$BASE/finetuning/base_models
+OUTPUT_DIR=$BASE/finetuning/sadl
+mkdir -p ${OUTPUT_DIR}
+
+cd $BASE/NCTM
+
+# model0
+python run/convert_and_quantise_base_model.py --base_model_dir ${BASE_MODELS}/model0 \
+--dataset_dir $DATASET --properties_file $PROPS \
+--output_dir ${OUTPUT_DIR}
+
+# model1
+python run/convert_and_quantise_base_model.py --base_model_dir ${BASE_MODELS}/model1 \
+--dataset_dir $DATASET --properties_file $PROPS \
+--output_dir ${OUTPUT_DIR}
+
+# model2
+python run/convert_and_quantise_base_model.py --base_model_dir ${BASE_MODELS}/model2 \
+--dataset_dir $DATASET --properties_file $PROPS \
+--output_dir ${OUTPUT_DIR}
+
+# model3
+python run/convert_and_quantise_base_model.py --base_model_dir ${BASE_MODELS}/model3 \
+--dataset_dir $DATASET --properties_file $PROPS \
+--output_dir ${OUTPUT_DIR}
+
+```
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/__init__.py b/training/training_scripts/NN_Post_Filtering/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/create_filter_with_multiplier.py b/training/training_scripts/NN_Post_Filtering/scripts/create_filter_with_multiplier.py
new file mode 100644
index 0000000000000000000000000000000000000000..1370eee58345bff7af64c0b0185fcfd3c2b542a4
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/create_filter_with_multiplier.py
@@ -0,0 +1,66 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+import click
+import tensorflow as tf
+
+from models.filter_with_multipliers import FilterWithMultipliers
+from util.file_system import check_directory
+
+
+@click.command()
+@click.option(
+    "--base_model_dir",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains the base model",
+)
+@click.option("--output_dir", default=None, type=click.Path(), help="Output directory")
+def run_model(base_model_dir, output_dir):
+    gpus = tf.config.list_physical_devices("GPU")
+    for gpu in gpus:
+        tf.config.experimental.set_memory_growth(gpu, True)
+
+    check_directory(base_model_dir)
+
+    multiplier_model = FilterWithMultipliers()
+    # Initialise variables
+    multiplier_model(tf.zeros((1, 72, 72, 10)))
+    multiplier_model.load_pretrained_weights(base_model_dir)
+    multiplier_model.save(output_dir)
+
+
+if __name__ == "__main__":
+    """
+    Make sure to set the environment variable `CUDA_VISIBLE_DEVICES=<gpu_idx>` before launching the python process
+    """
+    run_model()
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/models/__init__.py b/training/training_scripts/NN_Post_Filtering/scripts/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/models/base_model.py b/training/training_scripts/NN_Post_Filtering/scripts/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..44435e50170e2f4a7e2049f73c6a2ea97f19685a
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/models/base_model.py
@@ -0,0 +1,374 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+import os
+import sys
+import time
+from abc import ABC, abstractmethod
+from datetime import datetime
+from pathlib import Path
+from typing import Dict, List, Tuple, Union
+
+import tensorflow as tf
+
+from models.filter_with_multipliers import FilterWithMultipliers, Multiplier
+from util import Colour, Metric
+from util.logging import log_epoch_metrics, save_time, TIME_FORMAT
+from util.metrics import compute_loss, compute_epoch_metrics
+
+
+class BaseModel(ABC):
+    def __init__(
+        self,
+        stage: str,
+        base_model_dirs: List[str],
+        epochs: int,
+        lr: float,
+        batch_size: int,
+        block_size: int,
+        pad_size: int,
+        output_dir: str,
+    ) -> None:
+        """
+        Constructor
+        :param stage: training stage, i.e. fine-tuning or over-fitting
+        :param base_model_dirs: List of base model directories
+        :param epochs: Epochs
+        :param lr: Learning rate
+        :param batch_size: Batch size
+        :param block_size: Block size (without padding)
+        :param pad_size: Padding size
+        :param output_dir: Output directory for the log files and trained model
+        """
+        self._is_fine_tuning = stage == "fine-tuning"
+
+        self._num_models = len(base_model_dirs)
+        self._models = {}
+        self._base_models = {}
+        base_model_names = [""] * self._num_models
+
+        for model_idx in range(self._num_models):
+            base_model_dir = base_model_dirs[model_idx]
+            self._base_models[model_idx] = tf.saved_model.load(base_model_dir)
+            base_model_names[model_idx] = Path(base_model_dir).name
+
+            if self._is_fine_tuning:
+                self._models[model_idx] = tf.saved_model.load(base_model_dir)
+            else:
+                self._models[model_idx] = tf.keras.models.load_model(
+                    base_model_dir,
+                    custom_objects={
+                        "FilterWithMultipliers": FilterWithMultipliers,
+                        "Multiplier": Multiplier,
+                    },
+                )
+
+        self._epochs = epochs
+        self._lr = lr
+        self._optimiser = tf.optimizers.Adam(self._lr)
+        self._batch_size = batch_size
+
+        self._block_size = block_size
+        self._pad_size = pad_size
+
+        self._output_dir = output_dir
+        os.makedirs(self._output_dir, exist_ok=True)
+
+        self._model_output_dirs = {}
+        self._log_dirs = {}
+        self._log_files = {}
+        self._model_dirs = {}
+
+        self._train_summary = {}
+        self._valid_summary = {}
+
+        self._define_output_files(base_model_names)
+        self._initialise_metrics()
+
+    def _define_output_files(self, base_model_names: List[str]) -> None:
+        """
+        Defines the output directories and files: training output directory, summary directories,
+        output SavedModel directory and training time file
+        :param base_model_names:
+        """
+        self._time_file = os.path.join(self._output_dir, "train_time.csv")
+
+        model_type = "finetuned" if self._is_fine_tuning else "overfitted"
+
+        log_dirs = {}
+
+        for idx, base_model_name in enumerate(base_model_names):
+            self._model_output_dirs[idx] = os.path.join(
+                self._output_dir, f"{model_type}_{base_model_name}"
+            )
+            log_dirs[idx] = os.path.join(
+                self._model_output_dirs[idx],
+                f"{base_model_name}_{self._lr}_{self._batch_size}",
+            )
+            self._log_files[idx] = os.path.join(
+                self._model_output_dirs[idx], "training.csv"
+            )
+            self._model_dirs[idx] = os.path.join(
+                self._model_output_dirs[idx], "OutputModel"
+            )
+
+            self._train_summary[idx] = tf.summary.create_file_writer(
+                log_dirs[idx] + "_train"
+            )
+
+            if self._is_fine_tuning:
+                self._valid_summary[idx] = tf.summary.create_file_writer(
+                    log_dirs[idx] + "_valid"
+                )
+
+    def _initialise_metrics(self) -> None:
+        """
+        Initialises the training and validation metrics batch- and epoch-wise
+        """
+        self._train_batch_metrics = {}
+        self._train_epoch_metrics = {}
+
+        if self._is_fine_tuning:
+            self._valid_batch_metrics = {}
+            self._valid_epoch_metrics = {}
+
+        for model_idx in range(self._num_models):
+            self._train_batch_metrics[model_idx] = {}
+            self._train_epoch_metrics[model_idx] = {}
+
+            if self._is_fine_tuning:
+                self._valid_batch_metrics[model_idx] = {}
+                self._valid_epoch_metrics[model_idx] = {}
+
+            for colour in range(Colour.NUM_COLOURS):
+                self._train_batch_metrics[model_idx][colour] = {}
+                self._train_epoch_metrics[model_idx][colour] = {}
+
+                if self._is_fine_tuning:
+                    self._valid_batch_metrics[model_idx][colour] = {}
+                    self._valid_epoch_metrics[model_idx][colour] = {}
+
+                for metric in range(Metric.NUM_METRICS):
+                    self._train_batch_metrics[model_idx][colour][
+                        metric
+                    ] = tf.metrics.Mean()
+                    self._train_epoch_metrics[model_idx][colour][metric] = 0
+
+                    if self._is_fine_tuning:
+                        self._valid_batch_metrics[model_idx][colour][
+                            metric
+                        ] = tf.metrics.Mean()
+                        self._valid_epoch_metrics[model_idx][colour][metric] = 0
+
+    def compute_base_metrics(
+        self, base_model: tf.keras.Model, input_data: tf.Tensor, label_data: tf.Tensor
+    ) -> Tuple[Dict[int, tf.Tensor], Dict[int, tf.Tensor]]:
+        """
+        Computes PSNR for the VTM reconstruction and the initial model
+        :param base_model: Base model (before any fine-tuning or over-fitting takes place)
+        :param input_data: 4D tensor that includes the reconstruction, QP and boundary strength
+        :param label_data: Ground truth
+        :return: VTM PSNR and base model PSNR (sample-wise [B])
+        """
+        _, vtm_psnr = compute_loss(
+            label_data,
+            input_data[
+                :,
+                self._pad_size // 2 : self._pad_size // 2 + self._block_size,
+                self._pad_size // 2 : self._pad_size // 2 + self._block_size,
+                :6,
+            ],
+        )
+
+        base_prediction = base_model(input_data)
+        _, base_psnr = compute_loss(label_data, base_prediction)
+        return vtm_psnr, base_psnr
+
+    @abstractmethod
+    def step(
+        self, input_data: tf.Tensor, label_data: tf.Tensor, train: bool
+    ) -> Tuple[
+        Union[Dict[int, tf.Tensor], List[tf.Tensor]],
+        Union[Dict[int, tf.Tensor], List[tf.Tensor]],
+        Union[Dict[int, tf.Tensor], List[tf.Tensor]],
+        Union[Dict[int, tf.Tensor], List[tf.Tensor]],
+    ]:
+        """
+        Step (training and evaluation)
+        :param input_data: Input data
+        :param label_data: Ground-truth data
+        :param train: Is this a training step?
+        :return: MSE, PSNR, delta PSNR wrt VTM, delta PSNR wrt base model
+        """
+        raise NotImplementedError("Please Implement this method")
+
+    def _update_batch_metrics(
+        self,
+        batch_metrics: Dict[int, Dict[int, tf.metrics.Mean]],
+        mse: Dict[int, tf.Tensor],
+        psnr: Dict[int, tf.Tensor],
+        delta_psnr_wrt_vtm: Dict[int, tf.Tensor],
+        delta_psnr_wrt_to_base: Dict[int, tf.Tensor],
+    ) -> None:
+        """
+        Updates the batch metrics
+        :param mse: MSE
+        :param psnr: PSNR
+        :param delta_psnr_wrt_vtm: PSNR NN reconstruction - PSNR VTM reconstruction
+        :param delta_psnr_wrt_to_base: PSNR NN reconstruction - PSNR base NN reconstruction
+        """
+        for colour in range(Colour.NUM_COLOURS):
+            # check if psnr is not inf
+            is_psnr_not_inf = tf.math.logical_not(tf.math.is_inf(psnr[colour]))
+
+            batch_metrics[colour][Metric.LOSS].update_state(
+                tf.nn.compute_average_loss(
+                    mse[colour], global_batch_size=self._batch_size
+                )
+            )
+            batch_metrics[colour][Metric.PSNR].update_state(
+                tf.reduce_mean(tf.ragged.boolean_mask(psnr[colour], is_psnr_not_inf))
+            )
+            batch_metrics[colour][Metric.DELTA_PSNR_WRT_VTM].update_state(
+                tf.nn.compute_average_loss(
+                    delta_psnr_wrt_vtm[colour], global_batch_size=self._batch_size
+                )
+            )
+            batch_metrics[colour][Metric.DELTA_PSNR_WRT_BASE].update_state(
+                tf.nn.compute_average_loss(
+                    delta_psnr_wrt_to_base[colour], global_batch_size=self._batch_size
+                )
+            )
+
+    def train_loop(
+        self, train_dataset: tf.data.Dataset, valid_dataset: tf.data.Dataset = None
+    ) -> None:
+        """
+        Train loop. The input data is already batched in the form (input_data, label_data)
+        :param train_dataset: Training data
+        :param valid_dataset: Validation data
+        """
+        best_epoch_metric = [sys.float_info.max] * self._num_models
+        curr_epoch_metric = [0] * self._num_models
+        num_no_improvements = [0] * self._num_models
+        limit_epochs_no_improvements = int(self._epochs * 0.5)
+
+        start_time = time.strftime(TIME_FORMAT)
+        last_time = start_time
+
+        for epoch in range(self._epochs):
+            for input_data, label_data in train_dataset:
+                mses, psnrs, delta_psnrs_wrt_vtm, delta_psnrs_wrt_base = self.step(
+                    input_data, label_data, True
+                )
+                for model_idx in range(self._num_models):
+                    self._update_batch_metrics(
+                        self._train_batch_metrics[model_idx],
+                        mses[model_idx],
+                        psnrs[model_idx],
+                        delta_psnrs_wrt_vtm[model_idx],
+                        delta_psnrs_wrt_base[model_idx],
+                    )
+
+            for model_idx in range(self._num_models):
+                compute_epoch_metrics(
+                    self._train_batch_metrics[model_idx],
+                    self._train_epoch_metrics[model_idx],
+                )
+                log_epoch_metrics(
+                    self._train_summary[model_idx],
+                    self._train_epoch_metrics[model_idx],
+                    epoch,
+                )
+                curr_epoch_metric[model_idx] = self._train_epoch_metrics[model_idx][
+                    Colour.YCbCr
+                ][Metric.LOSS]
+
+            if valid_dataset is not None:
+                for input_data, label_data in valid_dataset:
+                    mses, psnrs, delta_psnrs_wrt_vtm, delta_psnrs_wrt_base = self.step(
+                        input_data, label_data, False
+                    )
+                    for model_idx in range(self._num_models):
+                        self._update_batch_metrics(
+                            self._valid_batch_metrics[model_idx],
+                            mses[model_idx],
+                            psnrs[model_idx],
+                            delta_psnrs_wrt_vtm[model_idx],
+                            delta_psnrs_wrt_base[model_idx],
+                        )
+
+                for model_idx in range(self._num_models):
+                    compute_epoch_metrics(
+                        self._valid_batch_metrics[model_idx],
+                        self._valid_epoch_metrics[model_idx],
+                    )
+                    log_epoch_metrics(
+                        self._valid_summary[model_idx],
+                        self._valid_epoch_metrics[model_idx],
+                        epoch,
+                    )
+
+                    curr_epoch_metric[model_idx] = self._valid_epoch_metrics[model_idx][
+                        Colour.YCbCr
+                    ][Metric.LOSS]
+
+            end_training = True
+            for model_idx in range(self._num_models):
+                if curr_epoch_metric[model_idx] < best_epoch_metric[model_idx]:
+                    tf.saved_model.save(
+                        self._models[model_idx], self._model_dirs[model_idx]
+                    )
+                    best_epoch_metric[model_idx] = curr_epoch_metric[model_idx]
+                    num_no_improvements[model_idx] = 0
+                else:
+                    num_no_improvements[model_idx] += 1
+
+                end_training = end_training and (
+                    num_no_improvements[model_idx] == limit_epochs_no_improvements
+                )
+
+            curr_time = time.strftime(TIME_FORMAT)
+            epoch_duration = datetime.strptime(
+                curr_time, TIME_FORMAT
+            ) - datetime.strptime(last_time, TIME_FORMAT)
+            last_time = curr_time
+            print(
+                f'Epoch: {epoch}; Performance: {", ".join([str(v.numpy()) for v in curr_epoch_metric])}; Time: {epoch_duration}'
+            )
+
+            if end_training:
+                print("Training ended earlier!")
+                break
+
+        end_time = time.strftime(TIME_FORMAT)
+        save_time(self._time_file, start_time, end_time)
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/models/filter_with_multipliers.py b/training/training_scripts/NN_Post_Filtering/scripts/models/filter_with_multipliers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff6b12c334da4119cd803ab955689184ea547954
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/models/filter_with_multipliers.py
@@ -0,0 +1,560 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+import tensorflow as tf
+from tensorflow.keras import Model
+from tensorflow.keras.layers import Add, Conv2D, InputLayer, Lambda, Layer, LeakyReLU
+
+
+class Multiplier(Layer):
+    """
+    Multiplier layer. It is initialised with ones
+    """
+
+    def __init__(self, units=1, **kwargs):
+        super(Multiplier, self).__init__(**kwargs)
+        self.units = units
+        self._multiplier = None
+
+    def build(self, input_shape):
+        self._multiplier = tf.Variable(
+            initial_value=tf.ones_initializer()(shape=(self.units,), dtype="float32"),
+            trainable=True,
+        )
+
+    def call(self, inputs, **kwargs):
+        return inputs * self._multiplier
+
+    def get_config(self):
+        config = super(Multiplier, self).get_config()
+        config.update({"units": self.units})
+        return config
+
+
+class FilterWithMultipliers(Model):
+    """
+    Qualcomm's NN filter with multiplier layers.
+    Let W be the conv kernel, b the bias terms and c the multipliers. The layer is applied in the following order:
+    (w * x + b) * c
+    """
+
+    def __init__(self, **kwargs):
+        super(FilterWithMultipliers, self).__init__()
+        self.built = True
+
+        self._input = InputLayer(input_shape=[72, 72, 10])
+
+        self._conv1 = Conv2D(
+            filters=72,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier1 = Multiplier(units=72)
+        self._leaky1 = LeakyReLU(alpha=0.2)
+
+        self._conv2 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier2 = Multiplier(units=72)
+        self._leaky2 = LeakyReLU(alpha=0.2)
+
+        self._conv3 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier3 = Multiplier(units=24)
+
+        self._conv4 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier4 = Multiplier(units=24)
+
+        self._conv5 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier5 = Multiplier(units=72)
+        self._leaky3 = LeakyReLU(alpha=0.2)
+
+        self._conv6 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier6 = Multiplier(units=24)
+
+        self._conv7 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier7 = Multiplier(units=24)
+
+        self._conv8 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier8 = Multiplier(units=72)
+        self._leaky4 = LeakyReLU(alpha=0.2)
+
+        self._conv9 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier9 = Multiplier(units=24)
+
+        self._conv10 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier10 = Multiplier(units=24)
+
+        self._conv11 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier11 = Multiplier(units=72)
+        self._leaky5 = LeakyReLU(alpha=0.2)
+
+        self._conv12 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier12 = Multiplier(units=24)
+
+        self._conv13 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier13 = Multiplier(units=24)
+
+        self._conv14 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier14 = Multiplier(units=72)
+        self._leaky6 = LeakyReLU(alpha=0.2)
+
+        self._conv15 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier15 = Multiplier(units=24)
+
+        self._conv16 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier16 = Multiplier(units=24)
+
+        self._conv17 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier17 = Multiplier(units=72)
+        self._leaky7 = LeakyReLU(alpha=0.2)
+
+        self._conv18 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier18 = Multiplier(units=24)
+
+        self._conv19 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier19 = Multiplier(units=24)
+
+        self._conv20 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier20 = Multiplier(units=72)
+        self._leaky8 = LeakyReLU(alpha=0.2)
+
+        self._conv21 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier21 = Multiplier(units=24)
+
+        self._conv22 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier22 = Multiplier(units=24)
+
+        self._conv23 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier23 = Multiplier(units=72)
+        self._leaky9 = LeakyReLU(alpha=0.2)
+
+        self._conv24 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier24 = Multiplier(units=24)
+
+        self._conv25 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier25 = Multiplier(units=24)
+
+        self._conv26 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier26 = Multiplier(units=72)
+        self._leaky10 = LeakyReLU(alpha=0.2)
+
+        self._conv27 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier27 = Multiplier(units=24)
+
+        self._conv28 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier28 = Multiplier(units=24)
+
+        self._conv29 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier29 = Multiplier(units=72)
+        self._leaky11 = LeakyReLU(alpha=0.2)
+
+        self._conv30 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier30 = Multiplier(units=24)
+
+        self._conv31 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier31 = Multiplier(units=24)
+
+        self._conv32 = Conv2D(
+            filters=72,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier32 = Multiplier(units=72)
+        self._leaky12 = LeakyReLU(alpha=0.2)
+
+        self._conv33 = Conv2D(
+            filters=24,
+            kernel_size=(1, 1),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier33 = Multiplier(units=24)
+
+        self._conv34 = Conv2D(
+            filters=24,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier34 = Multiplier(units=24)
+
+        self._conv35 = Conv2D(
+            filters=6,
+            kernel_size=(3, 3),
+            strides=(1, 1),
+            padding="same",
+            dilation_rate=(1, 1),
+        )
+        self._multiplier35 = Multiplier(units=6)
+
+        self._slice1 = Lambda(lambda x: x[:, 4:68, 4:68, :6])
+        self._slice2 = Lambda(lambda x: x[:, 4:68, 4:68, :])
+        self._add = Add()
+
+        self.build([None, 72, 72, 10])
+
+    def load_pretrained_weights(self, base_model_dir: str) -> None:
+        """
+        Loads weights from the pretrained model
+        :param base_model_dir: Absolute path to the base model
+        """
+        base_model = tf.keras.models.load_model(base_model_dir)
+        for src_var in base_model.variables:
+            var_name = src_var.name
+
+            for dst_var in self.variables:
+                if var_name == dst_var.name:
+                    dst_var.assign(src_var.numpy())
+                    break
+
+    def call(self, input_1):
+        """
+        Applies CNN
+        :param input_1: Patches of size 72x72x10
+        :return: filtered patches
+        """
+        x = self._input(input_1)
+
+        y = self._conv1(x)
+        y = self._multiplier1(y)
+        y = self._leaky1(y)
+
+        y = self._conv2(y)
+        y = self._multiplier2(y)
+        y = self._leaky2(y)
+
+        y = self._conv3(y)
+        y = self._multiplier3(y)
+
+        y = self._conv4(y)
+        y = self._multiplier4(y)
+
+        y = self._conv5(y)
+        y = self._multiplier5(y)
+        y = self._leaky3(y)
+
+        y = self._conv6(y)
+        y = self._multiplier6(y)
+
+        y = self._conv7(y)
+        y = self._multiplier7(y)
+
+        y = self._conv8(y)
+        y = self._multiplier8(y)
+        y = self._leaky4(y)
+
+        y = self._conv9(y)
+        y = self._multiplier9(y)
+
+        y = self._conv10(y)
+        y = self._multiplier10(y)
+
+        y = self._conv11(y)
+        y = self._multiplier11(y)
+        y = self._leaky5(y)
+
+        y = self._conv12(y)
+        y = self._multiplier12(y)
+
+        y = self._conv13(y)
+        y = self._multiplier13(y)
+
+        y = self._conv14(y)
+        y = self._multiplier14(y)
+        y = self._leaky6(y)
+
+        y = self._conv15(y)
+        y = self._multiplier15(y)
+
+        y = self._conv16(y)
+        y = self._multiplier16(y)
+
+        y = self._conv17(y)
+        y = self._multiplier17(y)
+        y = self._leaky7(y)
+
+        y = self._conv18(y)
+        y = self._multiplier18(y)
+
+        y = self._conv19(y)
+        y = self._multiplier19(y)
+
+        y = self._conv20(y)
+        y = self._multiplier20(y)
+        y = self._leaky8(y)
+
+        y = self._conv21(y)
+        y = self._multiplier21(y)
+
+        y = self._conv22(y)
+        y = self._multiplier22(y)
+
+        y = self._conv23(y)
+        y = self._multiplier23(y)
+        y = self._leaky9(y)
+
+        y = self._conv24(y)
+        y = self._multiplier24(y)
+
+        y = self._conv25(y)
+        y = self._multiplier25(y)
+
+        y = self._conv26(y)
+        y = self._multiplier26(y)
+        y = self._leaky10(y)
+
+        y = self._conv27(y)
+        y = self._multiplier27(y)
+
+        y = self._conv28(y)
+        y = self._multiplier28(y)
+
+        y = self._conv29(y)
+        y = self._multiplier29(y)
+        y = self._leaky11(y)
+
+        y = self._conv30(y)
+        y = self._multiplier30(y)
+
+        y = self._conv31(y)
+        y = self._multiplier31(y)
+
+        y = self._conv32(y)
+        y = self._multiplier32(y)
+        y = self._leaky12(y)
+
+        y = self._conv33(y)
+        y = self._multiplier33(y)
+
+        y = self._conv34(y)
+        y = self._multiplier34(y)
+
+        y = self._conv35(y)
+        y = self._multiplier35(y)
+
+        x = self._slice1(x)
+        y = self._slice2(y)
+        y = self._add([x, y])
+        return y
+
+    def get_config(self):
+        config = super(FilterWithMultipliers, self).get_config()
+        config.update({"name": self.name})
+        return config
+
+    @classmethod
+    def from_config(cls, config, custom_objects=None):
+        return cls(**config)
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/models/joint_models.py b/training/training_scripts/NN_Post_Filtering/scripts/models/joint_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cacad6be6ca0c03a6451e73bb8dc40f2522f210
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/models/joint_models.py
@@ -0,0 +1,146 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from typing import Dict, List, Tuple
+
+import tensorflow as tf
+
+from models.base_model import BaseModel
+from util import Colour
+from util.image_ops import add_zeros_to_image
+from util.metrics import compute_loss, compute_psnr_gain
+
+
+class JointModels(BaseModel):
+    def __init__(
+        self,
+        base_model_dirs: List[str],
+        epochs: int,
+        lr: float,
+        batch_size: int,
+        block_size: int,
+        pad_size: int,
+        output_dir: str,
+    ) -> None:
+        super().__init__(
+            "fine-tuning",
+            base_model_dirs,
+            epochs,
+            lr,
+            batch_size,
+            block_size,
+            pad_size,
+            output_dir,
+        )
+
+    @tf.function
+    def step(
+        self, input_data: tf.Tensor, label_data: tf.Tensor, train: bool
+    ) -> Tuple[
+        Dict[int, Dict[int, tf.Tensor]],
+        Dict[int, Dict[int, tf.Tensor]],
+        Dict[int, Dict[int, tf.Tensor]],
+        Dict[int, Dict[int, tf.Tensor]],
+    ]:
+        input_data = add_zeros_to_image(input_data)
+
+        vtm_psnrs = {}
+        base_psnrs = {}
+
+        for model_idx in self._base_models:
+            vtm_psnrs[model_idx], base_psnrs[model_idx] = self.compute_base_metrics(
+                self._base_models[model_idx], input_data, label_data
+            )
+
+        if train:
+            variables = []
+            for model_idx in self._models:
+                variables.append(self._models[model_idx].variables)
+
+            with tf.GradientTape(persistent=True) as tape:
+                tape.watch(variables)
+                predictions = {}
+                pred_mses = {}
+                pred_psnrs = {}
+                delta_psnrs_wrt_vtm = {}
+                delta_psnrs_wrt_base = {}
+                losses = {}
+                losses_exp = []
+
+                for model_idx in self._models:
+                    predictions[model_idx] = self._models[model_idx](input_data)
+                    pred_mses[model_idx], pred_psnrs[model_idx] = compute_loss(
+                        label_data, predictions[model_idx]
+                    )
+                    delta_psnrs_wrt_vtm[model_idx] = compute_psnr_gain(
+                        pred_psnrs[model_idx], vtm_psnrs[model_idx]
+                    )
+                    delta_psnrs_wrt_base[model_idx] = compute_psnr_gain(
+                        pred_psnrs[model_idx], base_psnrs[model_idx]
+                    )
+                    losses_exp.append(tf.math.exp(-pred_mses[model_idx][Colour.YCbCr]))
+                loss_exp_all = tf.math.add_n(losses_exp)
+
+                # loss weights
+                for model_idx, loss_exp in enumerate(losses_exp):
+                    alpha = tf.math.divide(loss_exp, loss_exp_all)
+                    losses[model_idx] = tf.nn.compute_average_loss(
+                        tf.math.multiply(alpha, pred_mses[model_idx][Colour.YCbCr]),
+                        global_batch_size=self._batch_size,
+                    )
+
+            # update gradients
+            for model_idx in losses:
+                gradients = tape.gradient(losses[model_idx], variables[model_idx])
+                self._optimiser.apply_gradients(zip(gradients, variables[model_idx]))
+
+            del tape
+        else:
+            predictions = {}
+            pred_mses = {}
+            pred_psnrs = {}
+            delta_psnrs_wrt_vtm = {}
+            delta_psnrs_wrt_base = {}
+
+            for model_idx in self._models:
+                predictions[model_idx] = self._models[model_idx](input_data)
+                pred_mses[model_idx], pred_psnrs[model_idx] = compute_loss(
+                    label_data, predictions[model_idx]
+                )
+                delta_psnrs_wrt_vtm[model_idx] = compute_psnr_gain(
+                    pred_psnrs[model_idx], vtm_psnrs[model_idx]
+                )
+                delta_psnrs_wrt_base[model_idx] = compute_psnr_gain(
+                    pred_psnrs[model_idx], base_psnrs[model_idx]
+                )
+
+        return pred_mses, pred_psnrs, delta_psnrs_wrt_vtm, delta_psnrs_wrt_base
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/models/single_model.py b/training/training_scripts/NN_Post_Filtering/scripts/models/single_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a293e022c5de0d4c12e539d3837a63698a639f5c
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/models/single_model.py
@@ -0,0 +1,109 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from typing import Dict, List, Tuple
+
+import tensorflow as tf
+
+from models.base_model import BaseModel
+from util import Colour
+from util.image_ops import add_zeros_to_image
+from util.metrics import compute_loss, compute_psnr_gain
+
+
+class SingleModel(BaseModel):
+    def __init__(
+        self,
+        stage: str,
+        base_model_dirs: List[str],
+        epochs: int,
+        lr: float,
+        batch_size: int,
+        block_size: int,
+        pad_size: int,
+        output_dir: str,
+    ) -> None:
+        super().__init__(
+            stage,
+            base_model_dirs,
+            epochs,
+            lr,
+            batch_size,
+            block_size,
+            pad_size,
+            output_dir,
+        )
+
+    def step(
+        self, input_data: tf.Tensor, label_data: tf.Tensor, train: bool
+    ) -> Tuple[
+        List[Dict[int, tf.Tensor]],
+        List[Dict[int, tf.Tensor]],
+        List[Dict[int, tf.Tensor]],
+        List[Dict[int, tf.Tensor]],
+    ]:
+        input_data = add_zeros_to_image(input_data)
+
+        vtm_psnr, base_psnr = self.compute_base_metrics(
+            self._base_models[0], input_data, label_data
+        )
+
+        if train:
+            with tf.GradientTape() as tape:
+                prediction = self._models[0](input_data)
+                pred_mse, pred_psnr = compute_loss(label_data, prediction)
+
+                delta_psnr_wrt_vtm = compute_psnr_gain(pred_psnr, vtm_psnr)
+                delta_psnr_wrt_base = compute_psnr_gain(pred_psnr, base_psnr)
+
+                loss = tf.nn.compute_average_loss(
+                    pred_mse[Colour.YCbCr], global_batch_size=self._batch_size
+                )
+
+                if self._is_fine_tuning:
+                    variables = tape.watched_variables()
+                else:
+                    variables = [
+                        v for v in tape.watched_variables() if "multiplier" in v.name
+                    ]
+
+            gradients = tape.gradient(loss, variables)
+            self._optimiser.apply_gradients(zip(gradients, variables))
+            del tape
+        else:
+            prediction = self._models[0](input_data)
+            pred_mse, pred_psnr = compute_loss(label_data, prediction)
+
+            delta_psnr_wrt_vtm = compute_psnr_gain(pred_psnr, vtm_psnr)
+            delta_psnr_wrt_base = compute_psnr_gain(pred_psnr, base_psnr)
+
+        return [pred_mse], [pred_psnr], [delta_psnr_wrt_vtm], [delta_psnr_wrt_base]
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/overfitting.sh b/training/training_scripts/NN_Post_Filtering/scripts/overfitting.sh
new file mode 100755
index 0000000000000000000000000000000000000000..6b1414de2027819f1e383f36452608b7dce61749
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/overfitting.sh
@@ -0,0 +1,66 @@
+#!/usr/bin/bash
+
+SEQ_NAME=$1
+SEQ_QP=$2
+BASE_MODEL_NAME=$3
+
+conda activate nn-post-filter
+
+################################################### Setup ###################################################
+
+DATASET=${BASE}/post_filter_dataset/valid_data
+PROPS=${BASE}/scripts/resources/properties/valid_data_properties.json
+
+ROOT_OUTPUT_DIR=${BASE}/overfitting
+mkdir -p "${ROOT_OUTPUT_DIR}"
+
+SADL_DIR=${ROOT_OUTPUT_DIR}/sadl
+mkdir -p "${SADL_DIR}"
+
+NNR_DIR=${ROOT_OUTPUT_DIR}/nnr
+mkdir -p "${NNR_DIR}"
+
+OVER_FITTING_DIR=${ROOT_OUTPUT_DIR}/${SEQ_NAME}_${SEQ_QP}
+mkdir -p "${OVER_FITTING_DIR}"
+
+BASE_MODEL=$BASE/finetuning/base_models/${BASE_MODEL_NAME}
+
+############################################### Over-fitting ################################################
+
+cd ${BASE}/scripts || exit
+
+python training.py --stage over-fitting --epochs 200 \
+--train_dir ${DATASET} --train_prop_file ${PROPS} --train_seq_qp "${SEQ_QP}" --seq_name "${SEQ_NAME}" \
+--base_model_dir "${BASE_MODEL}" --output_dir "${OVER_FITTING_DIR}" --cache_dataset
+
+############################################## NNC/NNR encoding ##############################################
+
+cd ${BASE}/NCTM || exit
+
+python run/encode_multipliers.py --seq_name "${SEQ_NAME}" --seq_qp "${SEQ_QP}" \
+--base_model_dir "${BASE_MODEL}" \
+--overfitted_model_dir "${OVER_FITTING_DIR}"/overfitted_"${BASE_MODEL_NAME}"/OutputModel \
+--dataset_dir "${DATASET}" --properties_file "${PROPS}" \
+--output_dir "${NNR_DIR}" --cache_dataset
+
+#################################### NNC/NNR decoding and model quantisation ###################################
+#
+# Int16 model
+
+cd ${BASE}/NCTM || exit
+
+NNR_BITSTREAM_INT16=${NNR_DIR}/${SEQ_NAME}_${SEQ_QP}_int16.nnr
+
+python run/decode_multipliers.py --quantise \
+--base_model_dir "${BASE_MODEL}" --nnr_bitstream "${NNR_BITSTREAM_INT16}" \
+--output_file "${SADL_DIR}/${SEQ_NAME}_${SEQ_QP}_int16.sadl"
+
+# Float model
+
+cd ${BASE}/NCTM || exit
+
+NNR_BITSTREAM_FLOAT=${SADL_DIR}/${SEQ_NAME}_${SEQ_QP}_float.nnr
+
+python run/decode_multipliers.py \
+--base_model_dir "${BASE_MODEL}" --nnr_bitstream "${NNR_BITSTREAM_FLOAT}" \
+--output_file "${SADL_DIR}/${SEQ_NAME}_${SEQ_QP}_float.sadl"
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/prepare_to_code.py b/training/training_scripts/NN_Post_Filtering/scripts/prepare_to_code.py
new file mode 100644
index 0000000000000000000000000000000000000000..efafe61e4efa665975fc958ebef0c23c22bd0d09
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/prepare_to_code.py
@@ -0,0 +1,70 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from pathlib import Path
+
+import click
+
+from util.file_system import check_directory
+from util.preprocessing import convert_mp4_to_yuv, convert_png_to_yuv
+
+
+@click.command()
+@click.option(
+    "--input_dir",
+    default=None,
+    type=click.Path(),
+    help="Input directory",
+)
+@click.option(
+    "--output_dir",
+    default=None,
+    type=click.Path(),
+    help="Output directory",
+)
+@click.option(
+    "--dataset",
+    default="DIV2K",
+    type=click.Choice(["BVI-DVC", "DIV2K"], case_sensitive=True),
+    help="Dataset name",
+)
+def prepare_to_code(input_dir, output_dir, dataset):
+    check_directory(input_dir)
+
+    if "BVI-DVC" in dataset:
+        convert_mp4_to_yuv(Path(input_dir), Path(output_dir))
+    elif "DIV2K" in dataset:
+        convert_png_to_yuv(Path(input_dir), Path(output_dir))
+
+
+if __name__ == "__main__":
+    prepare_to_code()
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/prepare_to_train.py b/training/training_scripts/NN_Post_Filtering/scripts/prepare_to_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27d08126dec2dd6f8a027a2ab3efeec118dbe95
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/prepare_to_train.py
@@ -0,0 +1,81 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from pathlib import Path
+
+import click
+
+from util.file_system import check_directory
+from util.preprocessing import convert_yuv_to_png
+
+
+@click.command()
+@click.option(
+    "--orig_dir",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains original data",
+)
+@click.option(
+    "--deco_dir",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains decoded data",
+)
+@click.option("--output_dir", default=None, type=click.Path(), help="Output directory")
+@click.option(
+    "--dataset",
+    default="DIV2K",
+    type=click.Choice(["BVI-DVC", "DIV2K", "JVET"], case_sensitive=True),
+    help="Dataset name",
+)
+def prepare_to_train(orig_dir, deco_dir, output_dir, dataset):
+    check_directory(orig_dir)
+    orig_dir = Path(orig_dir)
+
+    check_directory(deco_dir)
+    deco_dir = Path(deco_dir)
+
+    output_dir = Path(output_dir)
+    if dataset in ["BVI-DVC", "DIV2K"]:
+        output_dir = output_dir / "train_data"
+    elif "JVET" in dataset:
+        output_dir = output_dir / "valid_data"
+
+    output_dir.mkdir(parents=True, exist_ok=True)
+
+    convert_yuv_to_png(orig_dir, output_dir / "orig", dataset)
+    convert_yuv_to_png(deco_dir, output_dir / "deco", dataset)
+
+
+if __name__ == "__main__":
+    prepare_to_train()
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/print_overfit_cmd.py b/training/training_scripts/NN_Post_Filtering/scripts/print_overfit_cmd.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc1cab79bf8b9e11e218c7d1c7049f85b0303b7
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/print_overfit_cmd.py
@@ -0,0 +1,39 @@
+import json
+import click
+
+
+@click.command()
+@click.option(
+    "--train_dir",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains training images",
+)
+@click.option(
+    "--root_base_model_dir",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains the base model",
+)
+@click.option(
+    "--root_output_dir",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains the base model",
+)
+def main(train_dir, root_base_model_dir, root_output_dir):
+    with open('resources/base_models/models_to_be_overfitted.json', 'r') as stream:
+        config = json.load(stream)
+
+    for seq_name in config:
+        for qp in config[seq_name]:
+            model = config[seq_name][qp]
+            command = f'python training.py --stage over-fitting --epochs 200 --train_dir {train_dir}' \
+                      f'--train_prop_file resources/properties/valid_data_properties.json --train_seq_qp {qp} --seq_name {seq_name} ' \
+                      f'--base_model_dir {root_base_model_dir}/finetuned_model{model}/OutputModel --output_dir {root_output_dir}/{seq_name}_{qp}'
+            print(command)
+        print()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/saved_model.pb b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/saved_model.pb
new file mode 100644
index 0000000000000000000000000000000000000000..3b418ffd615c02a914f2f174f9393afabf522e1e
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/saved_model.pb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5bd123900fbcb94b38316a89a84beb4c8ef774fb7b61aad9ba58070d70e0f60d
+size 1293469
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.data-00000-of-00002 b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.data-00000-of-00002
new file mode 100644
index 0000000000000000000000000000000000000000..eeac87951bf03ee87ef2b972e14625d7cf7a5f5f
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.data-00000-of-00002
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac07200310c3bc120c4d6322cc39782ece0e0fd4ea0a283c6d5922fa1660ebc0
+size 45993
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.data-00001-of-00002 b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.data-00001-of-00002
new file mode 100644
index 0000000000000000000000000000000000000000..f160989ef8d27c8066c182c79305f9cb689ac86c
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.data-00001-of-00002
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:140591397cf8ee99e31a22824132d99786b6ec2db126d9d3d2128e5f6c2f32e8
+size 1292060
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.index b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.index
new file mode 100644
index 0000000000000000000000000000000000000000..ad5c0f7a7d33d7a6c66d4c3e2a20c23f273a69f5
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model1/variables/variables.index
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea41dfbae1b1aa9065485a0642702c807276d3435eaa32ac2ef2f4068783ccc0
+size 15349
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/saved_model.pb b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/saved_model.pb
new file mode 100644
index 0000000000000000000000000000000000000000..b08b4fa090f4256c2ee2bb9a7644e7559814f08a
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/saved_model.pb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6aacc808b236f08b659099c3a3ce499ad50d36545c29b20f97b1f77035e23ec
+size 1283519
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.data-00000-of-00002 b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.data-00000-of-00002
new file mode 100644
index 0000000000000000000000000000000000000000..5b5f1ee3f58995b7b66ec83d0a17c4dce5e91271
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.data-00000-of-00002
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6fea2bb43baa043f18e1eb27599b8730d9cfb7d147754789a77dae2283c9e8e
+size 45243
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.data-00001-of-00002 b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.data-00001-of-00002
new file mode 100644
index 0000000000000000000000000000000000000000..94eb33fa491307c0be30d31d2196a5f17c008ed6
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.data-00001-of-00002
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fc82a05a0ec06fca2c2b5a3d8a81a9f27b2b478cad41cf28a4fb73d309d1250a
+size 1292060
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.index b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.index
new file mode 100644
index 0000000000000000000000000000000000000000..be373b12e83bf1fe51b8ac70722850e7b1c3b15d
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model2/variables/variables.index
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbb6a0e6e8ed55e7a4617a2d9afbc635fba2e95c2e131f892ca874316ed42202
+size 15349
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/saved_model.pb b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/saved_model.pb
new file mode 100644
index 0000000000000000000000000000000000000000..e59212be0b51be629a92dafe980e2ae020ea3792
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/saved_model.pb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9cbc9135249ee3cea5e06f8435c42edad810b7b0f18e56aeb4fb6015ab818011
+size 1283519
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.data-00000-of-00002 b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.data-00000-of-00002
new file mode 100644
index 0000000000000000000000000000000000000000..ffa9120ec33c2e0cc326d129f959314bab9f9f98
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.data-00000-of-00002
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddf9673c72358ee5df377976e52e03b44605d70373f66246b79f3e2c692c4c80
+size 45243
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.data-00001-of-00002 b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.data-00001-of-00002
new file mode 100644
index 0000000000000000000000000000000000000000..e70037dfe730281de07645d424629839972c4f84
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.data-00001-of-00002
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f5d295a721f5d49763d250716fe5f7805dc144c8a3c12b87608db23e7c9304b1
+size 1292060
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.index b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.index
new file mode 100644
index 0000000000000000000000000000000000000000..d715e37e0e23c938aff43aa003a5e7ce72d7da5f
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/orig_base_models/model3/variables/variables.index
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19a36d0274897e0561ba1efb5695f568246c9a6b35edb95a669dcb6d69689a87
+size 15349
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/jvet_labels.json b/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/jvet_labels.json
new file mode 100644
index 0000000000000000000000000000000000000000..ae4c6beac9f3a9efef7d97607c2594b144a9f899
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/jvet_labels.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6560d19e728eb84e9812abb984a90cdd686d0366be9fe83e8cfd479942226fc4
+size 1806
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/train_data_properties.json b/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/train_data_properties.json
new file mode 100644
index 0000000000000000000000000000000000000000..c031c9422b6868a7ef077014d7948a7712a2e853
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/train_data_properties.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ebf2b615d6ab665d57ab152ec7008171e8094398dec086a67f15237900603d4
+size 739312
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/valid_data_properties.json b/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/valid_data_properties.json
new file mode 100644
index 0000000000000000000000000000000000000000..d0441bd08759c0c086a15db50ef570829bc624e3
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/resources/properties/valid_data_properties.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:79cb514251c333ef987c4455860cf09ca6f533c05353bf750442ebdca435f668
+size 3045
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/training.py b/training/training_scripts/NN_Post_Filtering/scripts/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..709d58229d287e679b9ef74259014323f6061ddc
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/training.py
@@ -0,0 +1,263 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+import random
+import sys
+
+import click
+import tensorflow as tf
+
+from models.joint_models import JointModels
+from models.single_model import SingleModel
+from util.dataset import Dataset
+
+
+@click.command()
+@click.option(
+    "--stage",
+    default="over-fitting",
+    type=click.Choice(["fine-tuning", "over-fitting"]),
+    help="Process stage",
+)
+@click.option(
+    "--train_dir",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains training images",
+)
+@click.option(
+    "--train_prop_file",
+    default=None,
+    type=click.Path(),
+    help="JSON file with sequence properties for training datasets",
+)
+@click.option(
+    "--train_seq_qp",
+    default=list(),
+    multiple=True,
+    help="List of sequence QPs for training",
+)
+@click.option(
+    "--valid_dir",
+    default=None,
+    type=click.Path(),
+    help="Directory that contains decoded validation images",
+)
+@click.option(
+    "--valid_prop_file",
+    default=None,
+    type=click.Path(),
+    help="JSON file with sequence properties for validation datasets",
+)
+@click.option(
+    "--valid_seq_qp",
+    default=list(),
+    multiple=True,
+    help="List of sequence QPs for validation",
+)
+@click.option("--bit_depth", default=10, type=int, help="Sequence bit depth")
+@click.option(
+    "--block_size", default=64, type=int, help="Block/patch size of the label"
+)
+@click.option(
+    "--pad_size", default=8, type=int, help="Padding size (sum of left and right)"
+)
+@click.option(
+    "--use_frame_type",
+    default=False,
+    is_flag=True,
+    help="Frames to be selected based on the frame type",
+)
+@click.option(
+    "--frame_type",
+    default="B",
+    type=click.Choice(["I", "P", "B"]),
+    help="Frame type",
+)
+@click.option(
+    "--use_frame_qp",
+    default=False,
+    is_flag=True,
+    help="Frames to be selected based on the frame QP",
+)
+@click.option(
+    "--min_frame_qp", default=0, type=int, help="Minimum frame QP (inclueded)"
+)
+@click.option("--max_frame_qp", default=0, type=int, help="Maximum frame QP (included)")
+@click.option(
+    "--use_random_patches",
+    default=False,
+    is_flag=True,
+    help="Patches are selected randomly",
+)
+@click.option(
+    "--num_patches",
+    default=10,
+    type=int,
+    help="Number of patches to extract from each frame",
+)
+@click.option("--seq_name", default=None, type=str, help="Sequence name/tag")
+@click.option(
+    "--joint_training", default=False, is_flag=True, help="Use joint training"
+)
+@click.option("--epochs", default=50, type=int, help="Max number of epochs")
+@click.option("--lr", default=1e-3, type=float, help="Learning rate")
+@click.option("--batch_size", default=64, type=int, help="Batch size")
+@click.option(
+    "--base_model_dir",
+    default=list(),
+    multiple=True,
+    type=click.Path(),
+    help="Directory that contains the base model",
+)
+@click.option(
+    "--output_dir",
+    default="/tmp/output_dir",
+    type=click.Path(),
+    help="Output directory",
+)
+@click.option(
+    "--cache_dataset",
+    default=False,
+    type=bool,
+    is_flag=True,
+    help="Cache the dataset in RAM",
+)
+def run_model(
+    stage,
+    train_dir,
+    train_prop_file,
+    train_seq_qp,
+    valid_dir,
+    valid_prop_file,
+    valid_seq_qp,
+    bit_depth,
+    block_size,
+    pad_size,
+    use_frame_type,
+    frame_type,
+    use_frame_qp,
+    min_frame_qp,
+    max_frame_qp,
+    use_random_patches,
+    num_patches,
+    seq_name,
+    joint_training,
+    epochs,
+    lr,
+    batch_size,
+    base_model_dir,
+    output_dir,
+    cache_dataset,
+):
+    gpus = tf.config.list_physical_devices("GPU")
+    for gpu in gpus:
+        tf.config.experimental.set_memory_growth(gpu, True)
+
+    if stage == "over-fitting":
+        assert (
+            not joint_training
+        ), "The joint training is only to be used in the fine-tuning stage"
+        assert len(base_model_dir) == 1, "Only one model is over-fitted at the time"
+
+    if joint_training:
+        model = JointModels(
+            base_model_dir, epochs, lr, batch_size, block_size, pad_size, output_dir
+        )
+    else:
+        model = SingleModel(
+            stage,
+            base_model_dir,
+            epochs,
+            lr,
+            batch_size,
+            block_size,
+            pad_size,
+            output_dir,
+        )
+
+    train_data = Dataset(
+        train_dir,
+        train_prop_file,
+        train_seq_qp,
+        bit_depth,
+        block_size,
+        pad_size,
+        use_frame_type,
+        frame_type,
+        use_frame_qp,
+        min_frame_qp,
+        max_frame_qp,
+        use_random_patches,
+        num_patches,
+        False,
+        cache_dataset,
+    )
+    train_data = train_data.create(seq_name, batch_size)
+
+    if stage == "fine-tuning":
+        valid_data = Dataset(
+            valid_dir,
+            valid_prop_file,
+            valid_seq_qp,
+            bit_depth,
+            block_size,
+            pad_size,
+            use_frame_type,
+            frame_type,
+            False,
+            0,
+            0,
+            use_random_patches,
+            num_patches,
+            True,
+            cache_dataset,
+        )
+        valid_data = valid_data.create(None, batch_size)
+    else:
+        valid_data = None
+
+    model.train_loop(train_data, valid_data)
+
+
+def main(*args, **kwargs):
+    print("call: {}".format(" ".join(sys.argv)))
+    run_model(*args, **kwargs)
+
+
+if __name__ == "__main__":
+    """
+    Make sure to set the environment variable `CUDA_VISIBLE_DEVICES=<gpu_idx>` before launching the python process
+    """
+    random.seed(1234)
+    tf.random.set_seed(1234)
+    main()
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/__init__.py b/training/training_scripts/NN_Post_Filtering/scripts/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..79faef8843ba7c7962983ff8fdda743f001b588b
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/__init__.py
@@ -0,0 +1,57 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+
+class Colour:
+    Y = 0
+    Cb = 1
+    Cr = 2
+    YCbCr = 3
+    NUM_COLOURS = 4
+
+
+class Metric:
+    LOSS = 0
+    PSNR = 1
+    DELTA_PSNR_WRT_VTM = 2
+    DELTA_PSNR_WRT_BASE = 3
+    NUM_METRICS = 4
+
+
+COLOUR_LABEL = {Colour.Y: "Y", Colour.Cb: "Cb", Colour.Cr: "Cr", Colour.YCbCr: "YCbCr"}
+COLOUR_WEIGHTS = {Colour.Y: 4.0 / 6.0, Colour.Cb: 1.0 / 6.0, Colour.Cr: 1.0 / 6.0}
+METRIC_LABEL = {
+    Metric.LOSS: "Loss",
+    Metric.PSNR: "PSNR",
+    Metric.DELTA_PSNR_WRT_VTM: "dPSNR_wrt_VTM",
+    Metric.DELTA_PSNR_WRT_BASE: "dPSNR_wrt_base",
+}
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/dataset.py b/training/training_scripts/NN_Post_Filtering/scripts/util/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a49f20407477e1633622774bbfe7c1b89c6b252
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/dataset.py
@@ -0,0 +1,600 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+import random
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from util.file_system import (
+    check_directory,
+    list_dirs,
+    list_selected_dirs,
+    read_json_file,
+)
+from util.image_ops import (
+    extract_patches,
+    extract_random_patches,
+    interleave_image,
+    pad_image,
+    read_image,
+)
+
+
+class Dataset:
+    def __init__(
+        self,
+        root_dir: str,
+        properties_file: str,
+        sequence_qps: Union[Tuple[str], List[str]],
+        bit_depth: int,
+        block_size: int,
+        pad_size: int,
+        use_frame_type: bool,
+        frame_type: str,
+        use_frame_qps: bool,
+        min_frame_qp: int,
+        max_frame_qp: int,
+        use_random_patches: bool,
+        num_patches: int,
+        is_validation: bool,
+        cache_dataset: bool,
+    ):
+        """
+        Constructor
+        :param root_dir: Root directory for the dataset
+        :param properties_file: JSON file with sequence properties for the dataset
+        :param sequence_qps: List of sequence QPs
+        :param bit_depth: Bit-depth for the data
+        :param block_size: Block size, without padding
+        :param pad_size: Padding size
+        :param use_frame_type: Enable/disable the frame selection based on its type
+        :param frame_type: Frame type (i.e. I or B)
+        :param use_frame_qps: Enable/disable the frame selection based on the frame QP
+        :param min_frame_qp: Minimum frame QP (inclusive)
+        :param max_frame_qp: Maximum frame QP (inclusive)
+        :param use_random_patches: Enable/disable the extraction of ramdom patches
+        :param num_patches: Number of random patches to be extracted
+        :param is_validation: Is this the validation dataset?, that means JVET
+        :param cache_dataset: Enable/disable dataset caching in memory
+        """
+        check_directory(root_dir)
+        self._deco_dir = Path(root_dir) / "deco"
+        self._orig_dir = Path(root_dir) / "orig"
+
+        self._sequence_qps = sequence_qps
+        self._bit_depth = bit_depth
+        self._block_size = block_size
+        self._pad_size = pad_size
+
+        self._seq_config = None
+        self._seq_width = None
+        self._seq_height = None
+        self._seq_num_blocks = None
+        self._load_seq_properties(properties_file)
+
+        self._use_frame_type = use_frame_type
+        self._frame_type = frame_type
+
+        self._use_frame_qps = use_frame_qps
+        self._min_frame_qp = min_frame_qp
+        self._max_frame_qp = max_frame_qp
+
+        self._use_random_patches = use_random_patches
+        if self._use_random_patches:
+            assert num_patches > 0, "At least one patch must be extracted per frame"
+        self._num_patches = num_patches
+
+        self._is_validation = is_validation
+        self._cache_dataset = cache_dataset
+
+    def _load_seq_properties(self, properties_file: str) -> None:
+        """
+        Loads the properties file that contains the sequence info, such as tags and dimensions. In addition, three
+        tables are created to map the sequence name to: (1) width, (2) height and (3) number of non-overlapping
+        blocks available
+        :param properties_file: Absolute path to the dataset properties file
+        """
+        self._seq_config = read_json_file(Path(properties_file))
+
+        tags = []
+        widths = []
+        heights = []
+        num_blocks = []
+
+        for seq_tag in self._seq_config.keys():
+            tags.append(seq_tag)
+            w = self._seq_config[seq_tag]["width"]
+            h = self._seq_config[seq_tag]["height"]
+
+            num_h_blocks = w // (self._block_size * 2)
+            if w % (self._block_size * 2) > 0:
+                num_h_blocks += 1
+
+            num_v_blocks = h // (self._block_size * 2)
+            if h % (self._block_size * 2) > 0:
+                num_v_blocks += 1
+
+            widths.append(w)
+            heights.append(h)
+            num_blocks.append(num_h_blocks * num_v_blocks)
+
+        tags = tf.constant(tags)
+        widths = tf.constant(widths)
+        heights = tf.constant(heights)
+        num_blocks = tf.constant(num_blocks)
+
+        self._seq_width = tf.lookup.StaticHashTable(
+            tf.lookup.KeyValueTensorInitializer(tags, widths, tf.string, tf.int32),
+            default_value=tf.constant(-1),
+        )
+        self._seq_height = tf.lookup.StaticHashTable(
+            tf.lookup.KeyValueTensorInitializer(tags, heights, tf.string, tf.int32),
+            default_value=tf.constant(-1),
+        )
+        self._seq_num_blocks = tf.lookup.StaticHashTable(
+            tf.lookup.KeyValueTensorInitializer(tags, num_blocks, tf.string, tf.int32),
+            default_value=tf.constant(-1),
+        )
+
+    def _get_file_list(
+        self, deco_seq_dirs: List[Path]
+    ) -> Tuple[
+        List[str],
+        List[str],
+        List[str],
+        List[str],
+        List[str],
+        List[str],
+        List[float],
+        List[np.array],
+        List[np.array],
+    ]:
+        """
+        Gets the filenames of the images to be processed (reconstruction and original), the frame QP and the random
+        top-left corner positions to extract the patches
+        :param deco_seq_dirs: List of decoded sequence directories
+        :return Lists of reconstruction and original images (one luma and two chroma each), list of frame QPs and list
+        of top-left corner positions
+        """
+        orig_y = []
+        orig_u = []
+        orig_v = []
+
+        reco_y = []
+        reco_u = []
+        reco_v = []
+
+        qp = []
+
+        pos_x = []
+        pos_y = []
+
+        luma_bs = self._block_size * 2
+
+        # here just simply hard-coded how many frames we extract for seqs in each class
+        # based on total frames in each class (500 frames for each class)
+        # e.g. A class: 1494 frames, 6 seqs, extract 84 frames in each seq
+        # B class: 2800 frames, 5 seqs, extract 100 frames in each seq
+        # C class: 1900 frames, 4 seqs, extract 125 frames in each seq
+        # D class: 1900 frames, 4 seqs, extract 125 frames in each seq
+        # F class: 1900 frames, 4 seqs, extract 125 frames in each seq
+        if self._is_validation:
+            frame_dict = {"A": 84, "B": 100, "C": 125, "D": 125, "F": 125}
+
+        for deco_seq_dir in deco_seq_dirs:
+            seq_name = deco_seq_dir.name
+
+            if self._is_validation:
+                num_frames = frame_dict[seq_name[0]]
+
+            w = self._seq_config[seq_name]["width"]
+            h = self._seq_config[seq_name]["height"]
+
+            if self._use_random_patches:
+                max_x = (
+                    luma_bs * (w // luma_bs)
+                    if w % luma_bs > 0
+                    else luma_bs * (w // luma_bs) - luma_bs
+                )
+                x_range = range(self._pad_size, max_x + 1, 2)
+
+                max_y = (
+                    luma_bs * (h // luma_bs)
+                    if h % luma_bs > 0
+                    else luma_bs * (h // luma_bs) - luma_bs
+                )
+                y_range = range(self._pad_size, max_y + 1, 2)
+
+            if len(self._sequence_qps) > 0:
+                qp_dirs = list_selected_dirs(deco_seq_dir, self._sequence_qps)
+            else:
+                qp_dirs = list_dirs(deco_seq_dir)
+
+            for qp_dir in qp_dirs:
+                frames_info = read_json_file(qp_dir / "frames_info.json")
+
+                reco_img_dir = qp_dir / "images"
+                orig_img_dir = self._orig_dir / seq_name / "images"
+
+                reco_files = reco_img_dir.glob("*_y.png")
+
+                if self._is_validation:
+                    reco_files_cp = reco_img_dir.glob("*_y.png")
+                    total_frames = len(list(reco_files_cp))
+                    random_pocs = np.random.choice(
+                        total_frames - 1, num_frames, replace=False
+                    )
+
+                for reco_file in reco_files:
+                    curr_poc = str(int(reco_file.stem.split("_")[0]))
+
+                    if self._is_validation and int(curr_poc) not in random_pocs:
+                        continue
+
+                    curr_frame_type = frames_info[curr_poc]["frame_type"]
+                    curr_frame_qp = frames_info[curr_poc]["QP"]
+
+                    if self._use_frame_type and self._frame_type != curr_frame_type:
+                        continue
+
+                    if self._use_frame_qps and (
+                        curr_frame_qp < self._min_frame_qp
+                        or curr_frame_qp > self._max_frame_qp
+                    ):
+                        continue
+
+                    orig_y.append(str(orig_img_dir / reco_file.name))
+                    orig_u.append(
+                        str(orig_img_dir / f'{reco_file.name.replace("y", "u")}')
+                    )
+                    orig_v.append(
+                        str(orig_img_dir / f'{reco_file.name.replace("y", "v")}')
+                    )
+
+                    reco_y.append(str(reco_file))
+                    reco_u.append(
+                        str(reco_img_dir / f'{reco_file.name.replace("y", "u")}')
+                    )
+                    reco_v.append(
+                        str(reco_img_dir / f'{reco_file.name.replace("y", "v")}')
+                    )
+
+                    qp.append(float(curr_frame_qp))
+
+                    if self._use_random_patches:
+                        pos_x.append(
+                            np.array(random.sample(x_range, self._num_patches))
+                        )
+                        pos_y.append(
+                            np.array(random.sample(y_range, self._num_patches))
+                        )
+                    else:
+                        pos_x.append(0)
+                        pos_y.append(0)
+
+        return orig_y, orig_u, orig_v, reco_y, reco_u, reco_v, qp, pos_x, pos_y
+
+    @tf.function
+    def read_images(
+        self, orig_y, orig_u, orig_v, reco_y, reco_u, reco_v, qp, pos_x, pos_y
+    ):
+        seq_tag = tf.strings.split(orig_y, "/")[-3]
+        width = self._seq_width.lookup(seq_tag)
+        height = self._seq_height.lookup(seq_tag)
+
+        orig_y = read_image(orig_y, self._bit_depth, width, height)
+        orig_u = read_image(orig_u, self._bit_depth, width // 2, height // 2)
+        orig_v = read_image(orig_v, self._bit_depth, width // 2, height // 2)
+
+        reco_y = read_image(reco_y, self._bit_depth, width, height)
+        reco_u = read_image(reco_u, self._bit_depth, width // 2, height // 2)
+        reco_v = read_image(reco_v, self._bit_depth, width // 2, height // 2)
+
+        qp_step = tf.math.pow(2.0, (qp - 42) / 6.0)
+        pos_x = tf.cast(pos_x, tf.int32)
+        pos_y = tf.cast(pos_y, tf.int32)
+
+        return (
+            seq_tag,
+            orig_y,
+            orig_u,
+            orig_v,
+            reco_y,
+            reco_u,
+            reco_v,
+            qp_step,
+            pos_x,
+            pos_y,
+        )
+
+    @tf.function
+    def pre_process_input(self, seq_tag, y, u, v, qp_step, pos_x, pos_y):
+        """
+        Creates input patches
+        :param seq_tag: Sequence tag/name
+        :param y: luma image
+        :param u: cb image
+        :param v: cr image
+        :param qp_step: QP step
+        :param pos_x: left corner positions
+        :param pos_y: top corner positions
+        :return: Input patches
+        """
+        with tf.device("/gpu:0"):
+            width = self._seq_width.lookup(seq_tag)
+            height = self._seq_height.lookup(seq_tag)
+
+            pos_xx = (pos_x - self._pad_size) // 2
+            pos_yy = (pos_y - self._pad_size) // 2
+
+            y = pad_image(y, width, height, self._block_size * 2, self._pad_size)
+            u = pad_image(
+                u, width // 2, height // 2, self._block_size, self._pad_size // 2
+            )
+            v = pad_image(
+                v, width // 2, height // 2, self._block_size, self._pad_size // 2
+            )
+
+            y_tl, y_tr, y_bl, y_br = interleave_image(y)
+
+            if self._use_random_patches:
+                y_tl = extract_random_patches(
+                    y_tl,
+                    self._block_size + self._pad_size,
+                    pos_xx,
+                    pos_yy,
+                    self._num_patches,
+                )
+                y_tr = extract_random_patches(
+                    y_tr,
+                    self._block_size + self._pad_size,
+                    pos_xx,
+                    pos_yy,
+                    self._num_patches,
+                )
+                y_bl = extract_random_patches(
+                    y_bl,
+                    self._block_size + self._pad_size,
+                    pos_xx,
+                    pos_yy,
+                    self._num_patches,
+                )
+                y_br = extract_random_patches(
+                    y_br,
+                    self._block_size + self._pad_size,
+                    pos_xx,
+                    pos_yy,
+                    self._num_patches,
+                )
+                u = extract_random_patches(
+                    u,
+                    self._block_size + self._pad_size,
+                    pos_xx,
+                    pos_yy,
+                    self._num_patches,
+                )
+                v = extract_random_patches(
+                    v,
+                    self._block_size + self._pad_size,
+                    pos_xx,
+                    pos_yy,
+                    self._num_patches,
+                )
+
+                qp_step = tf.fill(
+                    [
+                        self._num_patches,
+                        self._block_size + self._pad_size,
+                        self._block_size + self._pad_size,
+                        1,
+                    ],
+                    qp_step,
+                )
+            else:
+                y_tl = extract_patches(
+                    y_tl, self._block_size + self._pad_size, self._block_size
+                )
+                y_tr = extract_patches(
+                    y_tr, self._block_size + self._pad_size, self._block_size
+                )
+                y_bl = extract_patches(
+                    y_bl, self._block_size + self._pad_size, self._block_size
+                )
+                y_br = extract_patches(
+                    y_br, self._block_size + self._pad_size, self._block_size
+                )
+                u = extract_patches(
+                    u, self._block_size + self._pad_size, self._block_size
+                )
+                v = extract_patches(
+                    v, self._block_size + self._pad_size, self._block_size
+                )
+
+                qp_step = tf.fill(
+                    [
+                        self._seq_num_blocks[seq_tag],
+                        self._block_size + self._pad_size,
+                        self._block_size + self._pad_size,
+                        1,
+                    ],
+                    qp_step,
+                )
+
+            return tf.concat([y_tl, y_tr, y_bl, y_br, u, v, qp_step], axis=3)
+
+    @tf.function
+    def pre_process_label(self, seq_tag, y, u, v, pos_x, pos_y):
+        """
+        Creates label patches
+        :param seq_tag: Sequence tag/name
+        :param y: luma image
+        :param u: cb image
+        :param v: cr image
+        :param pos_x: left corner positions
+        :param pos_y: top corner positions
+        :return: Label patches
+        """
+        with tf.device("/gpu:0"):
+            pos_x = pos_x - self._pad_size
+            pos_y = pos_y - self._pad_size
+
+            width = self._seq_width.lookup(seq_tag)
+            height = self._seq_height.lookup(seq_tag)
+
+            mask = tf.ones_like(y)
+
+            block_size = self._block_size * 2
+
+            mod = tf.math.floormod(height, block_size)
+            out_height = tf.cond(
+                tf.greater(mod, 0), lambda: height + block_size - mod, lambda: height
+            )
+
+            mod = tf.math.floormod(width, block_size)
+            out_width = tf.cond(
+                tf.greater(mod, 0), lambda: width + block_size - mod, lambda: width
+            )
+
+            y = tf.image.pad_to_bounding_box(y, 0, 0, out_height, out_width)
+            mask = tf.image.pad_to_bounding_box(mask, 0, 0, out_height, out_width)
+            u = tf.image.pad_to_bounding_box(u, 0, 0, out_height // 2, out_width // 2)
+            v = tf.image.pad_to_bounding_box(v, 0, 0, out_height // 2, out_width // 2)
+
+            if self._use_random_patches:
+                y = extract_random_patches(
+                    y, block_size, pos_x, pos_y, self._num_patches
+                )
+                mask = extract_random_patches(
+                    mask, block_size, pos_x, pos_y, self._num_patches
+                )
+                u = extract_random_patches(
+                    u, self._block_size, pos_x // 2, pos_y // 2, self._num_patches
+                )
+                v = extract_random_patches(
+                    v, self._block_size, pos_x // 2, pos_y // 2, self._num_patches
+                )
+            else:
+                y = extract_patches(y, block_size, block_size)
+                mask = extract_patches(mask, block_size, block_size)
+                u = extract_patches(u, self._block_size, self._block_size)
+                v = extract_patches(v, self._block_size, self._block_size)
+
+            y_tl, y_tr, y_bl, y_br = interleave_image(y)
+            mask_tl, mask_tr, mask_bl, mask_br = interleave_image(mask)
+
+            return tf.concat(
+                [
+                    y_tl,
+                    y_tr,
+                    y_bl,
+                    y_br,
+                    u,
+                    v,
+                    mask_tl,
+                    mask_tr,
+                    mask_bl,
+                    mask_br,
+                    mask_tl,
+                    mask_tl,
+                ],
+                axis=3,
+            )
+
+    def _apply_pipeline(
+        self, deco_seq_dirs: List[Path], batch_size: int, seed: int
+    ) -> tf.data.Dataset:
+        """
+        Applies the data pipeline
+        :param deco_seq_dirs: List of decoded sequence directories
+        :param batch_size: Batch size
+        :param seed: Seed for "random" operations
+        :return: dataset
+        """
+        file_list = self._get_file_list(deco_seq_dirs)
+
+        dataset = tf.data.Dataset.from_tensor_slices(file_list)
+        dataset = dataset.shuffle(
+            buffer_size=len(dataset), seed=seed, reshuffle_each_iteration=False
+        )
+
+        dataset = dataset.interleave(
+            lambda *args: tf.data.Dataset.from_tensors(self.read_images(*args)),
+            num_parallel_calls=tf.data.experimental.AUTOTUNE,
+        )
+
+        dataset = dataset.interleave(
+            lambda seq_tag, orig_y, orig_u, orig_v, reco_y, reco_u, reco_v, qp, pos_x, pos_y: tf.data.Dataset.zip(
+                (
+                    tf.data.Dataset.from_tensor_slices(
+                        self.pre_process_input(
+                            seq_tag, reco_y, reco_u, reco_v, qp, pos_x, pos_y
+                        )
+                    ),
+                    tf.data.Dataset.from_tensor_slices(
+                        self.pre_process_label(
+                            seq_tag, orig_y, orig_u, orig_v, pos_x, pos_y
+                        )
+                    ),
+                )
+            ),
+            num_parallel_calls=tf.data.experimental.AUTOTUNE,
+        )
+
+        dataset = dataset.batch(batch_size, drop_remainder=True)
+        if self._cache_dataset:
+            dataset = dataset.cache()
+
+        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+
+        return dataset
+
+    def create(
+        self, seq_name: Optional[str], batch_size: int, seed: int = 1234
+    ) -> tf.data.Dataset:
+        """
+        Creates the dataset
+        :param seq_name: Sequence name
+        :param batch_size: Batch size
+        :param seed: Seed for "random" operations
+        :return: Dataset
+        """
+        if seq_name is None:
+            deco_seq_dirs = list_dirs(self._deco_dir)
+            random.shuffle(deco_seq_dirs)
+        else:
+            deco_seq_dirs = [self._deco_dir / seq_name]
+
+        dataset = self._apply_pipeline(deco_seq_dirs, batch_size, seed)
+        return dataset
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/file_system.py b/training/training_scripts/NN_Post_Filtering/scripts/util/file_system.py
new file mode 100644
index 0000000000000000000000000000000000000000..b31019dbfe6162012723055a80b882a3777d921c
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/file_system.py
@@ -0,0 +1,127 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+import json
+import os
+from pathlib import Path
+from typing import Dict, List, Tuple, Union
+
+
+def check_directory(input_path: str) -> None:
+    """
+    Checks whether the given path exists and corresponds to a directory
+    :param input_path: Absolute path
+    :return:
+    """
+    assert os.path.exists(input_path) and os.path.isdir(
+        input_path
+    ), f"{input_path} is not a directory"
+
+
+def check_file(input_path: str) -> None:
+    """
+    Checks whether the given path exists and corresponds to a file
+    :param input_path: Absolute path
+    """
+    assert os.path.exists(input_path) and os.path.isfile(
+        input_path
+    ), f"{input_path} is not a file"
+
+
+def list_dirs(input_dir: Path) -> List[Path]:
+    """
+    Lists the subdirectories of the given directory
+    :param input_dir: Input directory
+    :return: List of directories
+    """
+    return [f for f in input_dir.iterdir() if f.is_dir()]
+
+
+def list_selected_dirs(
+    input_dir: Path, pattern: Union[Tuple[str], List[str]]
+) -> List[Path]:
+    """
+    Lists the subdirectories that contain a given text in their names
+    :param input_dir: Input directory
+    :param pattern: text to match
+    :return: List of directories
+    """
+    return [f for f in input_dir.iterdir() if f.is_dir() and f.name in pattern]
+
+
+def read_json_file(json_path: Path) -> Dict:
+    """
+    Reads JSON file
+    :param json_path: Absolute path to the JSON file
+    :return: Dictionary containing JSON file data
+    """
+    assert json_path.exists() and json_path.is_file(), f"{json_path} is not a file"
+    with open(json_path, "r") as stream:
+        config = json.load(stream)
+    return config
+
+
+def write_json_file(content: Dict, output_file: Path) -> None:
+    """
+    Writes a dictionary to a JSON file
+    :param content: Dictionary to be saved
+    :param output_file: Absolute path to the JSON file
+    """
+    assert (
+        output_file.parent.exists() and output_file.parent.is_dir()
+    ), f"The parent directory {output_file.parent} does not exist"
+    with open(output_file, "w") as stream:
+        json.dump(content, stream, sort_keys=True, indent=4)
+
+
+def create_vtm_config_file(
+    cfg_file: Path, filename: Path, width: int, height: int, fps: int, num_frames: int
+) -> None:
+    """
+    Creates the sequence config file for VTM encoding
+    :param cfg_file: Output file name
+    :param filename: YUV file name
+    :param width: Width of the YUV
+    :param height: Height of the YUV
+    :param fps: Frame rate of the YUV
+    :param num_frames: Number of frames to be encoded
+    """
+    with open(cfg_file, "w") as stream:
+        stream.write(f"InputFile:           {filename}\n")
+        stream.write(f"SourceWidth:         {width}\n")
+        stream.write(f"SourceHeight:        {height}\n")
+        stream.write(f"InputBitDepth:       10\n")
+        stream.write(f"InputChromaFormat:   420\n")
+        stream.write(f"FrameRate:           {fps}\n")
+        stream.write(f"FrameSkip:           0\n")
+        stream.write(f"FramesToBeEncoded:   {num_frames}\n")
+        stream.write(f"Level:               5.1\n")
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/image_ops.py b/training/training_scripts/NN_Post_Filtering/scripts/util/image_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef7bce83fb2d6ceeaac828826d94cdb065b35e31
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/image_ops.py
@@ -0,0 +1,204 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from typing import Tuple
+
+import tensorflow as tf
+
+
+@tf.function
+def read_image(filename: str, bit_depth: int, width: int, height: int) -> tf.Tensor:
+    """
+    Reads an image
+    :param filename: Absolute path to the image
+    :param bit_depth: Target bit-depth
+    :param width: Width of the image
+    :param height: Height of the image
+    :return: 4D tensor BHWC
+    """
+    image = tf.io.read_file(filename)
+    image = tf.image.decode_png(image, 1, tf.uint16)
+    image = tf.cast(tf.image.resize(image, [height, width]), tf.uint16)
+    image = tf.bitwise.right_shift(image, 16 - bit_depth)
+    image = tf.expand_dims(image, axis=0)
+    image = normalise_image(image, bit_depth)
+    return image
+
+
+@tf.function
+def normalise_image(image: tf.Tensor, bit_depth: int) -> tf.Tensor:
+    """
+    Normalises an image to the range [0, 1]
+    :param image: Input image
+    :param bit_depth: Bit-depth of the image
+    :return: Normalised image, 4D tensor BHWC
+    """
+    image = tf.cast(image, tf.float32) / (2**bit_depth - 1)
+    return image
+
+
+@tf.function
+def pad_image(
+    in_image: tf.Tensor, width: int, height: int, block_size: int, pad_size: int
+) -> tf.Tensor:
+    """
+    Applies padding to the input image
+    :param in_image: Input image
+    :param width: Width of the image
+    :param height: Height of the image
+    :param block_size: Size of the actual block (final output size)
+    :param pad_size: Number of samples added to each side of the block size
+    :return: Padded image
+    """
+    left = tf.expand_dims(in_image[:, :, 0, :], axis=2)
+    left = tf.tile(left, [1, 1, pad_size, 1])
+
+    right = tf.expand_dims(in_image[:, :, -1, :], axis=2)
+
+    mod = tf.math.floormod(width, block_size)
+    right = tf.cond(
+        tf.greater(mod, 0),
+        lambda: tf.tile(right, [1, 1, pad_size + block_size - mod, 1]),
+        lambda: tf.tile(right, [1, 1, pad_size, 1]),
+    )
+
+    out_image = tf.concat([left, in_image, right], axis=2)
+
+    top = tf.expand_dims(out_image[:, 0, :, :], axis=1)
+    top = tf.tile(top, [1, pad_size, 1, 1])
+
+    bottom = tf.expand_dims(out_image[:, -1, :, :], axis=1)
+
+    mod = tf.math.floormod(height, block_size)
+    bottom = tf.cond(
+        tf.greater(mod, 0),
+        lambda: tf.tile(bottom, [1, pad_size + block_size - mod, 1, 1]),
+        lambda: tf.tile(bottom, [1, pad_size, 1, 1]),
+    )
+
+    out_image = tf.concat([top, out_image, bottom], axis=1)
+    return out_image
+
+
+@tf.function
+def interleave_image(
+    image: tf.Tensor,
+) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
+    """
+    Interleaves an image into four partitions.
+    Example: http://casu.ast.cam.ac.uk/surveys-projects/wfcam/technical/interleaving
+    :param image: Input image
+    :return: Four image partitions
+    """
+    tl = image[:, 0::2, 0::2, :]
+    tr = image[:, 0::2, 1::2, :]
+    bl = image[:, 1::2, 0::2, :]
+    br = image[:, 1::2, 1::2, :]
+    return tl, tr, bl, br
+
+
+@tf.function
+def de_interleave_luma(block: tf.Tensor, block_sie: int = 2) -> tf.Tensor:
+    """
+    De-interleaves four image partitions into a single image
+    :param block: Image partitions in the form BHWC, where B = 4
+    :param block_sie:
+    :return: Full image
+    """
+    return tf.nn.depth_to_space(block, block_sie)
+
+
+@tf.function
+def extract_patches(image: tf.Tensor, block_size: int, step: int) -> tf.Tensor:
+    """
+    Extracts patches of an image in Z-scan order
+    :param image: Input image
+    :param block_size: Block/patch size
+    :param step: Step size
+    :return: Patches concatenated in the batch dimension
+    """
+    patches = tf.image.extract_patches(
+        image, [1, block_size, block_size, 1], [1, step, step, 1], [1, 1, 1, 1], "VALID"
+    )
+    patches = tf.reshape(patches, [-1, block_size, block_size, image.shape[-1]])
+
+    return patches
+
+
+@tf.function
+def extract_random_patches(
+    image: tf.Tensor, block_size: int, pos_x, pos_y, num_patches: int
+) -> tf.Tensor:
+    """
+    Extracts random patches out of the input image
+    :param image: Input image 4D tensor
+    :param block_size: Patch size
+    :param pos_x: Left corner position
+    :param pos_y: Top corner position
+    :param num_patches: Number of patches to be extracted
+    :return: Patches concatenated in the batch dimension
+    """
+    patches = []
+
+    for i in range(num_patches):
+        patch = tf.image.crop_to_bounding_box(
+            image, pos_y[i], pos_x[i], block_size, block_size
+        )
+        patch = tf.squeeze(patch, axis=0)
+        patches.append(patch)
+
+    patches = tf.stack(patches, axis=0)
+    return patches
+
+
+@tf.function
+def merge_images(first: tf.Tensor, second: tf.Tensor) -> tf.Tensor:
+    """
+    Merges two images in the channel dimension
+    :param first: First image
+    :param second: Second image
+    :return: Merged images
+    """
+    return tf.concat([first, second], axis=3)
+
+
+@tf.function
+def add_zeros_to_image(image: tf.Tensor, channels: int = 3) -> tf.Tensor:
+    """
+    Add zero-filled channels to the input image
+    :param image: Input image
+    :param channels: Number of zero-filled channels to add
+    :return: Image with zero padded channels
+    """
+    zero_bs = tf.zeros_like(image)[:, :, :, :channels]
+    image = merge_images(image, zero_bs)
+    return image
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/logging.py b/training/training_scripts/NN_Post_Filtering/scripts/util/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..f30322c734a671fd734d54e21b65a1719948034f
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/logging.py
@@ -0,0 +1,75 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from datetime import datetime
+from typing import Dict
+
+import tensorflow as tf
+
+from util import Colour, COLOUR_LABEL, Metric, METRIC_LABEL
+
+TIME_FORMAT: str = "%Y%m%d-%H%M%S"
+
+
+def save_time(filename: str, start_time: str, end_time: str) -> None:
+    """
+    Writes in a file the time start time, end time and duration (end time - start time)
+    :param filename: Output file
+    :param start_time: Start time (string format)
+    :param end_time: End time (string format)
+    """
+    duration = datetime.strptime(end_time, TIME_FORMAT) - datetime.strptime(
+        start_time, TIME_FORMAT
+    )
+    with open(filename, "w") as stream:
+        stream.writelines("start,end,duration\n")
+        stream.writelines(f"{start_time},{end_time},{duration}\n")
+
+
+def log_epoch_metrics(
+    summary_writer: tf.summary.SummaryWriter,
+    epoch_metrics: Dict[int, Dict[int, float]],
+    epoch: int,
+) -> None:
+    """
+    Logs metrics for an epoch. The data can be visualised in TensorBoard
+    :param summary_writer: Summary writer
+    :param epoch_metrics: Metrics associated to the given epoch index
+    :param epoch: Epoch index
+    """
+    with summary_writer.as_default(step=epoch):
+        for colour in range(Colour.NUM_COLOURS):
+            for metric in range(Metric.NUM_METRICS):
+                tf.summary.scalar(
+                    f"{COLOUR_LABEL[colour]}_{METRIC_LABEL[metric]}",
+                    epoch_metrics[colour][metric],
+                )
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/metrics.py b/training/training_scripts/NN_Post_Filtering/scripts/util/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..afdf572768b567c2e00c59d97e6e81fdb82074a4
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/metrics.py
@@ -0,0 +1,139 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from typing import Dict, Tuple
+
+import tensorflow as tf
+
+from util import Colour, COLOUR_WEIGHTS, Metric
+from util.image_ops import de_interleave_luma
+
+
+@tf.function
+def compute_metrics(
+    ground_truth: tf.Tensor, prediction: tf.Tensor
+) -> Tuple[tf.Tensor, tf.Tensor]:
+    """
+    Computes the MSE and the PSNR between two 4D tensors (single channel), the computation is done in the batch dimension
+    :param ground_truth: Ground-truth tensor
+    :param prediction: Test tensor
+    :return: MSE and PSNR
+    """
+    mse = tf.reduce_mean(
+        tf.math.squared_difference(ground_truth, prediction), axis=[1, 2, 3]
+    )
+    psnr = 10.0 * (tf.math.log(1.0 / mse) / tf.math.log(10.0))
+    return mse, psnr
+
+
+@tf.function
+def compute_loss(
+    ground_truth: tf.Tensor,
+    prediction: tf.Tensor,
+    loss_weights: Dict[int, float] = COLOUR_WEIGHTS,
+) -> Tuple[Dict[int, tf.Tensor], Dict[int, tf.Tensor]]:
+    """
+    Computes the loss function and the associated PSNR for all channels
+    :param ground_truth: Ground-truth tensor
+    :param prediction: Test tensor
+    :param loss_weights: Weights used to compute the average across all channels
+    :return: Channel-wise loss and channel-wise PSNR
+    """
+    mask = ground_truth[:, :, :, 6:]
+    ground_truth = ground_truth[:, :, :, :6]
+    ground_truth = tf.multiply(ground_truth, mask)
+    prediction = tf.multiply(prediction, mask)
+
+    y_orig = de_interleave_luma(ground_truth[:, :, :, :4])
+    y_pred = de_interleave_luma(prediction[:, :, :, :4])
+    y_mse, y_psnr = compute_metrics(y_orig, y_pred)
+
+    cb_mse, cb_psnr = compute_metrics(
+        ground_truth[:, :, :, 4:5], prediction[:, :, :, 4:5]
+    )
+    cr_mse, cr_psnr = compute_metrics(
+        ground_truth[:, :, :, 5:6], prediction[:, :, :, 5:6]
+    )
+
+    y_weight = loss_weights[Colour.Y]
+    cb_weight = loss_weights[Colour.Cb]
+    cr_weight = loss_weights[Colour.Cr]
+
+    mse = y_weight * y_mse + cb_weight * cb_mse + cr_weight * cr_mse
+    psnr = y_weight * y_psnr + cb_weight * cb_psnr + cr_weight * cr_psnr
+
+    mse = {Colour.Y: y_mse, Colour.Cb: cb_mse, Colour.Cr: cr_mse, Colour.YCbCr: mse}
+    psnr = {
+        Colour.Y: y_psnr,
+        Colour.Cb: cb_psnr,
+        Colour.Cr: cr_psnr,
+        Colour.YCbCr: psnr,
+    }
+
+    return mse, psnr
+
+
+@tf.function
+def compute_psnr_gain(
+    test_psnr: Dict[int, tf.Tensor], base_psnr: Dict[int, tf.Tensor]
+) -> Dict[int, tf.Tensor]:
+    """
+    Computes the PSNR gain (delta PSNR = filtered reconstruction PSNR - VTM reconstruction PSNR).
+    Note that if any input PSNR is infinite, the PSNR gain is zero
+    :param test_psnr: PSNR of the filtered reconstruction
+    :param base_psnr: PSNR of the VTM reconstruction
+    :return: PSNR gain
+    """
+    psnr_gain = {}
+    for colour in test_psnr.keys():
+        diff = test_psnr[colour] - base_psnr[colour]
+        is_inf = tf.reduce_any(
+            [tf.math.is_inf(test_psnr[colour]), tf.math.is_inf(base_psnr[colour])],
+            axis=0,
+        )
+        psnr_gain[colour] = tf.where(is_inf, 0.0, diff)
+    return psnr_gain
+
+
+def compute_epoch_metrics(
+    batch_metrics: Dict[int, Dict[int, tf.metrics.Mean]],
+    epoch_metrics: Dict[int, Dict[int, int]],
+) -> None:
+    """
+    Computes the epoch metrics
+    :param batch_metrics: Batch metrics before accumulation
+    :param epoch_metrics: Epoch metrics
+    """
+    for colour in range(Colour.NUM_COLOURS):
+        for metric in range(Metric.NUM_METRICS):
+            epoch_metrics[colour][metric] = batch_metrics[colour][metric].result()
+            batch_metrics[colour][metric].reset_states()
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/parsing.py b/training/training_scripts/NN_Post_Filtering/scripts/util/parsing.py
new file mode 100644
index 0000000000000000000000000000000000000000..d77e9fcfcedc1925e40d66251849df8d5e1a8f72
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/parsing.py
@@ -0,0 +1,61 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+from pathlib import Path
+from typing import Dict
+
+from util.regex import get_frame_info_from_decoder_log_line
+
+
+def extract_bitstream_info(log_file: Path) -> Dict:
+    """
+    Extracts information from the decoder log file, i.e. POC, frame QP, frame type and temporal layer
+    :param log_file: VTM decoder log
+    :return Dictionary that contains the compressed info frame-wise
+    """
+    info = {}
+    with open(log_file, "r") as stream:
+        for line in stream:
+            if line.startswith("POC"):
+                (
+                    poc,
+                    temporal_layer,
+                    frame_type,
+                    qp,
+                ) = get_frame_info_from_decoder_log_line(line)
+                info[poc] = {
+                    "temporal_layer": temporal_layer,
+                    "frame_type": frame_type,
+                    "QP": qp,
+                }
+
+    return info
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/preprocessing.py b/training/training_scripts/NN_Post_Filtering/scripts/util/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..97708c9bb2c3ab95a4fd17622b9b18a184c133dd
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/preprocessing.py
@@ -0,0 +1,176 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+import os
+from pathlib import Path
+
+from imageio import imread, imwrite
+
+from util.file_system import create_vtm_config_file, read_json_file, write_json_file
+from util.parsing import extract_bitstream_info
+from util.regex import (
+    DIV2K_ORG_IMG_LR_FILENAME,
+    get_info_from_video_stem,
+    get_lr_method_from_div2k_dir_name,
+    is_regex_in_text,
+    TARGET_PNG_FORMAT,
+    TARGET_YUV_FORMAT,
+)
+
+
+def convert_mp4_to_yuv(
+    input_dir: Path, output_dir: Path, yuv_format: str = TARGET_YUV_FORMAT
+) -> None:
+    """
+    Converts a set of mp4 videos to YUV. It also creates the configuration files for VTM encoding
+    :param input_dir: Directory that contains the input mp4 videos
+    :param output_dir: Directory to save the output YUV videos and configuration files
+    :param yuv_format: Format for the YUV (e.g. yuv420p10le)
+    """
+    mp4_files = input_dir.glob("**/*.mp4")
+
+    output_dir.mkdir(parents=True, exist_ok=True)
+
+    for mp4_file in mp4_files:
+        file_stem = mp4_file.stem
+        yuv_dir = output_dir / file_stem
+        yuv_dir.mkdir(parents=True, exist_ok=True)
+
+        yuv_file = yuv_dir / f"{file_stem}.yuv"
+        command = f"ffmpeg -i {mp4_file} -pix_fmt {yuv_format} {yuv_file}"
+        os.system(command)
+
+        label, width, height, fps = get_info_from_video_stem(file_stem)
+
+        cfg_file = yuv_dir / f"{label}.cfg"
+        create_vtm_config_file(cfg_file, yuv_file, width, height, fps, 64)
+
+
+def convert_png_to_yuv(
+    input_dir: Path, output_dir: Path, yuv_format: str = TARGET_YUV_FORMAT
+) -> None:
+    """
+    Converts a set of PNGs to YUVs. It also creates the configuration files for VTM encoding
+    :param input_dir: Directory that contains the PNG images
+    :param output_dir: Directory to save the output YUV videos and configuration files
+    :param yuv_format: Format for the YUV (e.g. yuv420p10le)
+    """
+    org_imgs = input_dir.glob("**/*.png")
+
+    output_dir = Path(output_dir)
+
+    for org_img in org_imgs:
+        img = imread(org_img)
+        height, width, _ = img.shape
+
+        new_width = width - (width % 8)
+        new_height = height - (height % 8)
+
+        if width != new_width or height != new_height:
+            img = img[:new_height, :new_width, :]
+            imwrite(org_img, img)
+
+        if is_regex_in_text(org_img.name, DIV2K_ORG_IMG_LR_FILENAME):
+            lr_method = get_lr_method_from_div2k_dir_name(org_img.parent.parent.name)
+            output_file = org_img.parent / f"{org_img.stem}_{lr_method}.png"
+            org_img.rename(output_file)
+
+    org_imgs = input_dir.glob("**/*.png")
+
+    for org_img in org_imgs:
+        img = imread(org_img)
+        height, width, _ = img.shape
+        file_stem = f"{org_img.stem}_{width}x{height}_30_{yuv_format}"
+        yuv_dir = output_dir / file_stem
+        yuv_dir.mkdir(parents=True, exist_ok=True)
+        dst_video = yuv_dir / f"{file_stem}.yuv"
+        command = f"ffmpeg -i {org_img} -pix_fmt {yuv_format} {dst_video}"
+        os.system(command)
+
+        cfg_file = yuv_dir / f"{org_img.stem}.cfg"
+        create_vtm_config_file(cfg_file, dst_video, width, height, 1, 1)
+
+
+def convert_yuv_to_png(
+    root_input_dir: Path,
+    root_output_dir: Path,
+    dataset: str,
+    png_format: str = TARGET_PNG_FORMAT,
+) -> None:
+    """
+    Converts a YUV video to a PNG image. The conversion is done with ffmpeg. The output pattern is %05d.png and the
+    start number is 0
+    It also generates a JSON file with the info per frame of a compressed video
+    :param root_input_dir: Directory that contains the YUV
+    :param root_output_dir: Directory to save the generated PNGs
+    :param dataset: Dataset name
+    :param png_format: Output format for the PNG (e.g. gray16be)
+    """
+    yuv_files = list(root_input_dir.glob("**/*.yuv"))
+    channels = ["y", "u", "v"]
+    qps = ["22", "27", "32", "37", "42"]
+
+    if dataset == "JVET":
+        dataset_info = read_json_file(Path("resources/properties/jvet_labels.json"))
+
+    for yuv_file in yuv_files:
+        if yuv_file.parent.name in qps:
+            seq_name = yuv_file.parent.parent.name
+        else:
+            seq_name = yuv_file.stem
+
+        label, width, height, fps = get_info_from_video_stem(seq_name)
+
+        if dataset == "JVET":
+            label = dataset_info[seq_name]["label"]
+
+        if yuv_file.parent.name in qps:
+            img_output_dir = (
+                root_output_dir / f"{label}" / yuv_file.parent.name / "images"
+            )
+        else:
+            img_output_dir = root_output_dir / f"{label}" / "images"
+
+        img_output_dir.mkdir(parents=True, exist_ok=True)
+
+        if yuv_file.parent.name in qps:
+            decoder_log = yuv_file.parent / "log_dec.txt"
+            frame_info = extract_bitstream_info(decoder_log)
+            write_json_file(frame_info, img_output_dir.parent / "frames_info.json")
+
+        for ch in channels:
+            output_file = img_output_dir / f"%05d_{ch}.png"
+            command = (
+                f"ffmpeg -y -r {fps} -s {width}x{height} -pix_fmt {TARGET_YUV_FORMAT} -i {yuv_file} "
+                f'-vf "extractplanes={ch}" -pix_fmt {png_format} -start_number 0 {output_file}'
+            )
+            os.system(command)
diff --git a/training/training_scripts/NN_Post_Filtering/scripts/util/regex.py b/training/training_scripts/NN_Post_Filtering/scripts/util/regex.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a95e1dd6376136581d201f650f95aae12cc16c7
--- /dev/null
+++ b/training/training_scripts/NN_Post_Filtering/scripts/util/regex.py
@@ -0,0 +1,98 @@
+# The copyright in this software is being made available under the BSD
+# License, included below. This software may be subject to other third party
+# and contributor rights, including patent rights, and no such rights are
+# granted under this license.
+#
+# Copyright (c) 2010-2022, ITU/ISO/IEC
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#  * Redistributions of source code must retain the above copyright notice,
+#    this list of conditions and the following disclaimer.
+#  * Redistributions in binary form must reproduce the above copyright notice,
+#    this list of conditions and the following disclaimer in the documentation
+#    and/or other materials provided with the distribution.
+#  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+#    be used to endorse or promote products derived from this software without
+#    specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+import re
+from typing import Tuple
+
+
+DIV2K_ORG_IMG_LR_FILENAME = r"\d+x\d*\.png"
+TARGET_PNG_FORMAT = "gray16be"
+TARGET_YUV_FORMAT = "yuv420p10le"
+
+
+def get_regex_groups_for_pattern(
+    text: str, pattern: str, expected_number_of_groups: int
+) -> re.match:
+    """
+    Matches a regular expression from the beginning of a given text
+    :param text: Input text
+    :param pattern: Regular expression
+    :param expected_number_of_groups: number of groups to extract
+    :return: match
+    """
+    match = re.match(pattern, text)
+    if match is None or len(match.groups()) != expected_number_of_groups:
+        raise ValueError(f"Couldn't find pattern {pattern} in '{text}'")
+    return match
+
+
+def is_regex_in_text(text: str, pattern: str) -> bool:
+    """
+    It returns true if a regular expression is found in the text
+    :param text: Input text
+    :param pattern: Regular expression
+    :return: Whether the regular expression is in the text
+    """
+    return bool(re.search(pattern, text))
+
+
+def get_frame_info_from_decoder_log_line(text: str) -> Tuple[int, int, str, int]:
+    """
+    Given an input text, it extracts video frame information
+    :param text: Input text
+    :return: POC, temporal layer, slice type and QP
+    """
+    file_pattern = r"POC\s+(\d+)\s+LId:\s+\d+\s+TId:\s+(\d+)\s+\(\s+.*,\s+([I|P|B])-\w+,\s+QP\s+(\d+)"
+    match = get_regex_groups_for_pattern(text, file_pattern, 4)
+    return int(match.group(1)), int(match.group(2)), match.group(3), int(match.group(4))
+
+
+def get_info_from_video_stem(stem: str) -> Tuple[str, int, int, int]:
+    """
+    Given the stem of a video, it extracts video information
+    :param stem: Stem of the input video
+    :return: video label, width, height and fps
+    """
+    stem_pattern = r"(\w+[_\w]*|\d+x\d_\w+)_(\d+)x(\d+)p*_(\d+)[fps]*"
+    match = get_regex_groups_for_pattern(stem, stem_pattern, 4)
+    return match.group(1), int(match.group(2)), int(match.group(3)), int(match.group(4))
+
+
+def get_lr_method_from_div2k_dir_name(dir_name: str) -> str:
+    """
+    Extracts the down-sampling method out of the directory name in DIV2K
+    :param dir_name: Directory name
+    :return: down-sampling method
+    """
+    dir_name_pattern = r"\w+_\w+_\w+_(\w+)*"
+    match = get_regex_groups_for_pattern(dir_name, dir_name_pattern, 1)
+    return match.group(1)