diff --git a/source/Lib/CommonLib/RdCost.cpp b/source/Lib/CommonLib/RdCost.cpp index fcbde85d71bb649a26d594e91d7577a10c9fd7be..11457317855c3151c7bfe0b98669be05ee0f11b1 100644 --- a/source/Lib/CommonLib/RdCost.cpp +++ b/source/Lib/CommonLib/RdCost.cpp @@ -347,7 +347,7 @@ void RdCost::setDistParam(DistParam &rcDP, const Pel *pOrg, const Pel *piRefY, p #if WCG_EXT Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, - DFuncWtd distFuncWtd, const CPelBuf &orgLuma) + DFuncWtd distFuncWtd, const CPelBuf &orgLuma) const { DistParam cDtParam; @@ -357,8 +357,6 @@ Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDe cDtParam.bitDepth = bitDepth; cDtParam.compID = compID; - cDtParam.cShiftX = getComponentScaleX(compID, m_cf); - cDtParam.cShiftY = getComponentScaleY(compID, m_cf); if (isChroma(compID)) { cDtParam.orgLuma = orgLuma; @@ -368,8 +366,20 @@ Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDe cDtParam.orgLuma = org; } - cDtParam.distFuncWtd = m_distortionFuncWtd[distFuncWtd + sizeOffset<false>(org.width)]; - Distortion dist = cDtParam.distFuncWtd(this, cDtParam); + Distortion dist; + if (isChroma(compID) && (m_signalType == RESHAPE_SIGNAL_SDR || m_signalType == RESHAPE_SIGNAL_HLG)) + { + cDtParam.distFunc = m_distortionFunc[DFunc::SSE + sizeOffset<false>(org.width)]; + int64_t weight = m_chromaWeight; + dist = (weight * cDtParam.distFunc(cDtParam ) + (1 << MSE_WEIGHT_FRAC_BITS >> 1)) >> (MSE_WEIGHT_FRAC_BITS); + } + else + { + cDtParam.cShiftX = getComponentScaleX(compID, m_cf); + cDtParam.cShiftY = getComponentScaleY(compID, m_cf); + cDtParam.distFuncWtd = m_distortionFuncWtd[distFuncWtd + sizeOffset<false>(org.width)]; + dist = cDtParam.distFuncWtd(this, cDtParam); + } if (isChroma(compID)) { dist = (Distortion)(m_distortionWeight[MAP_CHROMA(compID)] * dist); @@ -378,7 +388,7 @@ Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDe } #endif Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, - DFunc distFunc) + DFunc distFunc) const { DistParam cDtParam; @@ -3131,7 +3141,7 @@ void RdCost::updateReshapeLumaLevelToWeightTable(SliceReshapeInfo &sliceReshape, } } -Distortion RdCost::getWeightedMSE(int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma) +Distortion RdCost::getWeightedMSE(const int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma) const { CHECKD(org < 0, "Sample value must be positive"); @@ -3143,27 +3153,12 @@ Distortion RdCost::getWeightedMSE(int compIdx, const Pel org, const Pel cur, con Pel diff = org - cur; // use luma to get weight - int64_t weight = MSE_WEIGHT_ONE; - if (m_signalType == RESHAPE_SIGNAL_SDR || m_signalType == RESHAPE_SIGNAL_HLG) - { - if (compIdx == COMPONENT_Y) - { - weight = m_reshapeLumaLevelToWeightPLUT[orgLuma]; - } - else - { - weight = m_chromaWeight; - } - } - else - { - weight = m_reshapeLumaLevelToWeightPLUT[orgLuma]; - } + int64_t weight = m_reshapeLumaLevelToWeightPLUT[orgLuma]; return (weight * (diff * diff) + (1 << MSE_WEIGHT_FRAC_BITS >> 1)) >> (MSE_WEIGHT_FRAC_BITS + shift); } -Distortion RdCost::xGetSSE_WTD( const DistParam &rcDtParam ) +Distortion RdCost::xGetSSE_WTD( const DistParam &rcDtParam ) const { if( rcDtParam.applyWeight ) { @@ -3196,7 +3191,7 @@ Distortion RdCost::xGetSSE_WTD( const DistParam &rcDtParam ) return (sum); } -Distortion RdCost::xGetSSE2_WTD( const DistParam &rcDtParam ) +Distortion RdCost::xGetSSE2_WTD( const DistParam &rcDtParam ) const { if( rcDtParam.applyWeight ) { @@ -3231,7 +3226,7 @@ Distortion RdCost::xGetSSE2_WTD( const DistParam &rcDtParam ) return (sum); } -Distortion RdCost::xGetSSE4_WTD( const DistParam &rcDtParam ) +Distortion RdCost::xGetSSE4_WTD( const DistParam &rcDtParam ) const { if( rcDtParam.applyWeight ) { @@ -3272,7 +3267,7 @@ Distortion RdCost::xGetSSE4_WTD( const DistParam &rcDtParam ) return (sum); } -Distortion RdCost::xGetSSE8_WTD( const DistParam &rcDtParam ) +Distortion RdCost::xGetSSE8_WTD( const DistParam &rcDtParam ) const { if( rcDtParam.applyWeight ) { @@ -3324,7 +3319,7 @@ Distortion RdCost::xGetSSE8_WTD( const DistParam &rcDtParam ) return (sum); } -Distortion RdCost::xGetSSE16_WTD( const DistParam &rcDtParam ) +Distortion RdCost::xGetSSE16_WTD( const DistParam &rcDtParam ) const { if( rcDtParam.applyWeight ) { @@ -3399,7 +3394,7 @@ Distortion RdCost::xGetSSE16_WTD( const DistParam &rcDtParam ) return (sum); } -Distortion RdCost::xGetSSE16N_WTD( const DistParam &rcDtParam ) +Distortion RdCost::xGetSSE16N_WTD( const DistParam &rcDtParam ) const { if( rcDtParam.applyWeight ) { @@ -3477,7 +3472,7 @@ Distortion RdCost::xGetSSE16N_WTD( const DistParam &rcDtParam ) return (sum); } -Distortion RdCost::xGetSSE32_WTD( const DistParam &rcDtParam ) +Distortion RdCost::xGetSSE32_WTD( const DistParam &rcDtParam ) const { if( rcDtParam.applyWeight ) { @@ -3601,7 +3596,7 @@ Distortion RdCost::xGetSSE32_WTD( const DistParam &rcDtParam ) return (sum); } -Distortion RdCost::xGetSSE64_WTD( const DistParam &rcDtParam ) +Distortion RdCost::xGetSSE64_WTD( const DistParam &rcDtParam ) const { if( rcDtParam.applyWeight ) { diff --git a/source/Lib/CommonLib/RdCost.h b/source/Lib/CommonLib/RdCost.h index 51ee4113532f64c3e5c366599106ce02bf34a759..f111375844b10c44e86102f9d5269808176bd741 100644 --- a/source/Lib/CommonLib/RdCost.h +++ b/source/Lib/CommonLib/RdCost.h @@ -61,7 +61,7 @@ using DistFunc = std::function<Distortion(const DistParam &)>; #if WCG_EXT class RdCost; -using DistFuncWtd = std::function<Distortion(RdCost *, const DistParam &)>; +using DistFuncWtd = std::function<Distortion(const RdCost *, const DistParam &)>; #endif // ==================================================================================================================== // Class definition @@ -388,15 +388,15 @@ private: static Distortion xGetSSE16N ( const DistParam& pcDtParam ); #if WCG_EXT - Distortion getWeightedMSE(int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma); - Distortion xGetSSE_WTD ( const DistParam& pcDtParam ); - Distortion xGetSSE2_WTD ( const DistParam& pcDtParam ); - Distortion xGetSSE4_WTD ( const DistParam& pcDtParam ); - Distortion xGetSSE8_WTD ( const DistParam& pcDtParam ); - Distortion xGetSSE16_WTD ( const DistParam& pcDtParam ); - Distortion xGetSSE32_WTD ( const DistParam& pcDtParam ); - Distortion xGetSSE64_WTD ( const DistParam& pcDtParam ); - Distortion xGetSSE16N_WTD ( const DistParam& pcDtParam ); + Distortion getWeightedMSE(const int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma) const; + Distortion xGetSSE_WTD ( const DistParam& pcDtParam ) const; + Distortion xGetSSE2_WTD ( const DistParam& pcDtParam ) const; + Distortion xGetSSE4_WTD ( const DistParam& pcDtParam ) const; + Distortion xGetSSE8_WTD ( const DistParam& pcDtParam ) const; + Distortion xGetSSE16_WTD ( const DistParam& pcDtParam ) const; + Distortion xGetSSE32_WTD ( const DistParam& pcDtParam ) const; + Distortion xGetSSE64_WTD ( const DistParam& pcDtParam ) const; + Distortion xGetSSE16N_WTD ( const DistParam& pcDtParam ) const; #endif static Distortion xGetSAD ( const DistParam& pcDtParam ); @@ -476,10 +476,10 @@ public: #if WCG_EXT Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, DFuncWtd distFuncWtd, - const CPelBuf &orgLuma); + const CPelBuf &orgLuma) const; #endif Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, - DFunc distFunc); + DFunc distFunc) const; Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, const Pel *mask, int bitDepth, const ComponentID compID, DFunc distFunc);