diff --git a/training/data_loader/data_loader.py b/training/data_loader/data_loader.py index 612b17953873acb52eeda22d05e00eacda7441c6..deebdcb4ec31371b930385d8ceb533dcf37b7544 100644 --- a/training/data_loader/data_loader.py +++ b/training/data_loader/data_loader.py @@ -44,6 +44,13 @@ class PatchInfo(NamedTuple): frame_index: int patch_x0: int patch_y0: int + + +def getRefDistance(poc): + distance = [32, 16, 8, 4, 2, 1] + for d in distance: + if poc % d == 0: + return d def readData(patch_size,border_size,norm,fn,off,nbb,ctype,h,w,x0,y0): @@ -121,6 +128,12 @@ class DataLoader: self.components.append("org_Y") self.components.append("org_U") self.components.append("org_V") + self.components.append("ref_list_0_Y") + self.components.append("ref_list_0_U") + self.components.append("ref_list_0_V") + self.components.append("ref_list_1_Y") + self.components.append("ref_list_1_U") + self.components.append("ref_list_1_V") if 'suffix_rec_after_dbf' in dcontent: self.suffix['rec_after_dbf']=dcontent['suffix_rec_after_dbf'] self.components.append("rec_after_dbf_Y") @@ -251,6 +264,22 @@ class DataLoader: elif 'partition_cu_average' in c : norm = self.normalizer_cu_average v = readData(psize,bsize,norm,fn,off,nbb,'uint16',h,w,x0,y0) + + elif 'ref_list_0' in c or 'ref_list_1' in c: + fn = d['dirname'] + '/' + d['basename'] + self.suffix['rec_before_dbf'] + nbb = 2 # 16 bits data + if 'ref_list_0' in c: + off = (pinfo.frame_index - getRefDistance(pinfo.frame_index)) * (frame_size_Y * nbb * 3 // 2) + else: + off = (pinfo.frame_index + getRefDistance(pinfo.frame_index)) * (frame_size_Y * nbb * 3 // 2) + if '_U' in c: + off += frame_size_Y * nbb + elif '_V' in c: + off += frame_size_Y * nbb + (frame_size_Y * nbb) // 4 + + norm = self.normalizer_rec + + v = readData(psize, bsize, norm, fn, off, nbb, 'uint16', h, w, x0, y0) elif c == 'qp_slice': fn=d['dirname']+'/'+d['basename']+self.suffix['qp_slice']