From 48479ff13d737099be5d9e3e68e86939bdb23b0c Mon Sep 17 00:00:00 2001
From: Yue Li <yue.li@bytedance.com>
Date: Wed, 16 Nov 2022 08:58:38 +0000
Subject: [PATCH] refactor code of filter set #1: correct variable name, remove
 constant

---
 source/Lib/CommonLib/CommonDef.h          |  6 +--
 source/Lib/CommonLib/NNFilterSet1.cpp     | 64 +++++++++++++----------
 source/Lib/EncoderLib/EncNNFilterSet1.cpp |  4 +-
 3 files changed, 42 insertions(+), 32 deletions(-)

diff --git a/source/Lib/CommonLib/CommonDef.h b/source/Lib/CommonLib/CommonDef.h
index b2d3086e60..ed69ac29de 100644
--- a/source/Lib/CommonLib/CommonDef.h
+++ b/source/Lib/CommonLib/CommonDef.h
@@ -195,9 +195,9 @@ static const int QP_OFFSET_NUM                                    = 2;
 static const int NNQPOFFSET[QP_OFFSET_NUM]                = { -5, 5 };
 #endif
 #endif
-#if NN_FILTERING_SET_1 && NN_FIXED_POINT_IMPLEMENTATION
-static const int NN_INPUT_PRECISION=                                13;
-static const int NN_OUPUTPUT_PRECISION=                             13;
+#if NN_FILTERING_SET_1
+static const int NN_INPUT_PRECISION=                               13;
+static const int NN_OUTPUT_PRECISION=                              13;
 #endif
 
 
diff --git a/source/Lib/CommonLib/NNFilterSet1.cpp b/source/Lib/CommonLib/NNFilterSet1.cpp
index 03c3da2f35..dec53e08c6 100644
--- a/source/Lib/CommonLib/NNFilterSet1.cpp
+++ b/source/Lib/CommonLib/NNFilterSet1.cpp
@@ -227,8 +227,8 @@ void extractOutputsLuma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, P
 #if NN_FIXED_POINT_IMPLEMENTATION
   int log2InputScale = 10;
   int log2OutputScale = 10;
-  int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale;
-  int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale;
+  int shiftInput = NN_OUTPUT_PRECISION - log2InputScale;
+  int shiftOutput = NN_OUTPUT_PRECISION - log2OutputScale;
   int offset    = (1 << shiftOutput) / 2;
 #else
   double inputScale = 1024;
