From cce7f95d65ebd5de54b8bad08bef4ac8ad016baa Mon Sep 17 00:00:00 2001
From: Fabrice URBAN <fabrice.urban@interdigital.com>
Date: Wed, 25 Sep 2024 15:06:48 +0000
Subject: [PATCH] Make RdCost members non-`static` to use correct weighting
 tables for each layer in multilayer encoding

---
 source/Lib/CommonLib/RdCost.cpp       | 69 ++++++++++++++-------------
 source/Lib/CommonLib/RdCost.h         | 43 +++++++++--------
 source/Lib/CommonLib/TypeDef.h        | 30 ++++++++----
 source/Lib/EncoderLib/EncCu.cpp       | 14 +++---
 source/Lib/EncoderLib/InterSearch.cpp |  8 ++--
 source/Lib/EncoderLib/IntraSearch.cpp | 20 ++++----
 6 files changed, 102 insertions(+), 82 deletions(-)

diff --git a/source/Lib/CommonLib/RdCost.cpp b/source/Lib/CommonLib/RdCost.cpp
index 52167df43..fcbde85d7 100644
--- a/source/Lib/CommonLib/RdCost.cpp
+++ b/source/Lib/CommonLib/RdCost.cpp
@@ -59,6 +59,7 @@ RdCost::~RdCost()
 }
 
 #if WCG_EXT
+EnumArray<DistFuncWtd, DFuncWtd> RdCost::m_distortionFuncWtd;
 double RdCost::calcRdCost( uint64_t fracBits, Distortion distortion, bool useUnadjustedLambda )
 #else
 double RdCost::calcRdCost( uint64_t fracBits, Distortion distortion )
@@ -193,14 +194,14 @@ void RdCost::init()
   m_distortionFunc[DFunc::SAD_FULL_NBIT16N] = RdCost::xGetSAD_full;
 
 #if WCG_EXT
-  m_distortionFunc[DFunc::SSE_WTD]    = RdCost::xGetSSE_WTD;
-  m_distortionFunc[DFunc::SSE2_WTD]   = RdCost::xGetSSE2_WTD;
-  m_distortionFunc[DFunc::SSE4_WTD]   = RdCost::xGetSSE4_WTD;
-  m_distortionFunc[DFunc::SSE8_WTD]   = RdCost::xGetSSE8_WTD;
-  m_distortionFunc[DFunc::SSE16_WTD]  = RdCost::xGetSSE16_WTD;
-  m_distortionFunc[DFunc::SSE32_WTD]  = RdCost::xGetSSE32_WTD;
-  m_distortionFunc[DFunc::SSE64_WTD]  = RdCost::xGetSSE64_WTD;
-  m_distortionFunc[DFunc::SSE16N_WTD] = RdCost::xGetSSE16N_WTD;
+  m_distortionFuncWtd[DFuncWtd::SSE_WTD]    = &RdCost::xGetSSE_WTD;
+  m_distortionFuncWtd[DFuncWtd::SSE2_WTD]   = &RdCost::xGetSSE2_WTD;
+  m_distortionFuncWtd[DFuncWtd::SSE4_WTD]   = &RdCost::xGetSSE4_WTD;
+  m_distortionFuncWtd[DFuncWtd::SSE8_WTD]   = &RdCost::xGetSSE8_WTD;
+  m_distortionFuncWtd[DFuncWtd::SSE16_WTD]  = &RdCost::xGetSSE16_WTD;
+  m_distortionFuncWtd[DFuncWtd::SSE32_WTD]  = &RdCost::xGetSSE32_WTD;
+  m_distortionFuncWtd[DFuncWtd::SSE64_WTD]  = &RdCost::xGetSSE64_WTD;
+  m_distortionFuncWtd[DFuncWtd::SSE16N_WTD] = &RdCost::xGetSSE16N_WTD;
 #endif
 
   m_distortionFunc[DFunc::SAD_INTERMEDIATE_BITDEPTH] = RdCost::xGetSAD;
@@ -346,11 +347,7 @@ void RdCost::setDistParam(DistParam &rcDP, const Pel *pOrg, const Pel *piRefY, p
 
 #if WCG_EXT
 Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID,
-                               DFunc distFunc, const CPelBuf *orgLuma)
-#else
-    Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID,
-                                   DFunc distFunc)
-#endif
+                               DFuncWtd distFuncWtd, const CPelBuf &orgLuma)
 {
   DistParam cDtParam;
 
@@ -360,21 +357,36 @@ Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDe
   cDtParam.bitDepth   = bitDepth;
   cDtParam.compID     = compID;
 
-#if WCG_EXT
-  if( orgLuma )
+  cDtParam.cShiftX = getComponentScaleX(compID,  m_cf);
+  cDtParam.cShiftY = getComponentScaleY(compID,  m_cf);
+  if (isChroma(compID))
   {
-    cDtParam.cShiftX = getComponentScaleX(compID,  m_cf);
-    cDtParam.cShiftY = getComponentScaleY(compID,  m_cf);
-    if( isChroma(compID) )
-    {
-      cDtParam.orgLuma  = *orgLuma;
-    }
-    else
-    {
-      cDtParam.orgLuma  = org;
-    }
+    cDtParam.orgLuma  = orgLuma;
+  }
+  else
+  {
+    cDtParam.orgLuma  = org;
+  }
+
+  cDtParam.distFuncWtd = m_distortionFuncWtd[distFuncWtd + sizeOffset<false>(org.width)];
+  Distortion dist = cDtParam.distFuncWtd(this, cDtParam);
+  if (isChroma(compID))
+  {
+    dist = (Distortion)(m_distortionWeight[MAP_CHROMA(compID)] * dist);
   }
+  return dist;
+}
 #endif
+Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID,
+                                   DFunc distFunc)
+{
+  DistParam cDtParam;
+
+  cDtParam.org        = org;
+  cDtParam.cur        = cur;
+  cDtParam.step       = 1;
+  cDtParam.bitDepth   = bitDepth;
+  cDtParam.compID     = compID;
 
   cDtParam.distFunc = m_distortionFunc[distFunc + sizeOffset<false>(org.width)];
 
@@ -3004,13 +3016,6 @@ Distortion RdCost::xGetHADs( const DistParam &rcDtParam )
 
 
 #if WCG_EXT
-uint32_t RdCost::m_signalType   = RESHAPE_SIGNAL_NULL;
-int32_t  RdCost::m_chromaWeight = MSE_WEIGHT_ONE;
-int      RdCost::m_lumaBD       = 10;
-
-std::vector<int32_t> RdCost::m_reshapeLumaLevelToWeightPLUT;
-std::vector<double> RdCost::m_lumaLevelToWeightPLUT;
-
 void RdCost::saveUnadjustedLambda()
 {
   m_dLambda_unadjusted = m_dLambda;
diff --git a/source/Lib/CommonLib/RdCost.h b/source/Lib/CommonLib/RdCost.h
index 8a04014da..51ee41135 100644
--- a/source/Lib/CommonLib/RdCost.h
+++ b/source/Lib/CommonLib/RdCost.h
@@ -59,6 +59,10 @@ class EncCfg;
 
 using DistFunc = std::function<Distortion(const DistParam &)>;
 
+#if WCG_EXT
+class RdCost;
+using DistFuncWtd = std::function<Distortion(RdCost *, const DistParam &)>;
+#endif
 // ====================================================================================================================
 // Class definition
 // ====================================================================================================================
@@ -71,6 +75,7 @@ public:
   CPelBuf               cur;
 #if WCG_EXT
   CPelBuf               orgLuma;
+  DistFuncWtd           distFuncWtd;
 #endif
   const Pel*            mask;
   ptrdiff_t             maskStride;
@@ -117,15 +122,16 @@ private:
   bool                    m_isLosslessRDCost;
 
 #if WCG_EXT
+  static EnumArray<DistFuncWtd, DFuncWtd> m_distortionFuncWtd;
   double                  m_dLambda_unadjusted; // TODO: check is necessary
   double                  m_distScaleUnadjusted;
 
-  static std::vector<int32_t> m_reshapeLumaLevelToWeightPLUT;   // scaled by MSE_WEIGHT_ONE
-  static std::vector<double>  m_lumaLevelToWeightPLUT;
+  std::vector<int32_t> m_reshapeLumaLevelToWeightPLUT;   // scaled by MSE_WEIGHT_ONE
+  std::vector<double>  m_lumaLevelToWeightPLUT;
 
-  static int32_t  m_chromaWeight;   // scaled by MSE_WEIGHT_ONE
-  static uint32_t m_signalType;
-  static int      m_lumaBD;
+  int32_t  m_chromaWeight = RESHAPE_SIGNAL_NULL;   // scaled by MSE_WEIGHT_ONE
+  uint32_t m_signalType = MSE_WEIGHT_ONE;
+  int      m_lumaBD = 10;
 
   ChromaFormat            m_cf;
 #endif
@@ -382,15 +388,15 @@ private:
   static Distortion xGetSSE16N        ( const DistParam& pcDtParam );
 
 #if WCG_EXT
-  static Distortion getWeightedMSE(int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma);
-  static Distortion xGetSSE_WTD       ( const DistParam& pcDtParam );
-  static Distortion xGetSSE2_WTD      ( const DistParam& pcDtParam );
-  static Distortion xGetSSE4_WTD      ( const DistParam& pcDtParam );
-  static Distortion xGetSSE8_WTD      ( const DistParam& pcDtParam );
-  static Distortion xGetSSE16_WTD     ( const DistParam& pcDtParam );
-  static Distortion xGetSSE32_WTD     ( const DistParam& pcDtParam );
-  static Distortion xGetSSE64_WTD     ( const DistParam& pcDtParam );
-  static Distortion xGetSSE16N_WTD    ( const DistParam& pcDtParam );
+  Distortion getWeightedMSE(int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma);
+  Distortion xGetSSE_WTD       ( const DistParam& pcDtParam );
+  Distortion xGetSSE2_WTD      ( const DistParam& pcDtParam );
+  Distortion xGetSSE4_WTD      ( const DistParam& pcDtParam );
+  Distortion xGetSSE8_WTD      ( const DistParam& pcDtParam );
+  Distortion xGetSSE16_WTD     ( const DistParam& pcDtParam );
+  Distortion xGetSSE32_WTD     ( const DistParam& pcDtParam );
+  Distortion xGetSSE64_WTD     ( const DistParam& pcDtParam );
+  Distortion xGetSSE16N_WTD    ( const DistParam& pcDtParam );
 #endif
 
   static Distortion xGetSAD           ( const DistParam& pcDtParam );
@@ -469,12 +475,11 @@ private:
 public:
 
 #if WCG_EXT
-  Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, DFunc distFunc,
-                         const CPelBuf *orgLuma = nullptr);
-#else
-  Distortion    getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID,
-                            DFunc distFunc);
+  Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, DFuncWtd distFuncWtd,
+                         const CPelBuf &orgLuma);
 #endif
+  Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID,
+                         DFunc distFunc);
 
   Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, const Pel *mask, int bitDepth,
                          const ComponentID compID, DFunc distFunc);
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index 1685cdd50..c495d24db 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -613,16 +613,6 @@ enum class DFunc
   SAD_FULL_NBIT64,
   SAD_FULL_NBIT16N,
 
-  // Weighted SSE functions by size
-  SSE_WTD,
-  SSE2_WTD,
-  SSE4_WTD,
-  SSE8_WTD,
-  SSE16_WTD,
-  SSE32_WTD,
-  SSE64_WTD,
-  SSE16N_WTD,
-
   SAD_INTERMEDIATE_BITDEPTH,
 
   SAD_WITH_MASK,
@@ -640,6 +630,26 @@ static inline DFunc operator+(const DFunc &a, const DFuncDiff &b)
 {
   return static_cast<DFunc>(to_underlying(a) + to_underlying(b));
 }
+#if WCG_EXT
+enum class DFuncWtd
+{
+  // Weighted SSE functions by size
+  SSE_WTD,
+  SSE2_WTD,
+  SSE4_WTD,
+  SSE8_WTD,
+  SSE16_WTD,
+  SSE32_WTD,
+  SSE64_WTD,
+  SSE16N_WTD,
+
+  NUM
+};
+static inline DFuncWtd operator+(const DFuncWtd &a, const DFuncDiff &b)
+{
+  return static_cast<DFuncWtd>(to_underlying(a) + to_underlying(b));
+}
+#endif
 
 /// motion vector predictor direction used in AMVP
 enum MvpDir
diff --git a/source/Lib/EncoderLib/EncCu.cpp b/source/Lib/EncoderLib/EncCu.cpp
index 14ca7aaef..50ac07031 100644
--- a/source/Lib/EncoderLib/EncCu.cpp
+++ b/source/Lib/EncoderLib/EncCu.cpp
@@ -4000,12 +4000,12 @@ Distortion EncCu::getDistortionDb( CodingStructure &cs, CPelBuf org, CPelBuf rec
       tmpRecLuma.copyFrom( reco );
       tmpRecLuma.rspSignal( m_pcReshape->getInvLUT() );
       dist += m_pcRdCost->getDistPart(org, tmpRecLuma, cs.sps->getBitDepth(toChannelType(compID)), compID,
-                                      DFunc::SSE_WTD, &orgLuma);
+                                      DFuncWtd::SSE_WTD, orgLuma);
     }
     else
     {
-      dist += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID, DFunc::SSE_WTD,
-                                      &orgLuma);
+      dist += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID,
+                                      DFuncWtd::SSE_WTD, orgLuma);
     }
   }
   else if (m_pcEncCfg->getLmcs() && cs.slice->getLmcsEnabledFlag() && cs.slice->isIntra()) //intra slice
@@ -4022,8 +4022,8 @@ Distortion EncCu::getDistortionDb( CodingStructure &cs, CPelBuf org, CPelBuf rec
     {
       if ((isChroma(compID) && m_pcEncCfg->getReshapeIntraCMD()))
       {
-        dist += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID, DFunc::SSE_WTD,
-                                        &orgLuma);
+        dist += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID,
+                                        DFuncWtd::SSE_WTD, orgLuma);
       }
       else
       {
@@ -4485,12 +4485,12 @@ void EncCu::xReuseCachedResult( CodingStructure *&tempCS, CodingStructure *&best
             tmpRecLuma.copyFrom(reco);
             tmpRecLuma.rspSignal(m_pcReshape->getInvLUT());
             finalDistortion += m_pcRdCost->getDistPart(org, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID,
-                                                       DFunc::SSE_WTD, &orgLuma);
+                                                       DFuncWtd::SSE_WTD, orgLuma);
           }
           else
           {
             finalDistortion += m_pcRdCost->getDistPart(org, reco, sps.getBitDepth(toChannelType(compID)), compID,
-                                                       DFunc::SSE_WTD, &orgLuma);
+                                                       DFuncWtd::SSE_WTD, orgLuma);
           }
         }
         else
diff --git a/source/Lib/EncoderLib/InterSearch.cpp b/source/Lib/EncoderLib/InterSearch.cpp
index 5dfc73576..6113795b4 100644
--- a/source/Lib/EncoderLib/InterSearch.cpp
+++ b/source/Lib/EncoderLib/InterSearch.cpp
@@ -10622,12 +10622,12 @@ void InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &pa
           tmpRecLuma.copyFrom(reco);
           tmpRecLuma.rspSignal(m_pcReshape->getInvLUT());
           distortion += m_pcRdCost->getDistPart(org, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID,
-                                                DFunc::SSE_WTD, &orgLuma);
+                                                DFuncWtd::SSE_WTD, orgLuma);
         }
         else
         {
           distortion += m_pcRdCost->getDistPart(org, reco, sps.getBitDepth(toChannelType(compID)), compID,
-                                                DFunc::SSE_WTD, &orgLuma);
+                                                DFuncWtd::SSE_WTD, orgLuma);
         }
       }
       else
@@ -10968,12 +10968,12 @@ void InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &pa
         tmpRecLuma.copyFrom(reco);
         tmpRecLuma.rspSignal(m_pcReshape->getInvLUT());
         finalDistortion += m_pcRdCost->getDistPart(org, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID,
-                                                   DFunc::SSE_WTD, &orgLuma);
+                                                   DFuncWtd::SSE_WTD, orgLuma);
       }
       else
       {
         finalDistortion +=
-          m_pcRdCost->getDistPart(org, reco, sps.getBitDepth(toChannelType(compID)), compID, DFunc::SSE_WTD, &orgLuma);
+          m_pcRdCost->getDistPart(org, reco, sps.getBitDepth(toChannelType(compID)), compID, DFuncWtd::SSE_WTD, orgLuma);
       }
     }
     else
