diff --git a/training/training_scripts/Nn_Filtering_Set_0/2_generate_compression_data/2_ReadMe.md b/training/training_scripts/Nn_Filtering_Set_0/2_generate_compression_data/2_ReadMe.md
index 0362cdaebb575010f182a401d1c40fa6679ac315..15d80ae8d56c71d03195aeab5a1a59dd045bd0f6 100644
--- a/training/training_scripts/Nn_Filtering_Set_0/2_generate_compression_data/2_ReadMe.md
+++ b/training/training_scripts/Nn_Filtering_Set_0/2_generate_compression_data/2_ReadMe.md
@@ -107,7 +107,7 @@ When the compression dataset generation is finished, the file structure should b
 ```
 
 ## Generate the dataset json
-Run generate_dataset_json.sh to generate the training and test dataset json files.
+Run generate_dataset_json.sh to generate the training and test dataset json files. Before running this script, concatenate_dataset.py, placed in ../../../tools, should firstly be copied into the current directory.
 
 ## Other
 Because the neural network based filter is used in the generation process of the second stage dataset, the speed is relative slow if it is run on CPU.
diff --git a/training/training_scripts/Nn_Filtering_Set_0/2_generate_compression_data/concatenate_dataset.py b/training/training_scripts/Nn_Filtering_Set_0/2_generate_compression_data/concatenate_dataset.py
deleted file mode 100644
index 55cf09c994456693ea846483859eaa932c44df70..0000000000000000000000000000000000000000
--- a/training/training_scripts/Nn_Filtering_Set_0/2_generate_compression_data/concatenate_dataset.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import argparse
-import glob
-import sys
-import json
-import re
-import os
-
-parser = argparse.ArgumentParser(prog='concatenate dataset', usage='create a global dataset from all the json file in a given directory. ', 
-                                  formatter_class=argparse.RawDescriptionHelpFormatter,
-                                 epilog=
-'''2 modes available:
-   concatenate_dataset.py --input_dir dir1 --input_dir dir2 --output_json pre_dataset.json
-   concatenate_dataset.py --input_json pre_dataset.json --input_dir_encoder direnc1 --input_dir_encoder direnc2 --output_json dataset.json''')
-parser.add_argument("--input_dir_json", action="append", nargs='+', type=str, help="directory containing individual json files. Multiple options possible.")
-parser.add_argument("--dataset_range", type=str, default='', help="train or test dataset range (such as 0-1000), use all data by default")
-parser.add_argument("--input_json", action="store", nargs='?', type=str, help="input json database.")
-parser.add_argument("--input_dir_encoder", action="append", nargs='+', type=str, help="directory containing individual encoder log files or encoder cfg files. Multiple options possible.")
-parser.add_argument("--log_extension", default="log", action="store", nargs='?', type=str, help="encoder log extension")
-parser.add_argument("--output_json", action="store", nargs='?', type=str, help="name of the output file with concatenated files", required=True)
-args=parser.parse_args()
-
-# mode 1: concatenate all indiviual dataset into 1 file, setting the dirname to find the data
-if args.input_dir_json is not None:
-    header={}
-    lastheader=None
-    db=[]
-    flat=[d for d1 in args.input_dir_json for d in d1]
-    for d in flat:
-        files = sorted(glob.glob(d+'/*.json'))
-        if args.dataset_range:
-            temp = args.dataset_range.split('-')
-            range_start, range_end = list(map(lambda x: int(x), temp))
-            files = files[range_start:range_end]
-        print("Processing directory {}: {} files".format(d,len(files)))
-        for f in files:
-           with open(f, "r") as file:
-               content = file.read()
-               dcontent = json.loads(content)
-               header={}
-               for key in dcontent:
-                   if "suffix_" in key:
-                       header[key]=dcontent[key]
-               if lastheader is not None and not lastheader == header:
-                   sys.exit("File {} does not contain the same data as other files".format(f))
-               lastheader = header
-               for data in dcontent['data']:
-                   if 'dirname' not in data: # no dirname yet
-                      data['dirname']=d
-                   db.append(data)
-    
-    jout=header
-    jout["data"]=db
-    s = json.dumps(jout,indent=1)
-    with open(args.output_json, "w") as file:
-      file.write(s)
-
-
-# mode 2: consolidate a dataset file by adding information on original yuv from encoder logs information     
-if args.input_json is not None:
-    db_logs={}
-    flat=[d for d1 in args.input_dir_encoder for d in d1]
-    for d in flat:
-        files = glob.glob(d+'/*.'+args.log_extension)
-        print("Processing directory {}: {} files".format(d,len(files)))
-        for f in files:
-           with open(f, "r") as file:
-              info={"FrameSkip": 0, "TemporalSubsampleRatio": 1} # default              
-              name=None
-              for line in file:
-                  m = re.match("^Input\s*File\s*:\s*([^\s]+)", line)
-                  if m:
-                      info['InputFile']=m.group(1)
-                  m = re.match("^Bitstream\s*File\s*:\s*([^\s]+)", line)
-                  if m:
-                      name=os.path.basename(m.group(1))
-                  m = re.match("^TemporalSubsampleRatio\s*:\s*([0-9]+)", line)
-                  if m:
-                      info['TemporalSubsampleRatio']=m.group(1)
-#                  m = re.match("^QP\s*:\s*([0-9]+)", line)
- #                 if m:
-  #                    info['QP']=m.group(1)
-                  m = re.match("^FrameSkip\s*:\s*([0-9]+)", line)
-                  if m:
-                      info['FrameSkip']=m.group(1)
-                  m = re.match("^Input\s+bit\s+depth\s*:\s*\(Y:([0-9]+),", line)
-                  if m:
-                      info['InputBitDepth']=m.group(1)
-                  m = re.match("^InputBitDepth\s*:\s*([0-9]+)", line)
-                  if m:
-                       info['InputBitDepth']=m.group(1)
-              if name is not None:
-                  if len(info) != 4:
-                    sys.exit("Not enough information extracted for bitstream {}".format(name))
-                  db_logs[name]=info        
-    print(db_logs)
-    with open(args.input_json, "r") as file:
-      content = file.read()
-      dcontent = json.loads(content)
-      for d in dcontent['data']:
-          if d['bsname'] in db_logs:
-              info=db_logs[d['bsname']]
-              d['original_yuv']=info['InputFile']
-              d['original_temporal_subsample']=int(info['TemporalSubsampleRatio'])
-              d['original_frame_skip']=int(info['FrameSkip'])
-#              d['qp_base']=int(info['QP'])
-              d['original_bitdepth']=int(info['InputBitDepth'])
-      s = json.dumps(dcontent,indent=1)
-      with open(args.output_json, "w") as file:
-        file.write(s)
diff --git a/training/training_scripts/Nn_Filtering_Set_0/3_training_tasks/3_ReadMe.md b/training/training_scripts/Nn_Filtering_Set_0/3_training_tasks/3_ReadMe.md
index 6c0435f70f2a12114714a5246a59855bf6e21c92..c43617af5308f11566c1c8b9a6033c0eab6420bd 100644
--- a/training/training_scripts/Nn_Filtering_Set_0/3_training_tasks/3_ReadMe.md
+++ b/training/training_scripts/Nn_Filtering_Set_0/3_training_tasks/3_ReadMe.md
@@ -11,7 +11,7 @@ tensorboard --logdir=Tensorboard --port=6003
 
 
 ### The first training stage
-
+Before launching the training process, data_loader.py, placed in ../../../../data_loader, should firstly be copied into ./training_scripts.
 Launch the first training stage by the following command:
 ```
 sh ./training_scripts/start_training.sh
