From cce7f95d65ebd5de54b8bad08bef4ac8ad016baa Mon Sep 17 00:00:00 2001 From: Fabrice URBAN <fabrice.urban@interdigital.com> Date: Wed, 25 Sep 2024 15:06:48 +0000 Subject: [PATCH] Make RdCost members non-`static` to use correct weighting tables for each layer in multilayer encoding --- source/Lib/CommonLib/RdCost.cpp | 69 ++++++++++++++------------- source/Lib/CommonLib/RdCost.h | 43 +++++++++-------- source/Lib/CommonLib/TypeDef.h | 30 ++++++++---- source/Lib/EncoderLib/EncCu.cpp | 14 +++--- source/Lib/EncoderLib/InterSearch.cpp | 8 ++-- source/Lib/EncoderLib/IntraSearch.cpp | 20 ++++---- 6 files changed, 102 insertions(+), 82 deletions(-) diff --git a/source/Lib/CommonLib/RdCost.cpp b/source/Lib/CommonLib/RdCost.cpp index 52167df43..fcbde85d7 100644 --- a/source/Lib/CommonLib/RdCost.cpp +++ b/source/Lib/CommonLib/RdCost.cpp @@ -59,6 +59,7 @@ RdCost::~RdCost() } #if WCG_EXT +EnumArray<DistFuncWtd, DFuncWtd> RdCost::m_distortionFuncWtd; double RdCost::calcRdCost( uint64_t fracBits, Distortion distortion, bool useUnadjustedLambda ) #else double RdCost::calcRdCost( uint64_t fracBits, Distortion distortion ) @@ -193,14 +194,14 @@ void RdCost::init() m_distortionFunc[DFunc::SAD_FULL_NBIT16N] = RdCost::xGetSAD_full; #if WCG_EXT - m_distortionFunc[DFunc::SSE_WTD] = RdCost::xGetSSE_WTD; - m_distortionFunc[DFunc::SSE2_WTD] = RdCost::xGetSSE2_WTD; - m_distortionFunc[DFunc::SSE4_WTD] = RdCost::xGetSSE4_WTD; - m_distortionFunc[DFunc::SSE8_WTD] = RdCost::xGetSSE8_WTD; - m_distortionFunc[DFunc::SSE16_WTD] = RdCost::xGetSSE16_WTD; - m_distortionFunc[DFunc::SSE32_WTD] = RdCost::xGetSSE32_WTD; - m_distortionFunc[DFunc::SSE64_WTD] = RdCost::xGetSSE64_WTD; - m_distortionFunc[DFunc::SSE16N_WTD] = RdCost::xGetSSE16N_WTD; + m_distortionFuncWtd[DFuncWtd::SSE_WTD] = &RdCost::xGetSSE_WTD; + m_distortionFuncWtd[DFuncWtd::SSE2_WTD] = &RdCost::xGetSSE2_WTD; + m_distortionFuncWtd[DFuncWtd::SSE4_WTD] = &RdCost::xGetSSE4_WTD; + m_distortionFuncWtd[DFuncWtd::SSE8_WTD] = &RdCost::xGetSSE8_WTD; + m_distortionFuncWtd[DFuncWtd::SSE16_WTD] = &RdCost::xGetSSE16_WTD; + m_distortionFuncWtd[DFuncWtd::SSE32_WTD] = &RdCost::xGetSSE32_WTD; + m_distortionFuncWtd[DFuncWtd::SSE64_WTD] = &RdCost::xGetSSE64_WTD; + m_distortionFuncWtd[DFuncWtd::SSE16N_WTD] = &RdCost::xGetSSE16N_WTD; #endif m_distortionFunc[DFunc::SAD_INTERMEDIATE_BITDEPTH] = RdCost::xGetSAD; @@ -346,11 +347,7 @@ void RdCost::setDistParam(DistParam &rcDP, const Pel *pOrg, const Pel *piRefY, p #if WCG_EXT Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, - DFunc distFunc, const CPelBuf *orgLuma) -#else - Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, - DFunc distFunc) -#endif + DFuncWtd distFuncWtd, const CPelBuf &orgLuma) { DistParam cDtParam; @@ -360,21 +357,36 @@ Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDe cDtParam.bitDepth = bitDepth; cDtParam.compID = compID; -#if WCG_EXT - if( orgLuma ) + cDtParam.cShiftX = getComponentScaleX(compID, m_cf); + cDtParam.cShiftY = getComponentScaleY(compID, m_cf); + if (isChroma(compID)) { - cDtParam.cShiftX = getComponentScaleX(compID, m_cf); - cDtParam.cShiftY = getComponentScaleY(compID, m_cf); - if( isChroma(compID) ) - { - cDtParam.orgLuma = *orgLuma; - } - else - { - cDtParam.orgLuma = org; - } + cDtParam.orgLuma = orgLuma; + } + else + { + cDtParam.orgLuma = org; + } + + cDtParam.distFuncWtd = m_distortionFuncWtd[distFuncWtd + sizeOffset<false>(org.width)]; + Distortion dist = cDtParam.distFuncWtd(this, cDtParam); + if (isChroma(compID)) + { + dist = (Distortion)(m_distortionWeight[MAP_CHROMA(compID)] * dist); } + return dist; +} #endif +Distortion RdCost::getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, + DFunc distFunc) +{ + DistParam cDtParam; + + cDtParam.org = org; + cDtParam.cur = cur; + cDtParam.step = 1; + cDtParam.bitDepth = bitDepth; + cDtParam.compID = compID; cDtParam.distFunc = m_distortionFunc[distFunc + sizeOffset<false>(org.width)]; @@ -3004,13 +3016,6 @@ Distortion RdCost::xGetHADs( const DistParam &rcDtParam ) #if WCG_EXT -uint32_t RdCost::m_signalType = RESHAPE_SIGNAL_NULL; -int32_t RdCost::m_chromaWeight = MSE_WEIGHT_ONE; -int RdCost::m_lumaBD = 10; - -std::vector<int32_t> RdCost::m_reshapeLumaLevelToWeightPLUT; -std::vector<double> RdCost::m_lumaLevelToWeightPLUT; - void RdCost::saveUnadjustedLambda() { m_dLambda_unadjusted = m_dLambda; diff --git a/source/Lib/CommonLib/RdCost.h b/source/Lib/CommonLib/RdCost.h index 8a04014da..51ee41135 100644 --- a/source/Lib/CommonLib/RdCost.h +++ b/source/Lib/CommonLib/RdCost.h @@ -59,6 +59,10 @@ class EncCfg; using DistFunc = std::function<Distortion(const DistParam &)>; +#if WCG_EXT +class RdCost; +using DistFuncWtd = std::function<Distortion(RdCost *, const DistParam &)>; +#endif // ==================================================================================================================== // Class definition // ==================================================================================================================== @@ -71,6 +75,7 @@ public: CPelBuf cur; #if WCG_EXT CPelBuf orgLuma; + DistFuncWtd distFuncWtd; #endif const Pel* mask; ptrdiff_t maskStride; @@ -117,15 +122,16 @@ private: bool m_isLosslessRDCost; #if WCG_EXT + static EnumArray<DistFuncWtd, DFuncWtd> m_distortionFuncWtd; double m_dLambda_unadjusted; // TODO: check is necessary double m_distScaleUnadjusted; - static std::vector<int32_t> m_reshapeLumaLevelToWeightPLUT; // scaled by MSE_WEIGHT_ONE - static std::vector<double> m_lumaLevelToWeightPLUT; + std::vector<int32_t> m_reshapeLumaLevelToWeightPLUT; // scaled by MSE_WEIGHT_ONE + std::vector<double> m_lumaLevelToWeightPLUT; - static int32_t m_chromaWeight; // scaled by MSE_WEIGHT_ONE - static uint32_t m_signalType; - static int m_lumaBD; + int32_t m_chromaWeight = RESHAPE_SIGNAL_NULL; // scaled by MSE_WEIGHT_ONE + uint32_t m_signalType = MSE_WEIGHT_ONE; + int m_lumaBD = 10; ChromaFormat m_cf; #endif @@ -382,15 +388,15 @@ private: static Distortion xGetSSE16N ( const DistParam& pcDtParam ); #if WCG_EXT - static Distortion getWeightedMSE(int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma); - static Distortion xGetSSE_WTD ( const DistParam& pcDtParam ); - static Distortion xGetSSE2_WTD ( const DistParam& pcDtParam ); - static Distortion xGetSSE4_WTD ( const DistParam& pcDtParam ); - static Distortion xGetSSE8_WTD ( const DistParam& pcDtParam ); - static Distortion xGetSSE16_WTD ( const DistParam& pcDtParam ); - static Distortion xGetSSE32_WTD ( const DistParam& pcDtParam ); - static Distortion xGetSSE64_WTD ( const DistParam& pcDtParam ); - static Distortion xGetSSE16N_WTD ( const DistParam& pcDtParam ); + Distortion getWeightedMSE(int compIdx, const Pel org, const Pel cur, const uint32_t shift, const Pel orgLuma); + Distortion xGetSSE_WTD ( const DistParam& pcDtParam ); + Distortion xGetSSE2_WTD ( const DistParam& pcDtParam ); + Distortion xGetSSE4_WTD ( const DistParam& pcDtParam ); + Distortion xGetSSE8_WTD ( const DistParam& pcDtParam ); + Distortion xGetSSE16_WTD ( const DistParam& pcDtParam ); + Distortion xGetSSE32_WTD ( const DistParam& pcDtParam ); + Distortion xGetSSE64_WTD ( const DistParam& pcDtParam ); + Distortion xGetSSE16N_WTD ( const DistParam& pcDtParam ); #endif static Distortion xGetSAD ( const DistParam& pcDtParam ); @@ -469,12 +475,11 @@ private: public: #if WCG_EXT - Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, DFunc distFunc, - const CPelBuf *orgLuma = nullptr); -#else - Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, - DFunc distFunc); + Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, DFuncWtd distFuncWtd, + const CPelBuf &orgLuma); #endif + Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, int bitDepth, const ComponentID compID, + DFunc distFunc); Distortion getDistPart(const CPelBuf &org, const CPelBuf &cur, const Pel *mask, int bitDepth, const ComponentID compID, DFunc distFunc); diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 1685cdd50..c495d24db 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -613,16 +613,6 @@ enum class DFunc SAD_FULL_NBIT64, SAD_FULL_NBIT16N, - // Weighted SSE functions by size - SSE_WTD, - SSE2_WTD, - SSE4_WTD, - SSE8_WTD, - SSE16_WTD, - SSE32_WTD, - SSE64_WTD, - SSE16N_WTD, - SAD_INTERMEDIATE_BITDEPTH, SAD_WITH_MASK, @@ -640,6 +630,26 @@ static inline DFunc operator+(const DFunc &a, const DFuncDiff &b) { return static_cast<DFunc>(to_underlying(a) + to_underlying(b)); } +#if WCG_EXT +enum class DFuncWtd +{ + // Weighted SSE functions by size + SSE_WTD, + SSE2_WTD, + SSE4_WTD, + SSE8_WTD, + SSE16_WTD, + SSE32_WTD, + SSE64_WTD, + SSE16N_WTD, + + NUM +}; +static inline DFuncWtd operator+(const DFuncWtd &a, const DFuncDiff &b) +{ + return static_cast<DFuncWtd>(to_underlying(a) + to_underlying(b)); +} +#endif /// motion vector predictor direction used in AMVP enum MvpDir diff --git a/source/Lib/EncoderLib/EncCu.cpp b/source/Lib/EncoderLib/EncCu.cpp index 14ca7aaef..50ac07031 100644 --- a/source/Lib/EncoderLib/EncCu.cpp +++ b/source/Lib/EncoderLib/EncCu.cpp @@ -4000,12 +4000,12 @@ Distortion EncCu::getDistortionDb( CodingStructure &cs, CPelBuf org, CPelBuf rec tmpRecLuma.copyFrom( reco ); tmpRecLuma.rspSignal( m_pcReshape->getInvLUT() ); dist += m_pcRdCost->getDistPart(org, tmpRecLuma, cs.sps->getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } else { - dist += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID, DFunc::SSE_WTD, - &orgLuma); + dist += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID, + DFuncWtd::SSE_WTD, orgLuma); } } else if (m_pcEncCfg->getLmcs() && cs.slice->getLmcsEnabledFlag() && cs.slice->isIntra()) //intra slice @@ -4022,8 +4022,8 @@ Distortion EncCu::getDistortionDb( CodingStructure &cs, CPelBuf org, CPelBuf rec { if ((isChroma(compID) && m_pcEncCfg->getReshapeIntraCMD())) { - dist += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID, DFunc::SSE_WTD, - &orgLuma); + dist += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID, + DFuncWtd::SSE_WTD, orgLuma); } else { @@ -4485,12 +4485,12 @@ void EncCu::xReuseCachedResult( CodingStructure *&tempCS, CodingStructure *&best tmpRecLuma.copyFrom(reco); tmpRecLuma.rspSignal(m_pcReshape->getInvLUT()); finalDistortion += m_pcRdCost->getDistPart(org, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } else { finalDistortion += m_pcRdCost->getDistPart(org, reco, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } } else diff --git a/source/Lib/EncoderLib/InterSearch.cpp b/source/Lib/EncoderLib/InterSearch.cpp index 5dfc73576..6113795b4 100644 --- a/source/Lib/EncoderLib/InterSearch.cpp +++ b/source/Lib/EncoderLib/InterSearch.cpp @@ -10622,12 +10622,12 @@ void InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &pa tmpRecLuma.copyFrom(reco); tmpRecLuma.rspSignal(m_pcReshape->getInvLUT()); distortion += m_pcRdCost->getDistPart(org, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } else { distortion += m_pcRdCost->getDistPart(org, reco, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } } else @@ -10968,12 +10968,12 @@ void InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &pa tmpRecLuma.copyFrom(reco); tmpRecLuma.rspSignal(m_pcReshape->getInvLUT()); finalDistortion += m_pcRdCost->getDistPart(org, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } else { finalDistortion += - m_pcRdCost->getDistPart(org, reco, sps.getBitDepth(toChannelType(compID)), compID, DFunc::SSE_WTD, &orgLuma); + m_pcRdCost->getDistPart(org, reco, sps.getBitDepth(toChannelType(compID)), compID, DFuncWtd::SSE_WTD, orgLuma); } } else diff --git a/source/Lib/EncoderLib/IntraSearch.cpp b/source/Lib/EncoderLib/IntraSearch.cpp index e63f5eb53..e80019ca0 100644 --- a/source/Lib/EncoderLib/IntraSearch.cpp +++ b/source/Lib/EncoderLib/IntraSearch.cpp @@ -1941,12 +1941,12 @@ void IntraSearch::PLTSearch(CodingStructure &cs, Partitioner& partitioner, Compo tmpRecLuma.copyFrom(reco); tmpRecLuma.rspSignal(m_pcReshape->getInvLUT()); distortion += m_pcRdCost->getDistPart(org, tmpRecLuma, cs.sps->getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } else { distortion += m_pcRdCost->getDistPart(org, reco, cs.sps->getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } } else @@ -3542,15 +3542,15 @@ void IntraSearch::xIntraCodingTUBlock(TransformUnit &tu, const ComponentID &comp PelBuf tmpRecLuma = m_tmpStorageCtu.getBuf(tmpArea1); tmpRecLuma.copyFrom(piReco); tmpRecLuma.rspSignal(m_pcReshape->getInvLUT()); - dist += m_pcRdCost->getDistPart(piOrg, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, DFunc::SSE_WTD, - &orgLuma); + dist += m_pcRdCost->getDistPart(piOrg, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, DFuncWtd::SSE_WTD, + orgLuma); } else { - dist += m_pcRdCost->getDistPart(piOrg, piReco, bitDepth, compID, DFunc::SSE_WTD, &orgLuma); + dist += m_pcRdCost->getDistPart(piOrg, piReco, bitDepth, compID, DFuncWtd::SSE_WTD, orgLuma); if( jointCbCr ) { - dist += m_pcRdCost->getDistPart(crOrg, crReco, bitDepth, COMPONENT_Cr, DFunc::SSE_WTD, &orgLuma); + dist += m_pcRdCost->getDistPart(crOrg, crReco, bitDepth, COMPONENT_Cr, DFuncWtd::SSE_WTD, orgLuma); } } } @@ -4859,12 +4859,12 @@ bool IntraSearch::xRecurIntraCodingACTQT(CodingStructure &cs, Partitioner &parti tmpRecLuma.copyFrom(piReco); tmpRecLuma.rspSignal(m_pcReshape->getInvLUT()); totalDist += m_pcRdCost->getDistPart(piOrg, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } else { totalDist += m_pcRdCost->getDistPart(piOrg, piReco, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } } else @@ -4976,12 +4976,12 @@ bool IntraSearch::xRecurIntraCodingACTQT(CodingStructure &cs, Partitioner &parti tmpRecLuma.copyFrom(piReco); tmpRecLuma.rspSignal(m_pcReshape->getInvLUT()); distTmp += m_pcRdCost->getDistPart(piOrg, tmpRecLuma, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } else { distTmp += m_pcRdCost->getDistPart(piOrg, piReco, sps.getBitDepth(toChannelType(compID)), compID, - DFunc::SSE_WTD, &orgLuma); + DFuncWtd::SSE_WTD, orgLuma); } } else -- GitLab