@@ -277,10 +277,10 @@ template<typename T>
 void extractOutputsChroma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, PelStorage& tempScaledBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, bool inter)
 {
 #if NN_FIXED_POINT_IMPLEMENTATION
-int log2InputScale = 10;
+  int log2InputScale = 10;
   int log2OutputScale = 10;
-  int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale;
-  int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale;
+  int shiftInput = NN_OUTPUT_PRECISION - log2InputScale;
+  int shiftOutput = NN_OUTPUT_PRECISION - log2OutputScale;
   int offset    = (1 << shiftOutput) / 2;
 #else
   double inputScale = 1024;
@@ -338,6 +338,11 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL
 {
   //at::init_num_threads(); // use all available threads
   
+  double inputScale = 1024;
+  double qpScale = 64;
+  int log2InputScale = 10;
+  int log2QpScale = 6;
+  
   const int border_to_skip = 0;
   if (border_to_skip>0) sadl::Tensor<float>::skip_border = true;
   
@@ -360,10 +365,10 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL
   std::vector<InputData> listInputData;
   if (inter)
   {
-    InputData inputRec = {NN_INPUT_REC, 0, 1024, 3, true, false};
-    InputData inputPred = {NN_INPUT_PRED, 1, 1024, 3, true, false};
-    InputData inputBs = {NN_INPUT_BS, 2, 1024, 3, true, false};
-    InputData inputQp = {NN_INPUT_LOCAL_QP, 3, 64, 7, true, false};
+    InputData inputRec = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputPred = {NN_INPUT_PRED, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputBs = {NN_INPUT_BS, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputQp = {NN_INPUT_LOCAL_QP, 3, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false};
     listInputData.push_back(inputRec);
     listInputData.push_back(inputPred);
     listInputData.push_back(inputBs);
@@ -371,15 +376,15 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL
   }
   else
   {
-    InputData inputRec = {NN_INPUT_REC, 0, 1024, 3, true, false};
-    InputData inputPred = {NN_INPUT_PRED, 1, 1024, 3, true, false};
+    InputData inputRec = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputPred = {NN_INPUT_PRED, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
 #if JVET_AB0053_NO_PART_NO_ATTN
-    InputData inputBs = {NN_INPUT_BS, 2, 1024, 3, true, false};
-    InputData inputQp = {NN_INPUT_LOCAL_QP, 3, 64, 7, true, false};
+    InputData inputBs = {NN_INPUT_BS, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputQp = {NN_INPUT_LOCAL_QP, 3, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false};
 #else
-    InputData inputPartition = {NN_INPUT_PARTITION, 2, 1024, 3, true, false};
-    InputData inputBs = {NN_INPUT_BS, 3, 1024, 3, true, false};
-    InputData inputQp = {NN_INPUT_LOCAL_QP, 4, 64, 7, true, false};
+    InputData inputPartition = {NN_INPUT_PARTITION, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputBs = {NN_INPUT_BS, 3, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputQp = {NN_INPUT_LOCAL_QP, 4, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false};
 #endif
     listInputData.push_back(inputRec);
     listInputData.push_back(inputPred);
@@ -404,6 +409,11 @@ void NNFilterSet1::cnnFilterChromaBlock(Picture* pic, UnitArea inferArea, int ex
 
   //at::init_num_threads();
   
+  double inputScale = 1024;
+  double qpScale = 64;
+  int log2InputScale = 10;
+  int log2QpScale = 6;
+  
   const int border_to_skip = 0;
   if (border_to_skip>0) sadl::Tensor<float>::skip_border = true;
   
@@ -427,11 +437,11 @@ void NNFilterSet1::cnnFilterChromaBlock(Picture* pic, UnitArea inferArea, int ex
   
   if (inter)
   {
-    InputData inputRecCrossComponent = {NN_INPUT_REC, 0, 1024, 3, true, false};
-    InputData inputRec = {NN_INPUT_REC, 1, 1024, 3, false, true};
-    InputData inputPred = {NN_INPUT_PRED, 2, 1024, 3, false, true};
-    InputData inputBs = {NN_INPUT_BS, 3, 1024, 3, false, true};
-    InputData inputQp = {NN_INPUT_LOCAL_QP, 4, 64, 7, false, true};
+    InputData inputRecCrossComponent = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputRec = {NN_INPUT_REC, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true};
+    InputData inputPred = {NN_INPUT_PRED, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true};
+    InputData inputBs = {NN_INPUT_BS, 3, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true};
+    InputData inputQp = {NN_INPUT_LOCAL_QP, 4, qpScale, NN_INPUT_PRECISION - log2QpScale, false, true};
     listInputData.push_back(inputRecCrossComponent);
     listInputData.push_back(inputRec);
     listInputData.push_back(inputPred);
@@ -440,12 +450,12 @@ void NNFilterSet1::cnnFilterChromaBlock(Picture* pic, UnitArea inferArea, int ex
   }
   else
   {
-    InputData inputRecCrossComponent = {NN_INPUT_REC, 0, 1024, 3, true, false};
-    InputData inputRec = {NN_INPUT_REC, 1, 1024, 3, false, true};
-    InputData inputPred = {NN_INPUT_PRED, 2, 1024, 3, false, true};
-    InputData inputPartition = {NN_INPUT_PARTITION, 3, 1024, 3, false, true};
-    InputData inputBs = {NN_INPUT_BS, 4, 1024, 3, false, true};
-    InputData inputQp = {NN_INPUT_LOCAL_QP, 5, 64, 7, false, true};
+    InputData inputRecCrossComponent = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false};
+    InputData inputRec = {NN_INPUT_REC, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true};
+    InputData inputPred = {NN_INPUT_PRED, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true};
+    InputData inputPartition = {NN_INPUT_PARTITION, 3, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true};
+    InputData inputBs = {NN_INPUT_BS, 4, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true};
+    InputData inputQp = {NN_INPUT_LOCAL_QP, 5, qpScale, NN_INPUT_PRECISION - log2QpScale, false, true};
     listInputData.push_back(inputRecCrossComponent);
     listInputData.push_back(inputRec);
     listInputData.push_back(inputPred);
diff --git a/source/Lib/EncoderLib/EncNNFilterSet1.cpp b/source/Lib/EncoderLib/EncNNFilterSet1.cpp
index 72c4b9d770..c6c8100856 100644
--- a/source/Lib/EncoderLib/EncNNFilterSet1.cpp
+++ b/source/Lib/EncoderLib/EncNNFilterSet1.cpp
@@ -273,8 +273,8 @@ void EncNNFilterSet1::extractOutputsLumaRd (Picture* pic, sadl::Model<T> &m, Pel
 #if NN_FIXED_POINT_IMPLEMENTATION
   int log2InputScale = 10;
   int log2OutputScale = 10;
-  int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale;
-  int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale;
+  int shiftInput = NN_OUTPUT_PRECISION - log2InputScale;
+  int shiftOutput = NN_OUTPUT_PRECISION - log2OutputScale;
   int offset    = (1 << shiftOutput) / 2;
 #else
   double inputScale = 1024;
-- 
GitLab