diff --git a/source/Lib/EncoderLib/IntraSearch.cpp b/source/Lib/EncoderLib/IntraSearch.cpp
index e63f5eb53..e80019ca0 100644
--- a/source/Lib/EncoderLib/IntraSearch.cpp
+++ b/source/Lib/EncoderLib/IntraSearch.cpp
@@ -1941,12 +1941,12 @@ void IntraSearch::PLTSearch(CodingStructure &cs, Partitioner& partitioner, Compo
         tmpRecLuma.copyFrom(reco);
         tmpRecLuma.rspSignal(m_pcReshape->getInvLUT());
         distortion += m_pcRdCost->getDistPart(org, tmpRecLuma, cs.sps->getBitDepth(toChannelType(compID)), compID,
-                                              DFunc::SSE_WTD, &orgLuma);
+                                              DFuncWtd::SSE_WTD, orgLuma);
       }
       else
       {
         distortion += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID,
-                                              DFunc::SSE_WTD, &orgLuma);
+                                              DFuncWtd::SSE_WTD, orgLuma);
       }
     }
     else
@@ -3542,15 +3542,15 @@ void IntraSearch::xIntraCodingTUBlock(TransformUnit &tu, const ComponentID &comp
       PelBuf   tmpRecLuma = m_tmpStorageCtu.getBuf(tmpArea1);
       tmpRecLuma.copyFrom(piReco);
       tmpRecLuma.rspSignal(m_pcReshape->getInvLUT());
-      dist += m_pcRdCost->getDistPart(piOrg, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, DFunc::SSE_WTD,
-                                      &orgLuma);
+      dist += m_pcRdCost->getDistPart(piOrg, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, DFuncWtd::SSE_WTD,
+                                      orgLuma);
     }
     else
     {
-      dist += m_pcRdCost->getDistPart(piOrg, piReco, bitDepth, compID, DFunc::SSE_WTD, &orgLuma);
+      dist += m_pcRdCost->getDistPart(piOrg, piReco, bitDepth, compID, DFuncWtd::SSE_WTD, orgLuma);
       if( jointCbCr )
       {
-        dist += m_pcRdCost->getDistPart(crOrg, crReco, bitDepth, COMPONENT_Cr, DFunc::SSE_WTD, &orgLuma);
+        dist += m_pcRdCost->getDistPart(crOrg, crReco, bitDepth, COMPONENT_Cr, DFuncWtd::SSE_WTD, orgLuma);
       }
     }
   }
@@ -4859,12 +4859,12 @@ bool IntraSearch::xRecurIntraCodingACTQT(CodingStructure &cs, Partitioner &parti
           tmpRecLuma.copyFrom(piReco);
           tmpRecLuma.rspSignal(m_pcReshape->getInvLUT());
           totalDist += m_pcRdCost->getDistPart(piOrg, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID,
-                                               DFunc::SSE_WTD, &orgLuma);
+                                               DFuncWtd::SSE_WTD, orgLuma);
         }
         else
         {
           totalDist += m_pcRdCost->getDistPart(piOrg, piReco, sps.getBitDepth(toChannelType(compID)), compID,
-                                               DFunc::SSE_WTD, &orgLuma);
+                                               DFuncWtd::SSE_WTD, orgLuma);
         }
       }
       else
@@ -4976,12 +4976,12 @@ bool IntraSearch::xRecurIntraCodingACTQT(CodingStructure &cs, Partitioner &parti
                 tmpRecLuma.copyFrom(piReco);
                 tmpRecLuma.rspSignal(m_pcReshape->getInvLUT());
                 distTmp += m_pcRdCost->getDistPart(piOrg, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID,
-                                                   DFunc::SSE_WTD, &orgLuma);
+                                                   DFuncWtd::SSE_WTD, orgLuma);
               }
               else
               {
                 distTmp += m_pcRdCost->getDistPart(piOrg, piReco, sps.getBitDepth(toChannelType(compID)), compID,
-                                                   DFunc::SSE_WTD, &orgLuma);
+                                                   DFuncWtd::SSE_WTD, orgLuma);
               }
             }
             else
-- 
GitLab