From cddbce8bca419a48dcf60be86f5f14b69678d4ab Mon Sep 17 00:00:00 2001
From: Renjie Chang <renjiechang@tencent.com>
Date: Thu, 2 Mar 2023 13:19:16 +0000
Subject: [PATCH] SR configuration for rate matching

---
 README.md                           |  5 +++++
 cfg/nn-based/nnsr_classA1.cfg       |  2 ++
 cfg/nn-based/nnsr_classA2.cfg       |  2 ++
 source/App/EncoderApp/EncApp.cpp    |  2 ++
 source/App/EncoderApp/EncAppCfg.cpp |  4 +++-
 source/App/EncoderApp/EncAppCfg.h   |  2 ++
 source/Lib/CommonLib/CommonDef.h    |  5 +++++
 source/Lib/EncoderLib/EncCfg.h      |  8 +++++++-
 source/Lib/EncoderLib/EncLib.cpp    | 26 +++++++++++++++++++++++---
 9 files changed, 51 insertions(+), 5 deletions(-)
 create mode 100644 cfg/nn-based/nnsr_classA1.cfg
 create mode 100644 cfg/nn-based/nnsr_classA2.cfg

diff --git a/README.md b/README.md
index b77b5b8a47..3d54053ab7 100644
--- a/README.md
+++ b/README.md
@@ -419,10 +419,15 @@ please add the following argument when running the VTM-11-NNVC encoder/decoder e
 where `path_to_directory_models_intra` is the path to the directory "models/intra" relatively to the directory from which the
 VTM-11-NNVC encoder/decoder executable is run.
 
+
 NN-based super resolution
 ------------------------------------------------------------------------
 To activate NN-based super resolution, use --NnsrOption=1. The default model path is set as "./models/super_resolution/".
 
+For rate matchinng, use following config file when testing class A1 or A2.
+[cfg/nn-based/nnsr_classA1.cfg](cfg/nn-based/nnsr_classA1.cfg)
+[cfg/nn-based/nnsr_classA2.cfg](cfg/nn-based/nnsr_classA2.cfg)
+
 
 Content-adaptive post-filter
 ------------------------------------------------------------------------
diff --git a/cfg/nn-based/nnsr_classA1.cfg b/cfg/nn-based/nnsr_classA1.cfg
new file mode 100644
index 0000000000..593f428836
--- /dev/null
+++ b/cfg/nn-based/nnsr_classA1.cfg
@@ -0,0 +1,2 @@
+NnsrQpOffset                   : 5
+NnsrQpOffsetOverrideByQp       : 22 5 27 5 32 5 37 5 42 5
diff --git a/cfg/nn-based/nnsr_classA2.cfg b/cfg/nn-based/nnsr_classA2.cfg
new file mode 100644
index 0000000000..36fb542cdf
--- /dev/null
+++ b/cfg/nn-based/nnsr_classA2.cfg
@@ -0,0 +1,2 @@
+NnsrQpOffset                   : 5
+NnsrQpOffsetOverrideByQp       : 22 10 27 6 32 6 37 6 42 6
diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp
index 09b55af765..6ed57bdb3f 100644
--- a/source/App/EncoderApp/EncApp.cpp
+++ b/source/App/EncoderApp/EncApp.cpp
@@ -263,6 +263,8 @@ void EncApp::xInitLibCfg()
 
 #if JVET_AC0196_NNSR
   m_cEncLib.setUseNnsr                                           (m_nnsrOption);
+  m_cEncLib.setNnsrQpOffset                                      (m_nnsrQpOffset);
+  m_cEncLib.setNnsrQpOffsetOverridebyQp                          (m_sNnsrQpOffsetOverrideByQp);
 #endif
 
   m_cEncLib.setProfile                                           ( m_profile);
diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp
index 694e1c9633..a35d50009d 100644
--- a/source/App/EncoderApp/EncAppCfg.cpp
+++ b/source/App/EncoderApp/EncAppCfg.cpp
@@ -1472,6 +1472,8 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] )
   ( "RPR",                                            m_rprEnabledFlag,                          true, "Reference Sample Resolution" )
 #if JVET_AC0196_NNSR
   ( "NnsrOption",                                     m_nnsrOption,                              false, "NN-based super resolution option")
+  ( "NnsrQpOffset",                                   m_nnsrQpOffset,                               5, "Base QP offset for Rate matching")
+  ( "NnsrQpOffsetOverrideByQp",                       m_sNnsrQpOffsetOverrideByQp,         string(""), "Override QP offset for Rate matching")
   ( "ScalingRatioHor",                                m_scalingRatioHor,                          2.0, "Scaling ratio in hor direction" )
   ( "ScalingRatioVer",                                m_scalingRatioVer,                          2.0, "Scaling ratio in ver direction" )
   ( "FractionNumFrames",                              m_fractionOfFrames,                         1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" )
@@ -1678,7 +1680,7 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] )
   const list<const char*>& argv_unhandled = po::scanArgv(opts, argc, (const char**) argv, err);
 
 #if JVET_AC0196_NNSR
