From 3e5a4a4038c41d30cd99e4cfcb999cfba8616531 Mon Sep 17 00:00:00 2001 From: Yue Li <yue.li@bytedance.com> Date: Fri, 3 Feb 2023 16:36:54 -0800 Subject: [PATCH] allow data_loader to extract patches from reference picture list 0 and reference picture list 1 --- training/data_loader/data_loader.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/training/data_loader/data_loader.py b/training/data_loader/data_loader.py index 612b179538..deebdcb4ec 100644 --- a/training/data_loader/data_loader.py +++ b/training/data_loader/data_loader.py @@ -44,6 +44,13 @@ class PatchInfo(NamedTuple): frame_index: int patch_x0: int patch_y0: int + + +def getRefDistance(poc): + distance = [32, 16, 8, 4, 2, 1] + for d in distance: + if poc % d == 0: + return d def readData(patch_size,border_size,norm,fn,off,nbb,ctype,h,w,x0,y0): @@ -121,6 +128,12 @@ class DataLoader: self.components.append("org_Y") self.components.append("org_U") self.components.append("org_V") + self.components.append("ref_list_0_Y") + self.components.append("ref_list_0_U") + self.components.append("ref_list_0_V") + self.components.append("ref_list_1_Y") + self.components.append("ref_list_1_U") + self.components.append("ref_list_1_V") if 'suffix_rec_after_dbf' in dcontent: self.suffix['rec_after_dbf']=dcontent['suffix_rec_after_dbf'] self.components.append("rec_after_dbf_Y") @@ -251,6 +264,22 @@ class DataLoader: elif 'partition_cu_average' in c : norm = self.normalizer_cu_average v = readData(psize,bsize,norm,fn,off,nbb,'uint16',h,w,x0,y0) + + elif 'ref_list_0' in c or 'ref_list_1' in c: + fn = d['dirname'] + '/' + d['basename'] + self.suffix['rec_before_dbf'] + nbb = 2 # 16 bits data + if 'ref_list_0' in c: + off = (pinfo.frame_index - getRefDistance(pinfo.frame_index)) * (frame_size_Y * nbb * 3 // 2) + else: + off = (pinfo.frame_index + getRefDistance(pinfo.frame_index)) * (frame_size_Y * nbb * 3 // 2) + if '_U' in c: + off += frame_size_Y * nbb + elif '_V' in c: + off += frame_size_Y * nbb + (frame_size_Y * nbb) // 4 + + norm = self.normalizer_rec + + v = readData(psize, bsize, norm, fn, off, nbb, 'uint16', h, w, x0, y0) elif c == 'qp_slice': fn=d['dirname']+'/'+d['basename']+self.suffix['qp_slice'] -- GitLab