Skip to content
Snippets Groups Projects
Commit 2eae7e98 authored by Liqiang Wang's avatar Liqiang Wang
Browse files

Remove the duplicate files and add the corresponding description.

parent 6220ae53
No related branches found
No related tags found
No related merge requests found
......@@ -107,7 +107,7 @@ When the compression dataset generation is finished, the file structure should b
```
## Generate the dataset json
Run generate_dataset_json.sh to generate the training and test dataset json files.
Run generate_dataset_json.sh to generate the training and test dataset json files. Before running this script, concatenate_dataset.py, placed in ../../../tools, should firstly be copied into the current directory.
## Other
Because the neural network based filter is used in the generation process of the second stage dataset, the speed is relative slow if it is run on CPU.
......
import argparse
import glob
import sys
import json
import re
import os
parser = argparse.ArgumentParser(prog='concatenate dataset', usage='create a global dataset from all the json file in a given directory. ',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=
'''2 modes available:
concatenate_dataset.py --input_dir dir1 --input_dir dir2 --output_json pre_dataset.json
concatenate_dataset.py --input_json pre_dataset.json --input_dir_encoder direnc1 --input_dir_encoder direnc2 --output_json dataset.json''')
parser.add_argument("--input_dir_json", action="append", nargs='+', type=str, help="directory containing individual json files. Multiple options possible.")
parser.add_argument("--dataset_range", type=str, default='', help="train or test dataset range (such as 0-1000), use all data by default")
parser.add_argument("--input_json", action="store", nargs='?', type=str, help="input json database.")
parser.add_argument("--input_dir_encoder", action="append", nargs='+', type=str, help="directory containing individual encoder log files or encoder cfg files. Multiple options possible.")
parser.add_argument("--log_extension", default="log", action="store", nargs='?', type=str, help="encoder log extension")
parser.add_argument("--output_json", action="store", nargs='?', type=str, help="name of the output file with concatenated files", required=True)
args=parser.parse_args()
# mode 1: concatenate all indiviual dataset into 1 file, setting the dirname to find the data
if args.input_dir_json is not None:
header={}
lastheader=None
db=[]
flat=[d for d1 in args.input_dir_json for d in d1]
for d in flat:
files = sorted(glob.glob(d+'/*.json'))
if args.dataset_range:
temp = args.dataset_range.split('-')
range_start, range_end = list(map(lambda x: int(x), temp))
files = files[range_start:range_end]
print("Processing directory {}: {} files".format(d,len(files)))
for f in files:
with open(f, "r") as file:
content = file.read()
dcontent = json.loads(content)
header={}
for key in dcontent:
if "suffix_" in key:
header[key]=dcontent[key]
if lastheader is not None and not lastheader == header:
sys.exit("File {} does not contain the same data as other files".format(f))
lastheader = header
for data in dcontent['data']:
if 'dirname' not in data: # no dirname yet
data['dirname']=d
db.append(data)
jout=header
jout["data"]=db
s = json.dumps(jout,indent=1)
with open(args.output_json, "w") as file:
file.write(s)
# mode 2: consolidate a dataset file by adding information on original yuv from encoder logs information
if args.input_json is not None:
db_logs={}
flat=[d for d1 in args.input_dir_encoder for d in d1]
for d in flat:
files = glob.glob(d+'/*.'+args.log_extension)
print("Processing directory {}: {} files".format(d,len(files)))
for f in files:
with open(f, "r") as file:
info={"FrameSkip": 0, "TemporalSubsampleRatio": 1} # default
name=None
for line in file:
m = re.match("^Input\s*File\s*:\s*([^\s]+)", line)
if m:
info['InputFile']=m.group(1)
m = re.match("^Bitstream\s*File\s*:\s*([^\s]+)", line)
if m:
name=os.path.basename(m.group(1))
m = re.match("^TemporalSubsampleRatio\s*:\s*([0-9]+)", line)
if m:
info['TemporalSubsampleRatio']=m.group(1)
# m = re.match("^QP\s*:\s*([0-9]+)", line)
# if m:
# info['QP']=m.group(1)
m = re.match("^FrameSkip\s*:\s*([0-9]+)", line)
if m:
info['FrameSkip']=m.group(1)
m = re.match("^Input\s+bit\s+depth\s*:\s*\(Y:([0-9]+),", line)
if m:
info['InputBitDepth']=m.group(1)
m = re.match("^InputBitDepth\s*:\s*([0-9]+)", line)
if m:
info['InputBitDepth']=m.group(1)
if name is not None:
if len(info) != 4:
sys.exit("Not enough information extracted for bitstream {}".format(name))
db_logs[name]=info
print(db_logs)
with open(args.input_json, "r") as file:
content = file.read()
dcontent = json.loads(content)
for d in dcontent['data']:
if d['bsname'] in db_logs:
info=db_logs[d['bsname']]
d['original_yuv']=info['InputFile']
d['original_temporal_subsample']=int(info['TemporalSubsampleRatio'])
d['original_frame_skip']=int(info['FrameSkip'])
# d['qp_base']=int(info['QP'])
d['original_bitdepth']=int(info['InputBitDepth'])
s = json.dumps(dcontent,indent=1)
with open(args.output_json, "w") as file:
file.write(s)
......@@ -11,7 +11,7 @@ tensorboard --logdir=Tensorboard --port=6003
### The first training stage
Before launching the training process, data_loader.py, placed in ../../../../data_loader, should firstly be copied into ./training_scripts.
Launch the first training stage by the following command:
```
sh ./training_scripts/start_training.sh
......
import json
import math
import sys
from typing import NamedTuple
import numpy as np
import struct
class PatchInfo(NamedTuple):
data_index: int
frame_index: int
patch_x0: int
patch_y0: int
def readData(patch_size,border_size,norm,fn,off,nbb,ctype,h,w,x0,y0):
t = np.zeros((patch_size+2*border_size,patch_size+2*border_size),dtype='float32') # implicit zeros padding
with open(fn,"rb") as file:
cropc = [ y0-border_size, y0+patch_size+border_size, x0-border_size, x0+patch_size+border_size ]
srcc = [max(cropc[0], 0), min(cropc[1], h), max(cropc[2], 0), min(cropc[3], w)]
dstc = [srcc[0] - cropc[0], srcc[1] - cropc[0], srcc[2] - cropc[2], srcc[3] - cropc[2]]
data_h = srcc[1] - srcc[0]
off += srcc[0]*w*nbb
data = np.fromfile(file,dtype=ctype,count=data_h*w,offset=off).reshape(data_h,w)
data = data[:, srcc[2]:srcc[3]]
t[dstc[0]:dstc[1], dstc[2]:dstc[3]] = data
return t.astype('float32')/norm
def readOne(patch_size,border_size,norm,fn,off,ctype):
with open(fn,"rb") as file:
if ctype == 'int32':
file.seek(off)
v = float(struct.unpack("i",file.read(4))[0])/norm
else:
sys.exit("readOne todo")
t = np.full((patch_size+2*border_size,patch_size+2*border_size),v,dtype='float32')
return t
def getTypeComp(comp):
is_y_extracted = False
is_u_v_extracted = False
for comp_i in comp:
if '_Y' in comp_i:
is_y_extracted = True
if '_U' in comp_i or '_V' in comp_i:
is_u_v_extracted = True
if is_y_extracted and is_u_v_extracted:
return None
return is_u_v_extracted
class DataLoader:
components=[]
database=None # contains the whole database
patch_info=None # contains address of each patch in the database: dataset index, frame index in the dataset, patch_index in the frame
suffix={} # suffix for each file
# patch_size in luma sample
def __init__(self, jsonfile, patch_size, poc_list, generate_type, qp_filter=-1, slice_type_filter=-1):
self.generate_type=generate_type
if self.generate_type == 0:
self.normalizer_rec = 1023.0
self.normalizer_pred = 1023.0
self.normalizer_bs = 1023.0
self.normalizer_cu_average = 1023.0
self.normalizer_org8bits = 255.0
self.normalizer_org10bits = 1023.0
self.normalizer_qp = 1023.0
else:
self.normalizer_rec = 1024.0
self.normalizer_pred = 1024.0
self.normalizer_bs = 1024.0
self.normalizer_cu_average = 1024.0
self.normalizer_org8bits = 256.0
self.normalizer_org10bits = 1024.0
self.normalizer_qp = 64.0
self.patch_size=patch_size
self.patch_info=[]
with open(jsonfile, "r") as file:
content = file.read()
dcontent = json.loads(content)
if qp_filter>0 and 'suffix_qp' not in dcontent:
sys.exit("Filtering on qp impossible: no qp data in the dataset")
if slice_type_filter>0 and 'suffix_slicetype' not in dcontent:
sys.exit("Filtering on slice type impossible: no slice data in the dataset")
if qp_filter>0 or slice_type_filter>0:
sys.exit("todo")
self.components.append("org_Y")
self.components.append("org_U")
self.components.append("org_V")
if 'suffix_rec_after_dbf' in dcontent:
self.suffix['rec_after_dbf']=dcontent['suffix_rec_after_dbf']
self.components.append("rec_after_dbf_Y")
self.components.append("rec_after_dbf_U")
self.components.append("rec_after_dbf_V")
if 'suffix_rec_before_dbf' in dcontent:
self.suffix['rec_before_dbf']=dcontent['suffix_rec_before_dbf']
self.components.append("rec_before_dbf_Y")
self.components.append("rec_before_dbf_U")
self.components.append("rec_before_dbf_V")
if 'suffix_pred' in dcontent:
self.suffix['pred']=dcontent['suffix_pred']
self.components.append("pred_Y")
self.components.append("pred_U")
self.components.append("pred_V")
if 'suffix_bs' in dcontent:
self.suffix['bs']=dcontent['suffix_bs']
self.components.append("bs_Y")
self.components.append("bs_U")
self.components.append("bs_V")
if 'suffix_partition_cu_average' in dcontent:
self.suffix['partition_cu_average']=dcontent['suffix_partition_cu_average']
self.components.append("partition_cu_average_Y")
self.components.append("partition_cu_average_U")
self.components.append("partition_cu_average_V")
if 'suffix_qp' in dcontent:
self.components.append("qp_slice")
self.suffix['qp_slice']=dcontent['suffix_qp']
self.components.append("qp_base") # always here
if 'suffix_slicetype' in dcontent:
self.components.append("slice_type")
self.suffix['slice_type']=dcontent['suffix_slicetype']
self.database=dcontent['data']
# create array of patches adress
if self.generate_type == 0:
psize = self.patch_size
for didx in range(len(self.database)):
d=self.database[didx]
w = int(d['width'])
h = int(d['height'])
w -= w % psize
h -= h % psize
nb_w=int(w//psize - 2)
nb_h=int(h//psize - 2)
id_ra = '_T2RA_'
ra_flag = True if id_ra in d['bsname'] else False
for fidx in range(int(d['data_count'])):
if ra_flag and (fidx == 0 or fidx == 32 or fidx == 64):
continue
for y0 in range(nb_h):
for x0 in range(nb_w):
self.patch_info.append(PatchInfo(didx,fidx,1+x0,1+y0))
else:
for didx in range(len(self.database)):
d=self.database[didx]
nb_w=int(math.floor(float(d['width'])/patch_size))
nb_h=int(math.floor(float(d['height'])/patch_size))
frames = range(int(d['data_count'])) if not poc_list else poc_list
for fidx in frames:
if fidx >= d['data_count']:
sys.exit("exceed max number of frames ({})".format(d['data_count']))
for y0 in range(nb_h):
for x0 in range(nb_w):
self.patch_info.append(PatchInfo(didx,fidx,x0,y0))
def nb_patches(self):
return len(self.patch_info)
def getPatchData(self,idx,comp,border_size=0):
assert(idx<len(self.patch_info))
pinfo=self.patch_info[idx]
d=self.database[pinfo.data_index]
psize=self.patch_size
bsize=border_size
# print(pinfo,d)
chroma_block=getTypeComp(comp)
if chroma_block is None:
raise AssertionError('The second argument of getPatchData contains strings ending with \'_Y\' and strings ending with \'_U\' or \'_V\', which is not allowed.')
w=int(d['width'])
h=int(d['height'])
frame_size_Y=w*h
if chroma_block:
psize//=2
bsize//=2
w//=2
h//=2
tsize=bsize+psize+bsize
x0 = pinfo.patch_x0*psize
y0 = pinfo.patch_y0*psize
t = np.zeros((1,tsize,tsize,len(comp)),dtype='float32')
for idx, c in enumerate(comp):
assert(c in self.components)
if 'org' in c:
fn=d['original_yuv']
off_frame=d['original_frame_skip']+pinfo.frame_index
if d['original_bitdepth'] == 8: # 8bits
norm=self.normalizer_org8bits
b='uint8'
nbb = 1
else: # 10bits
norm=self.normalizer_org10bits
b='uint16'
nbb = 2
off = off_frame*(frame_size_Y*nbb*3//2)
if c == 'org_U':
off+=frame_size_Y*nbb
elif c == 'org_V':
off+=frame_size_Y*nbb+(frame_size_Y*nbb)//4
v = readData(psize,bsize,norm,fn,off,nbb,b,h,w,x0,y0)
elif 'rec_after_dbf' in c or 'rec_before_dbf' in c or 'pred' in c or 'partition_cu_average' in c or 'bs' in c:
fn=d['dirname']+'/'+d['basename']+self.suffix[c[:-2]]
nbb=2 # 16 bits data
off=pinfo.frame_index*(frame_size_Y*nbb*3//2)
if '_U' in c:
off+=frame_size_Y*nbb
elif '_V' in c:
off+=frame_size_Y*nbb+(frame_size_Y*nbb)//4
if 'rec_after_dbf' in c or 'rec_before_dbf' in c: norm = self.normalizer_rec
elif 'pred' in c : norm = self.normalizer_pred
elif 'bs' in c : norm = self.normalizer_bs
elif 'partition_cu_average' in c : norm = self.normalizer_cu_average
v = readData(psize,bsize,norm,fn,off,nbb,'uint16',h,w,x0,y0)
elif c == 'qp_slice':
fn=d['dirname']+'/'+d['basename']+self.suffix['qp_slice']
norm=self.normalizer_qp
off=pinfo.frame_index*4
v = readOne(psize,bsize,norm,fn,off,'int32')
elif c == 'qp_base':
norm=self.normalizer_qp
f = float(d['qp_base'])/norm
v = np.full((tsize,tsize),f,dtype='float32')
elif c == 'slice_type':
fn=d['dirname']+'/'+d['basename']+self.suffix['slice_type']
norm=1
off=pinfo.frame_index*4
v = readOne(psize,bsize,norm,fn,off,'int32')
else:
sys.exit("Unkwown component {}".format(c))
t[0,:,:,idx]=v
return t
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment