diff --git a/source/Lib/CommonLib/RdCost.cpp b/source/Lib/CommonLib/RdCost.cpp
index fcbde85d71bb649a26d594e91d7577a10c9fd7be..11457317855c3151c7bfe0b98669be05ee0f11b1 100644
--- a/source/Lib/CommonLib/RdCost.cpp
+++ b/source/Lib/CommonLib/RdCost.cpp
@@ -347,7 +347,7 @@ void RdCost::setDistParam(DistParam &rcDP, const Pel *pOrg, const Pel *piRefY, p
 
 #if WCG_EXT
 Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID,
-                               DFuncWtd distFuncWtd, const CPelBuf &orgLuma)
+                               DFuncWtd distFuncWtd, const CPelBuf &orgLuma) const
 {
   DistParam cDtParam;
 
@@ -357,8 +357,6 @@ Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDe
   cDtParam.bitDepth   = bitDepth;
   cDtParam.compID     = compID;
 
-  cDtParam.cShiftX = getComponentScaleX(compID,  m_cf);
-  cDtParam.cShiftY = getComponentScaleY(compID,  m_cf);
   if (isChroma(compID))
   {
     cDtParam.orgLuma  = orgLuma;
@@ -368,8 +366,20 @@ Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDe
     cDtParam.orgLuma  = org;
   }
 
-  cDtParam.distFuncWtd = m_distortionFuncWtd[distFuncWtd + sizeOffset<false>(org.width)];
-  Distortion dist = cDtParam.distFuncWtd(this, cDtParam);
+  Distortion dist;
+  if (isChroma(compID) && (m_signalType == RESHAPE_SIGNAL_SDR || m_signalType == RESHAPE_SIGNAL_HLG))
+  {
+       cDtParam.distFunc = m_distortionFunc[DFunc::SSE + sizeOffset<false>(org.width)];
+       int64_t weight = m_chromaWeight;
+       dist = (weight * cDtParam.distFunc(cDtParam ) + (1 << MSE_WEIGHT_FRAC_BITS >> 1)) >> (MSE_WEIGHT_FRAC_BITS);
+  }
+  else
+  {
+    cDtParam.cShiftX = getComponentScaleX(compID,  m_cf);
+    cDtParam.cShiftY = getComponentScaleY(compID,  m_cf);
+    cDtParam.distFuncWtd = m_distortionFuncWtd[distFuncWtd + sizeOffset<false>(org.width)];
+    dist = cDtParam.distFuncWtd(this, cDtParam);
+  }
   if (isChroma(compID))
   {
     dist = (Distortion)(m_distortionWeight[MAP_CHROMA(compID)] * dist);
@@ -378,7 +388,7 @@ Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDe
 }
 #endif
 Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID,
-                                   DFunc distFunc)
+                                   DFunc distFunc) const
 {
   DistParam cDtParam;
 
@@ -3131,7 +3141,7 @@ void RdCost::updateReshapeLumaLevelToWeightTable(SliceReshapeInfo &sliceReshape,
   }
 }
 
-Distortion RdCost::getWeightedMSE(int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma)
+Distortion RdCost::getWeightedMSE(const int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma) const
 {
   CHECKD(org < 0, "Sample value must be positive");
 
@@ -3143,27 +3153,12 @@ Distortion RdCost::getWeightedMSE(int compIdx, const Pel org, const Pel cur, con
   Pel diff = org - cur;
 
   // use luma to get weight
-  int64_t weight = MSE_WEIGHT_ONE;
-  if (m_signalType == RESHAPE_SIGNAL_SDR || m_signalType == RESHAPE_SIGNAL_HLG)
-  {
-    if (compIdx == COMPONENT_Y)
-    {
-      weight = m_reshapeLumaLevelToWeightPLUT[orgLuma];
-    }
-    else
-    {
-      weight = m_chromaWeight;
-    }
-  }
-  else
-  {
-    weight = m_reshapeLumaLevelToWeightPLUT[orgLuma];
-  }
+  int64_t weight = m_reshapeLumaLevelToWeightPLUT[orgLuma];
 
   return (weight * (diff * diff) + (1 << MSE_WEIGHT_FRAC_BITS >> 1)) >> (MSE_WEIGHT_FRAC_BITS + shift);
 }
 
-Distortion RdCost::xGetSSE_WTD( const DistParam &rcDtParam )
+Distortion RdCost::xGetSSE_WTD( const DistParam &rcDtParam ) const
 {
   if( rcDtParam.applyWeight )
   {
@@ -3196,7 +3191,7 @@ Distortion RdCost::xGetSSE_WTD( const DistParam &rcDtParam )
   return (sum);
 }
 
-Distortion RdCost::xGetSSE2_WTD( const DistParam &rcDtParam )
+Distortion RdCost::xGetSSE2_WTD( const DistParam &rcDtParam ) const
 {
   if( rcDtParam.applyWeight )
   {
@@ -3231,7 +3226,7 @@ Distortion RdCost::xGetSSE2_WTD( const DistParam &rcDtParam )
   return (sum);
 }
 
-Distortion RdCost::xGetSSE4_WTD( const DistParam &rcDtParam )
+Distortion RdCost::xGetSSE4_WTD( const DistParam &rcDtParam ) const
 {
   if( rcDtParam.applyWeight )
   {
@@ -3272,7 +3267,7 @@ Distortion RdCost::xGetSSE4_WTD( const DistParam &rcDtParam )
   return (sum);
 }
 
-Distortion RdCost::xGetSSE8_WTD( const DistParam &rcDtParam )
+Distortion RdCost::xGetSSE8_WTD( const DistParam &rcDtParam ) const
 {
   if( rcDtParam.applyWeight )
   {
@@ -3324,7 +3319,7 @@ Distortion RdCost::xGetSSE8_WTD( const DistParam &rcDtParam )
   return (sum);
 }
 
-Distortion RdCost::xGetSSE16_WTD( const DistParam &rcDtParam )
+Distortion RdCost::xGetSSE16_WTD( const DistParam &rcDtParam ) const
 {
   if( rcDtParam.applyWeight )
   {
@@ -3399,7 +3394,7 @@ Distortion RdCost::xGetSSE16_WTD( const DistParam &rcDtParam )
   return (sum);
 }
 
-Distortion RdCost::xGetSSE16N_WTD( const DistParam &rcDtParam )
+Distortion RdCost::xGetSSE16N_WTD( const DistParam &rcDtParam ) const
 {
   if( rcDtParam.applyWeight )
   {
@@ -3477,7 +3472,7 @@ Distortion RdCost::xGetSSE16N_WTD( const DistParam &rcDtParam )
   return (sum);
 }
 
-Distortion RdCost::xGetSSE32_WTD( const DistParam &rcDtParam )
+Distortion RdCost::xGetSSE32_WTD( const DistParam &rcDtParam ) const
 {
   if( rcDtParam.applyWeight )
   {
@@ -3601,7 +3596,7 @@ Distortion RdCost::xGetSSE32_WTD( const DistParam &rcDtParam )
   return (sum);
 }
 
-Distortion RdCost::xGetSSE64_WTD( const DistParam &rcDtParam )
+Distortion RdCost::xGetSSE64_WTD( const DistParam &rcDtParam ) const
 {
   if( rcDtParam.applyWeight )
   {
diff --git a/source/Lib/CommonLib/RdCost.h b/source/Lib/CommonLib/RdCost.h
index 51ee4113532f64c3e5c366599106ce02bf34a759..f111375844b10c44e86102f9d5269808176bd741 100644
--- a/source/Lib/CommonLib/RdCost.h
+++ b/source/Lib/CommonLib/RdCost.h
@@ -61,7 +61,7 @@ using DistFunc = std::function<Distortion(const DistParam &)>;
 
 #if WCG_EXT
 class RdCost;
-using DistFuncWtd = std::function<Distortion(RdCost *, const DistParam &)>;
+using DistFuncWtd = std::function<Distortion(const RdCost *, const DistParam &)>;
 #endif
 // ====================================================================================================================
 // Class definition
@@ -388,15 +388,15 @@ private:
   static Distortion xGetSSE16N        ( const DistParam& pcDtParam );
 
 #if WCG_EXT
-  Distortion getWeightedMSE(int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma);
-  Distortion xGetSSE_WTD       ( const DistParam& pcDtParam );
-  Distortion xGetSSE2_WTD      ( const DistParam& pcDtParam );
-  Distortion xGetSSE4_WTD      ( const DistParam& pcDtParam );
-  Distortion xGetSSE8_WTD      ( const DistParam& pcDtParam );
-  Distortion xGetSSE16_WTD     ( const DistParam& pcDtParam );
-  Distortion xGetSSE32_WTD     ( const DistParam& pcDtParam );
-  Distortion xGetSSE64_WTD     ( const DistParam& pcDtParam );
-  Distortion xGetSSE16N_WTD    ( const DistParam& pcDtParam );
+  Distortion getWeightedMSE(const int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma) const;
+  Distortion xGetSSE_WTD       ( const DistParam& pcDtParam ) const;
+  Distortion xGetSSE2_WTD      ( const DistParam& pcDtParam ) const;
+  Distortion xGetSSE4_WTD      ( const DistParam& pcDtParam ) const;
+  Distortion xGetSSE8_WTD      ( const DistParam& pcDtParam ) const;
+  Distortion xGetSSE16_WTD     ( const DistParam& pcDtParam ) const;
+  Distortion xGetSSE32_WTD     ( const DistParam& pcDtParam ) const;
+  Distortion xGetSSE64_WTD     ( const DistParam& pcDtParam ) const;
+  Distortion xGetSSE16N_WTD    ( const DistParam& pcDtParam ) const;
 #endif
 
   static Distortion xGetSAD           ( const DistParam& pcDtParam );
@@ -476,10 +476,10 @@ public:
 
 #if WCG_EXT
   Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, DFuncWtd distFuncWtd,
-                         const CPelBuf &orgLuma);
+                         const CPelBuf &orgLuma) const;
 #endif
   Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID,
-                         DFunc distFunc);
+                         DFunc distFunc) const;
 
   Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, const Pel *mask, int bitDepth,
                          const ComponentID compID, DFunc distFunc);