diff --git a/training/training_scripts/Nn_Filtering_Set_0/3_training_tasks/training_scripts/data_loader.py b/training/training_scripts/Nn_Filtering_Set_0/3_training_tasks/training_scripts/data_loader.py
deleted file mode 100644
index 150009859d1df0123f7ed6a26b69d950b40295e1..0000000000000000000000000000000000000000
--- a/training/training_scripts/Nn_Filtering_Set_0/3_training_tasks/training_scripts/data_loader.py
+++ /dev/null
@@ -1,240 +0,0 @@
-import json
-import math
-import sys
-from typing import NamedTuple
-import numpy as np
-import struct
-
-class PatchInfo(NamedTuple):
-    data_index: int
-    frame_index: int
-    patch_x0: int
-    patch_y0: int
-    
-    
-def readData(patch_size,border_size,norm,fn,off,nbb,ctype,h,w,x0,y0):
-    t = np.zeros((patch_size+2*border_size,patch_size+2*border_size),dtype='float32') # implicit zeros padding
-    with open(fn,"rb") as file:
-         cropc = [ y0-border_size, y0+patch_size+border_size, x0-border_size, x0+patch_size+border_size ]
-         srcc = [max(cropc[0], 0), min(cropc[1], h), max(cropc[2], 0), min(cropc[3], w)]
-         dstc = [srcc[0] - cropc[0], srcc[1] - cropc[0], srcc[2] - cropc[2], srcc[3] - cropc[2]]
-         data_h = srcc[1] - srcc[0]
-         off += srcc[0]*w*nbb
-         data = np.fromfile(file,dtype=ctype,count=data_h*w,offset=off).reshape(data_h,w)
-         data = data[:, srcc[2]:srcc[3]]
-         t[dstc[0]:dstc[1], dstc[2]:dstc[3]] = data
-         return t.astype('float32')/norm
-    
-def readOne(patch_size,border_size,norm,fn,off,ctype):
-    with open(fn,"rb") as file:
-         if ctype == 'int32':
-             file.seek(off)
-             v = float(struct.unpack("i",file.read(4))[0])/norm
-         else:
-             sys.exit("readOne todo")
-         t = np.full((patch_size+2*border_size,patch_size+2*border_size),v,dtype='float32') 
-         return t
-    
-def getTypeComp(comp):
-    is_y_extracted = False
-    is_u_v_extracted = False
-    for comp_i in comp:
-        if '_Y' in comp_i:
-            is_y_extracted = True
-        if '_U' in comp_i or '_V' in comp_i:
-            is_u_v_extracted = True
-    if is_y_extracted and is_u_v_extracted:
-        return None
-    return is_u_v_extracted
-
-
-class DataLoader:
-    components=[]
-    database=None # contains the whole database
-    patch_info=None # contains address of each patch in the database: dataset index, frame index in the dataset, patch_index in the frame
-    suffix={} # suffix for each file
-    
-    # patch_size in luma sample
-    def __init__(self, jsonfile, patch_size, poc_list, generate_type, qp_filter=-1, slice_type_filter=-1):
-        self.generate_type=generate_type
-        if self.generate_type == 0:
-            self.normalizer_rec  = 1023.0
-            self.normalizer_pred = 1023.0
-            self.normalizer_bs   = 1023.0
-            self.normalizer_cu_average = 1023.0
-            self.normalizer_org8bits = 255.0
-            self.normalizer_org10bits = 1023.0
-            self.normalizer_qp   = 1023.0
-        else:
-            self.normalizer_rec  = 1024.0
-            self.normalizer_pred = 1024.0
-            self.normalizer_bs   = 1024.0
-            self.normalizer_cu_average = 1024.0
-            self.normalizer_org8bits = 256.0
-            self.normalizer_org10bits = 1024.0
-            self.normalizer_qp   = 64.0
-        self.patch_size=patch_size
-        self.patch_info=[]
-        with open(jsonfile, "r") as file:
-         content = file.read()
-         dcontent = json.loads(content)
-         if qp_filter>0 and 'suffix_qp' not in dcontent:
-             sys.exit("Filtering on qp impossible: no qp data in the dataset")
-         if slice_type_filter>0 and 'suffix_slicetype' not in dcontent:
-             sys.exit("Filtering on slice type impossible: no slice data in the dataset")
-         if qp_filter>0 or slice_type_filter>0:
-             sys.exit("todo")
-         self.components.append("org_Y")
-         self.components.append("org_U")
-         self.components.append("org_V")
-         if  'suffix_rec_after_dbf' in dcontent: 
-             self.suffix['rec_after_dbf']=dcontent['suffix_rec_after_dbf']
-             self.components.append("rec_after_dbf_Y")
-             self.components.append("rec_after_dbf_U")
-             self.components.append("rec_after_dbf_V")
-         if  'suffix_rec_before_dbf' in dcontent: 
-             self.suffix['rec_before_dbf']=dcontent['suffix_rec_before_dbf']
-             self.components.append("rec_before_dbf_Y")
-             self.components.append("rec_before_dbf_U")
-             self.components.append("rec_before_dbf_V")
-         if  'suffix_pred' in dcontent: 
-             self.suffix['pred']=dcontent['suffix_pred']
-             self.components.append("pred_Y")
-             self.components.append("pred_U")
-             self.components.append("pred_V")
-         if  'suffix_bs' in dcontent: 
-             self.suffix['bs']=dcontent['suffix_bs']             
-             self.components.append("bs_Y")
-             self.components.append("bs_U")
-             self.components.append("bs_V")
-         if  'suffix_partition_cu_average' in dcontent: 
-             self.suffix['partition_cu_average']=dcontent['suffix_partition_cu_average']    
-             self.components.append("partition_cu_average_Y")
-             self.components.append("partition_cu_average_U")
-             self.components.append("partition_cu_average_V")
-         if  'suffix_qp' in dcontent: 
-             self.components.append("qp_slice")
-             self.suffix['qp_slice']=dcontent['suffix_qp']    
-         self.components.append("qp_base") # always here
-         if  'suffix_slicetype' in dcontent: 
-             self.components.append("slice_type")
-             self.suffix['slice_type']=dcontent['suffix_slicetype']    
-             
-         self.database=dcontent['data']
-         # create array of patches adress
-        
-        if self.generate_type == 0:
-            psize = self.patch_size
-            for didx in range(len(self.database)):
-                 d=self.database[didx]
-                 w = int(d['width'])
-                 h = int(d['height'])
-                 w -= w % psize
-                 h -= h % psize
-                 nb_w=int(w//psize - 2)
-                 nb_h=int(h//psize - 2)
-                 
-                 id_ra = '_T2RA_'
-                 ra_flag = True if id_ra in d['bsname'] else False
-                 for fidx in range(int(d['data_count'])):
-                    if ra_flag and (fidx == 0 or fidx == 32 or fidx == 64):
-                        continue
-                    for y0 in range(nb_h):
-                        for x0 in range(nb_w):
-                            self.patch_info.append(PatchInfo(didx,fidx,1+x0,1+y0))
-        else:
-            for didx in range(len(self.database)):
-                d=self.database[didx]
-                nb_w=int(math.floor(float(d['width'])/patch_size))
-                nb_h=int(math.floor(float(d['height'])/patch_size))
-                frames = range(int(d['data_count'])) if not poc_list else poc_list
-                for fidx in frames:
-                    if fidx >= d['data_count']:
-                        sys.exit("exceed max number of frames ({})".format(d['data_count']))
-                    for y0 in range(nb_h):
-                        for x0 in range(nb_w):
-                            self.patch_info.append(PatchInfo(didx,fidx,x0,y0))
-                     
-    def nb_patches(self):
-         return len(self.patch_info)
-     
-
-    def getPatchData(self,idx,comp,border_size=0):
-        assert(idx<len(self.patch_info))
-        pinfo=self.patch_info[idx]
-        d=self.database[pinfo.data_index]
-        psize=self.patch_size
-        bsize=border_size
-        # print(pinfo,d)
-        chroma_block=getTypeComp(comp)
-        if chroma_block is None:
-            raise AssertionError('The second argument of getPatchData contains strings ending with \'_Y\' and strings ending with \'_U\' or \'_V\', which is not allowed.')
-        w=int(d['width'])
-        h=int(d['height'])
-        frame_size_Y=w*h
-        if chroma_block:
-            psize//=2
-            bsize//=2
-            w//=2
-            h//=2
-        tsize=bsize+psize+bsize
-        x0 = pinfo.patch_x0*psize
-        y0 = pinfo.patch_y0*psize
-        t = np.zeros((1,tsize,tsize,len(comp)),dtype='float32')
-        
-        for idx, c in enumerate(comp):
-            assert(c in self.components)
-                           
-            if 'org' in c:
-                fn=d['original_yuv']
-                off_frame=d['original_frame_skip']+pinfo.frame_index
-                if d['original_bitdepth'] == 8: # 8bits
-                    norm=self.normalizer_org8bits
-                    b='uint8' 
-                    nbb = 1
-                else: # 10bits
-                    norm=self.normalizer_org10bits
-                    b='uint16'                
-                    nbb = 2
-                off = off_frame*(frame_size_Y*nbb*3//2)
-                if c == 'org_U': 
-                    off+=frame_size_Y*nbb                  
-                elif c == 'org_V': 
-                    off+=frame_size_Y*nbb+(frame_size_Y*nbb)//4
-                v = readData(psize,bsize,norm,fn,off,nbb,b,h,w,x0,y0)
-                
-            elif 'rec_after_dbf' in c or 'rec_before_dbf' in c or 'pred' in c or 'partition_cu_average' in c or 'bs' in c:
-                fn=d['dirname']+'/'+d['basename']+self.suffix[c[:-2]]
-                nbb=2 # 16 bits data
-                off=pinfo.frame_index*(frame_size_Y*nbb*3//2)
-                if '_U' in c: 
-                    off+=frame_size_Y*nbb
-                elif '_V' in c: 
-                    off+=frame_size_Y*nbb+(frame_size_Y*nbb)//4
-                if   'rec_after_dbf' in c or 'rec_before_dbf' in c: norm = self.normalizer_rec
-                elif 'pred' in c :          norm = self.normalizer_pred
-                elif 'bs' in c :            norm = self.normalizer_bs
-                elif 'partition_cu_average' in c :     norm = self.normalizer_cu_average
-                               
-                v = readData(psize,bsize,norm,fn,off,nbb,'uint16',h,w,x0,y0)
-                
-            elif c == 'qp_slice':
-                fn=d['dirname']+'/'+d['basename']+self.suffix['qp_slice']
-                norm=self.normalizer_qp
-                off=pinfo.frame_index*4
-                v = readOne(psize,bsize,norm,fn,off,'int32')
-
-            elif c == 'qp_base':
-                norm=self.normalizer_qp
-                f = float(d['qp_base'])/norm                
-                v = np.full((tsize,tsize),f,dtype='float32')                 
-            elif c == 'slice_type':
-                fn=d['dirname']+'/'+d['basename']+self.suffix['slice_type']
-                norm=1
-                off=pinfo.frame_index*4
-                v = readOne(psize,bsize,norm,fn,off,'int32')
-            else:
-                 sys.exit("Unkwown component {}".format(c))
-            t[0,:,:,idx]=v
-        return t
-