-  m_nnsrOption = m_nnsrOption ? true : false;
+  m_nnsrOption = (m_nnsrOption && m_iSourceWidth >= NNSR_ENABLING_WIDTH && m_iSourceHeight >= NNSR_ENABLING_HEIGHT) ? true : false;
   m_scalingRatioHor = m_nnsrOption ? 2.0 : 1.0;
   m_scalingRatioVer = m_nnsrOption ? 2.0 : 1.0;
   m_upscaledOutput = (m_nnsrOption && (m_upscaledOutput != 0)) ? 2 : 0;
diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h
index a4c6a1c58d..eda98c8ff6 100644
--- a/source/App/EncoderApp/EncAppCfg.h
+++ b/source/App/EncoderApp/EncAppCfg.h
@@ -783,6 +783,8 @@ protected:
 
 #if JVET_AC0196_NNSR
   bool        m_nnsrOption;
+  int         m_nnsrQpOffset;
+  std::string m_sNnsrQpOffsetOverrideByQp;
 #endif
 #if JVET_AC0055_NN_POST_FILTERING
   bool m_nnpf;
diff --git a/source/Lib/CommonLib/CommonDef.h b/source/Lib/CommonLib/CommonDef.h
index 73a4b1fd2f..945ae33f84 100644
--- a/source/Lib/CommonLib/CommonDef.h
+++ b/source/Lib/CommonLib/CommonDef.h
@@ -203,6 +203,11 @@ static const int MINIMUM_TID_ENABLING_TEMPORAL_INPUTS =             3; // JVET-A
 #endif
 #endif
 
+#if JVET_AC0196_NNSR
+static const int NNSR_ENABLING_WIDTH     =                       3840;
+static const int NNSR_ENABLING_HEIGHT    =                       2160;
+#endif
+
 #if JVET_AC0055_NN_POST_FILTERING
 static const int MAX_NUM_NN_POST_FILTERS =                          8;
 static const int NNPF_BLOCK_SIZE         =                         64;
diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h
index 9136c0a804..85403fd4ef 100644
--- a/source/Lib/EncoderLib/EncCfg.h
+++ b/source/Lib/EncoderLib/EncCfg.h
@@ -176,7 +176,9 @@ protected:
   std::string m_rdoCnnlfIntraLumaModelNameNNFilter1;          ///< intra luma nnlf set1 model
 #endif
 #if JVET_AC0196_NNSR
-  bool m_nnsrOption;
+  bool        m_nnsrOption;
+  int         m_nnsrQpOffset;
+  std::string m_sNnsrQpOffsetOverridebyQp;
 #endif
   int       m_iFrameRate;
   int       m_FrameSkip;
@@ -914,6 +916,10 @@ public:
 #if JVET_AC0196_NNSR
   bool             getUseNnsr()                                   { return m_nnsrOption;                   };
   void             setUseNnsr(bool b)                             { m_nnsrOption = b;                      };
+  int              getNnsrQpOffset()                              { return m_nnsrQpOffset;                   };
+  void             setNnsrQpOffset(int b)                         { m_nnsrQpOffset = b;                      };
+  std::string      getNnsrQpOffsetOverridebyQp()                  { return m_sNnsrQpOffsetOverridebyQp; };
+  void             setNnsrQpOffsetOverridebyQp(std::string b)     { m_sNnsrQpOffsetOverridebyQp = b; };
 #endif
 
   void setProfile(Profile::Name profile) { m_profile = profile; }
diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp
index edd4984f3d..94c2a9af12 100644
--- a/source/Lib/EncoderLib/EncLib.cpp
+++ b/source/Lib/EncoderLib/EncLib.cpp
@@ -809,7 +809,8 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu
       ppsID = m_iGOPRprPpsId;
     }
 #elif JVET_AC0196_NNSR
-    ppsID = ENC_PPS_ID_RPR;
+    if (m_resChangeInClvsEnabled)
+      ppsID = ENC_PPS_ID_RPR;
 #else
     if( m_resChangeInClvsEnabled && m_intraPeriod == -1 )
     {
@@ -2532,8 +2533,27 @@ int EncCfg::getQPForPicture(const uint32_t gopIndex, const Slice *pSlice) const
 #if ADAPTIVE_RPR
     if (pSlice->getPPS()->getPPSId() == ENC_PPS_ID_RPR)
     {
-      // adjust QP for rate matching. 
-      qp -= 0;
+      // adjust QP for rate matching.
+      int baseQp = qp;
+      constexpr int maxQp = 64;
+      std::array<int, maxQp> qpOffsets;
+      qpOffsets.fill(m_nnsrQpOffset);
+
+      // rate adaptation for AI
+      if (getIntraPeriod() == 1)
+      {
+        std::istringstream iss(m_sNnsrQpOffsetOverridebyQp);
+        int iqp, qpOffset;
+        while (iss >> iqp >> qpOffset)
+        {
+          if (iqp < 0 || iqp >= maxQp)
+          {
+            std::cerr << "Error in m_sNnsrQpOffsetOverridebyQp format" << std::endl;
+          }
+          qpOffsets[iqp] = qpOffset;
+        }
+      }
+      qp -= qpOffsets[baseQp];
     }
 #endif
     
-- 
GitLab