From 3e5a4a4038c41d30cd99e4cfcb999cfba8616531 Mon Sep 17 00:00:00 2001
From: Yue Li <yue.li@bytedance.com>
Date: Fri, 3 Feb 2023 16:36:54 -0800
Subject: [PATCH] allow data_loader to extract patches from reference picture
 list 0 and reference picture list 1

---
 training/data_loader/data_loader.py | 29 +++++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/training/data_loader/data_loader.py b/training/data_loader/data_loader.py
index 612b179538..deebdcb4ec 100644
--- a/training/data_loader/data_loader.py
+++ b/training/data_loader/data_loader.py
@@ -44,6 +44,13 @@ class PatchInfo(NamedTuple):
     frame_index: int
     patch_x0: int
     patch_y0: int
+
+
+def getRefDistance(poc):
+    distance = [32, 16, 8, 4, 2, 1]
+    for d in distance:
+        if poc % d == 0:
+            return d
     
     
 def readData(patch_size,border_size,norm,fn,off,nbb,ctype,h,w,x0,y0):
@@ -121,6 +128,12 @@ class DataLoader:
          self.components.append("org_Y")
          self.components.append("org_U")
          self.components.append("org_V")
+         self.components.append("ref_list_0_Y")
+         self.components.append("ref_list_0_U")
+         self.components.append("ref_list_0_V")
+         self.components.append("ref_list_1_Y")
+         self.components.append("ref_list_1_U")
+         self.components.append("ref_list_1_V")
          if  'suffix_rec_after_dbf' in dcontent: 
              self.suffix['rec_after_dbf']=dcontent['suffix_rec_after_dbf']
              self.components.append("rec_after_dbf_Y")
@@ -251,6 +264,22 @@ class DataLoader:
                 elif 'partition_cu_average' in c :     norm = self.normalizer_cu_average
                                
                 v = readData(psize,bsize,norm,fn,off,nbb,'uint16',h,w,x0,y0)
+
+            elif 'ref_list_0' in c or 'ref_list_1' in c:
+                fn = d['dirname'] + '/' + d['basename'] + self.suffix['rec_before_dbf']
+                nbb = 2  # 16 bits data
+                if 'ref_list_0' in c:
+                    off = (pinfo.frame_index - getRefDistance(pinfo.frame_index)) * (frame_size_Y * nbb * 3 // 2)
+                else:
+                    off = (pinfo.frame_index + getRefDistance(pinfo.frame_index)) * (frame_size_Y * nbb * 3 // 2)
+                if '_U' in c:
+                    off += frame_size_Y * nbb
+                elif '_V' in c:
+                    off += frame_size_Y * nbb + (frame_size_Y * nbb) // 4
+
+                norm = self.normalizer_rec
+
+                v = readData(psize, bsize, norm, fn, off, nbb, 'uint16', h, w, x0, y0)
                 
             elif c == 'qp_slice':
                 fn=d['dirname']+'/'+d['basename']+self.suffix['qp_slice']
-- 
GitLab