diff --git a/README.md b/README.md index 0145c06178e1cb1d7dbf76a7fc51860459c95d32..976f54484543dd8c7233a4cc14ec91a12dd0b080 100644 --- a/README.md +++ b/README.md @@ -399,3 +399,11 @@ To specify model paths, use e.g. following command lines. Note that model paths --NnlfSet1IntraLumaModel="models/NnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl" --NnlfSet1IntraChromaModel="models/NnlfSet1_ChromaCNNFilter_IntraSlice_int16.sadl" + +NN-based loop filter encoder optimization +---------------------------------------------- +To activate NN-based loop filter encoder optimization, use --EncNnlfOpt=1. Note that the encoder optimization is disabled by default. + +To specify model paths, use e.g. following command lines. Note that model paths should be specified at encoder. +--RdoCnnlfInterLumaModel="models/RdNnlfSet1_LumaCNNFilter_InterSlice_int16.sadl" +--RdoCnnlfIntraLumaModel="models/RdNnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl" diff --git a/models/RdNnlfSet1_LumaCNNFilter_InterSlice_int16.sadl b/models/RdNnlfSet1_LumaCNNFilter_InterSlice_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..92c6d6541aa00117a042dc4394c021b77dae3af8 Binary files /dev/null and b/models/RdNnlfSet1_LumaCNNFilter_InterSlice_int16.sadl differ diff --git a/models/RdNnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl b/models/RdNnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..f4bee004ef8bdd5740f7850306efa65640d272ec Binary files /dev/null and b/models/RdNnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl differ diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp index b34066ca6f2ddd8594ed3df313c3c0d256830be8..69d6bcf2798d1ddce51bc0c868c6b7ca6fd5fa0c 100644 --- a/source/App/EncoderApp/EncApp.cpp +++ b/source/App/EncoderApp/EncApp.cpp @@ -248,6 +248,13 @@ void EncApp::xInitLibCfg() m_cEncLib.setNnlfSet1InterChromaModelName (m_nnlfSet1InterChromaModelName); m_cEncLib.setNnlfSet1IntraLumaModelName (m_nnlfSet1IntraLumaModelName); m_cEncLib.setNnlfSet1IntraChromaModelName (m_nnlfSet1IntraChromaModelName); +#endif +#if JVET_AB0068_RD + m_cEncLib.setUseEncNnlfOpt (m_encNnlfOpt); + m_cEncLib.setRdoCnnlfInterLumaModelName (m_rdoCnnlfInterLumaModelName); + m_cEncLib.setRdoCnnlfInterChromaModelName (m_rdoCnnlfInterChromaModelName); + m_cEncLib.setRdoCnnlfIntraLumaModelName (m_rdoCnnlfIntraLumaModelName); + m_cEncLib.setRdoCnnlfIntraChromaModelName (m_rdoCnnlfIntraChromaModelName); #endif m_cEncLib.setProfile ( m_profile); m_cEncLib.setLevel ( m_levelTier, m_level); diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp index 14987766a1876dbe36bb2353b982f3432ff19981..875d6934f37ddbcea01329f765e9c5091cbcd9d8 100644 --- a/source/App/EncoderApp/EncAppCfg.cpp +++ b/source/App/EncoderApp/EncAppCfg.cpp @@ -1429,6 +1429,13 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) ( "NnlfSet1InterChromaModel", m_nnlfSet1InterChromaModelName, string("models/NnlfSet1_ChromaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter chroma model name") ( "NnlfSet1IntraLumaModel", m_nnlfSet1IntraLumaModelName, string("models/NnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra luma model name") ( "NnlfSet1IntraChromaModel", m_nnlfSet1IntraChromaModelName, string("models/NnlfSet1_ChromaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra chroma model name") +#endif +#if JVET_AB0068_RD + ( "EncNnlfOpt", m_encNnlfOpt, false, "Encoder optimization with NN-based loop filter") + ( "RdoCnnlfInterLumaModel", m_rdoCnnlfInterLumaModelName, string("models/RdNnlfSet1_LumaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter luma model name for RDO") + ( "RdoCnnlfInterChromaModel", m_rdoCnnlfInterChromaModelName, string("models/RdNnlfSet1_ChromaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter chroma model name for RDO") + ( "RdoCnnlfIntraLumaModel", m_rdoCnnlfIntraLumaModelName, string("models/RdNnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra luma model name for RDO") + ( "RdoCnnlfIntraChromaModel", m_rdoCnnlfIntraChromaModelName, string("models/RdNnlfSet1_ChromaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra chroma model name for RDO") #endif ( "RPR", m_rprEnabledFlag, true, "Reference Sample Resolution" ) ( "ScalingRatioHor", m_scalingRatioHor, 1.0, "Scaling ratio in hor direction" ) @@ -4039,6 +4046,9 @@ void EncAppCfg::xPrintParameter() #endif #if NN_FILTERING_SET_1 msg( VERBOSE, "NNLFSET1:%d ", (m_nnlfSet1)?(1):(0)); +#endif +#if JVET_AB0068_RD + msg( VERBOSE, "EncNnlfOpt:%d ", m_encNnlfOpt ? 1 : 0); #endif msg( VERBOSE, "SAO:%d ", (m_bUseSAO)?(1):(0)); msg( VERBOSE, "ALF:%d ", m_alf ? 1 : 0 ); diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h index f914e22191446c72e4657d9bbe7b69ef74414a0a..62c09f67e019d39b28f7702d74a947764165de21 100644 --- a/source/App/EncoderApp/EncAppCfg.h +++ b/source/App/EncoderApp/EncAppCfg.h @@ -95,6 +95,12 @@ protected: std::string m_nnlfSet1IntraLumaModelName; ///<intra luma nnlf set1 model std::string m_nnlfSet1IntraChromaModelName; ///<inra chroma nnlf set1 model #endif +#if JVET_AB0068_RD + std::string m_rdoCnnlfInterLumaModelName; ///<inter luma cnnlf model + std::string m_rdoCnnlfInterChromaModelName; ///<inter chroma cnnlf model + std::string m_rdoCnnlfIntraLumaModelName; ///<intra luma cnnlf model + std::string m_rdoCnnlfIntraChromaModelName; ///<inra chroma cnnlf model +#endif // Lambda modifiers double m_adLambdaModifier[ MAX_TLAYER ]; ///< Lambda modifier array for each temporal layer @@ -729,6 +735,10 @@ protected: unsigned m_nnlfSet1MaxNumParams; #endif +#if JVET_AB0068_RD + bool m_encNnlfOpt; +#endif + bool m_rprEnabledFlag; double m_scalingRatioHor; double m_scalingRatioVer; diff --git a/source/Lib/CommonLib/CodingStructure.cpp b/source/Lib/CommonLib/CodingStructure.cpp index d7f0aba8e5de3495c4fe3135b6df27bb4f101378..47eaef7eb94edc240f9d536a513d7a3bf362a3f4 100644 --- a/source/Lib/CommonLib/CodingStructure.cpp +++ b/source/Lib/CommonLib/CodingStructure.cpp @@ -1177,6 +1177,10 @@ void CodingStructure::initSubStructure( CodingStructure& subStruct, const Channe subStruct.useDbCost = false; subStruct.costDbOffset = 0; +#if JVET_AB0068_RD + subStruct.useNnCost = false; + subStruct.costNnOffset = 0; +#endif for( uint32_t i = 0; i < subStruct.area.blocks.size(); i++ ) { @@ -1474,6 +1478,10 @@ void CodingStructure::initStructData( const int &QP, const bool &skipMotBuf ) lumaCost = MAX_DOUBLE; costDbOffset = 0; useDbCost = false; +#if JVET_AB0068_RD + costNnOffset = 0; + useNnCost = false; +#endif interHad = std::numeric_limits<Distortion>::max(); } diff --git a/source/Lib/CommonLib/CodingStructure.h b/source/Lib/CommonLib/CodingStructure.h index 8114fb225c03f5ff9b94443932fa6c20e066ad5f..f8fb1d4aa7d9747f7fa866bdeba9c51abdfca8ad 100644 --- a/source/Lib/CommonLib/CodingStructure.h +++ b/source/Lib/CommonLib/CodingStructure.h @@ -178,6 +178,10 @@ public: double cost; bool useDbCost; double costDbOffset; +#if JVET_AB0068_RD + bool useNnCost; + double costNnOffset; +#endif double lumaCost; uint64_t fracBits; Distortion dist; diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 566508253db1e9eec245d119923da7c96d3e1943..4fb0ed3a2c535e2281cb302b9f21694f2aae1278 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -94,6 +94,9 @@ using TypeSadl = float; #define BYPASS_INTER_SLICE 0 // only used for training data generation #endif + +#define JVET_AB0068_RD 1 // JVET-AB0068: EE1-1.6: RDO Considering Deep In-Loop Filtering + //########### place macros to be removed in next cycle below this line ############### #define JVET_V0056 1 // MCTF changes as presented in JVET-V0056 diff --git a/source/Lib/EncoderLib/CMakeLists.txt b/source/Lib/EncoderLib/CMakeLists.txt index 9b3cf90c6434d6ccf5c7ab14be4671dfec7edd6f..a6d6f09437c040617066d5a9af1289683ac970ec 100644 --- a/source/Lib/EncoderLib/CMakeLists.txt +++ b/source/Lib/EncoderLib/CMakeLists.txt @@ -61,6 +61,17 @@ if( CMAKE_COMPILER_IS_GNUCC ) # this is quite certainly a compiler problem set_property( SOURCE "EncCu.cpp" APPEND PROPERTY COMPILE_FLAGS "-Wno-array-bounds" ) endif() + +if( MSVC ) + set_property( SOURCE EncNNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ") +elseif( UNIX OR MINGW ) + if( NNLF_BUILD_WITH_AVX512 STREQUAL "1" ) + set_property( SOURCE EncNNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw") + else() + set_property( SOURCE EncNNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2") + endif() +endif() + # example: place header files in different folders source_group( "Natvis Files" FILES ${NATVIS_FILES} ) diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h index 61601cec4d7f67c9f7a3c9252ee26de2a463fae5..feb426be5c382c1975658c85eb35a2affbd2ddd7 100644 --- a/source/Lib/EncoderLib/EncCfg.h +++ b/source/Lib/EncoderLib/EncCfg.h @@ -163,6 +163,13 @@ protected: std::string m_nnlfSet1InterChromaModelName; ///<inter chroma nnlf set1 model std::string m_nnlfSet1IntraLumaModelName; ///<intra luma nnlf set1 model std::string m_nnlfSet1IntraChromaModelName; ///<inra chroma nnlf set1 model +#endif +#if JVET_AB0068_RD + bool m_encNnlfOpt; + std::string m_rdoCnnlfInterLumaModelName; ///<inter luma cnnlf model + std::string m_rdoCnnlfInterChromaModelName; ///<inter chroma cnnlf model + std::string m_rdoCnnlfIntraLumaModelName; ///<intra luma cnnlf model + std::string m_rdoCnnlfIntraChromaModelName; ///<inra chroma cnnlf model #endif int m_iFrameRate; int m_FrameSkip; @@ -811,6 +818,18 @@ public: void setNnlfSet1IntraLumaModelName(std::string s) { m_nnlfSet1IntraLumaModelName = s; } void setNnlfSet1IntraChromaModelName(std::string s) { m_nnlfSet1IntraChromaModelName = s; } #endif +#if JVET_AB0068_RD + std::string getRdoCnnlfInterLumaModelName() { return m_rdoCnnlfInterLumaModelName; } + std::string getRdoCnnlfInterChromaModelName() { return m_rdoCnnlfInterChromaModelName; } + std::string getRdoCnnlfIntraLumaModelName() { return m_rdoCnnlfIntraLumaModelName; } + std::string getRdoCnnlfIntraChromaModelName() { return m_rdoCnnlfIntraChromaModelName; } + void setRdoCnnlfInterLumaModelName(std::string s) { m_rdoCnnlfInterLumaModelName = s; } + void setRdoCnnlfInterChromaModelName(std::string s) { m_rdoCnnlfInterChromaModelName = s; } + void setRdoCnnlfIntraLumaModelName(std::string s) { m_rdoCnnlfIntraLumaModelName = s; } + void setRdoCnnlfIntraChromaModelName(std::string s) { m_rdoCnnlfIntraChromaModelName = s; } + bool getUseEncNnlfOpt() { return m_encNnlfOpt; }; + void setUseEncNnlfOpt(bool b) { m_encNnlfOpt = b; }; +#endif void setProfile(Profile::Name profile) { m_profile = profile; } void setLevel(Level::Tier tier, Level::Name level) { m_levelTier = tier; m_level = level; } diff --git a/source/Lib/EncoderLib/EncCu.cpp b/source/Lib/EncoderLib/EncCu.cpp index b875909a93575deba535718a6210149211092cc9..77f2e285e313a987019023118fd349ffb3730936 100644 --- a/source/Lib/EncoderLib/EncCu.cpp +++ b/source/Lib/EncoderLib/EncCu.cpp @@ -234,6 +234,9 @@ void EncCu::init( EncLib* pcEncLib, const SPS& sps PARL_PARAM( const int tId ) ) #if ENABLE_SPLIT_PARALLELISM m_pcEncLib = pcEncLib; m_dataId = tId; +#endif +#if JVET_AB0068_RD + m_pcCNNFilter = pcEncLib->getGOPEncoder()->getEncCnnFilter(); #endif m_pcLoopFilter = pcEncLib->getLoopFilter(); m_GeoCostList.init(GEO_NUM_PARTITION_MODE, m_pcEncCfg->getMaxNumGeoCand()); @@ -1328,6 +1331,10 @@ void EncCu::xCheckModeSplit(CodingStructure *&tempCS, CodingStructure *&bestCS, tempCS->cost = MAX_DOUBLE; tempCS->costDbOffset = 0; tempCS->useDbCost = m_pcEncCfg->getUseEncDbOpt(); +#if JVET_AB0068_RD + tempCS->costNnOffset = 0; + tempCS->useNnCost = false; +#endif m_CurrCtx--; partitioner.exitCurrSplit(); xCheckBestMode( tempCS, bestCS, partitioner, encTestMode ); @@ -1370,6 +1377,10 @@ void EncCu::xCheckModeSplit(CodingStructure *&tempCS, CodingStructure *&bestCS, tempCS->cost = MAX_DOUBLE; tempCS->costDbOffset = 0; tempCS->useDbCost = m_pcEncCfg->getUseEncDbOpt(); +#if JVET_AB0068_RD + tempCS->costNnOffset = 0; + tempCS->useNnCost = false; +#endif m_CurrCtx--; partitioner.exitCurrSplit(); if( partitioner.chType == CHANNEL_TYPE_LUMA ) @@ -1559,6 +1570,85 @@ void EncCu::xCheckModeSplit(CodingStructure *&tempCS, CodingStructure *&bestCS, } } +#if JVET_AB0068_RD + UnitArea unitArea = clipArea(bestCS->area, *bestCS->picture); + + bool useEncNnOpt = false; + if (m_pcCNNFilter->m_NumModelsValid > 0 && partitioner.chType == CHANNEL_TYPE_LUMA && bestCS->cost != MAX_DOUBLE && tempCS->cost != MAX_DOUBLE) + { + int lumaWidth = unitArea.lwidth(); + int lumaHeight = unitArea.lheight(); + + if (lumaWidth <= 64 && lumaHeight <= 64) + { + useEncNnOpt = true; + } + + if (useEncNnOpt) + { + double ratio = 1.005; + if (bestCS->slice->isIntra()) + { + ratio = 1.05; + } + if (m_pcEncCfg->getIntraPeriod() == 1) + { + ratio = 1.01; + } + if ((bestCS->cost > tempCS->cost * ratio) || (tempCS->cost > bestCS->cost * ratio)) + { + useEncNnOpt = false; + } + } + } + + if (useEncNnOpt && (!bestCS->useNnCost || !tempCS->useNnCost)) + { + CodingStructure* nnCS0 = bestCS; + CodingStructure* nnCS1 = tempCS; + + if (bestCS->useNnCost && !tempCS->useNnCost) + { + nnCS0 = bestCS; + nnCS1 = tempCS; + } + else if (!bestCS->useNnCost && tempCS->useNnCost) + { + nnCS0 = tempCS; + nnCS1 = bestCS; + } + else + { + if (bestCS->cost >= tempCS->cost) + { + nnCS0 = bestCS; + nnCS1 = tempCS; + } + else + { + nnCS0 = tempCS; + nnCS1 = bestCS; + } + } + + xCheckCnnlf(*nnCS0, unitArea); + + bool bCheckNnCS1 = true; + if (nnCS0->useNnCost) + { + if ((nnCS0->cost + nnCS0->costNnOffset) > nnCS1->cost) + { + bCheckNnCS1 = false; + } + } + + if (bCheckNnCS1) + { + xCheckCnnlf(*nnCS1, unitArea); + } + } +#endif + // RD check for sub partitioned coding structure. xCheckBestMode( tempCS, bestCS, partitioner, encTestMode ); @@ -1578,6 +1668,98 @@ void EncCu::xCheckModeSplit(CodingStructure *&tempCS, CodingStructure *&bestCS, tempCS->prevQP[partitioner.chType] = oldPrevQp; } +#if JVET_AB0068_RD +void EncCu::xCheckCnnlf(CodingStructure& cs, UnitArea unitArea) +{ + if (cs.useNnCost) + { + return; + } + + CHECK(cs.costNnOffset != 0, "The costNnOffset of cs should be zero when useNnCost is false"); + + Picture* pic = cs.picture; + + Distortion orgDistAll = 0; + Distortion orgDistComp[3] = { 0,0,0 }; + for (uint32_t comp = 0; comp < 1; comp++) + { + const ComponentID compID = ComponentID(comp); + if (!cs.cus.at(0)->blocks[comp].valid()) + { + continue; + } + orgDistComp[comp] = getDistortionDb(cs, cs.getOrgBuf(unitArea.blocks[compID]), cs.getRecoBuf(unitArea.blocks[compID]), compID, unitArea.blocks[compID], false); + orgDistAll += orgDistComp[comp]; + } + + CompArea tempLumaArea = unitArea.Y(); + tempLumaArea.repositionTo(Position(0, 0)); + + PelBuf bufRec = m_pcCNNFilter->m_tempRecBuf [COMPONENT_Y].getBuf(tempLumaArea); + PelBuf bufPred = m_pcCNNFilter->m_tempPredBuf [COMPONENT_Y].getBuf(tempLumaArea); + PelBuf bufSplit = m_pcCNNFilter->m_tempSplitBuf[COMPONENT_Y].getBuf(tempLumaArea); + bufRec.copyFrom(cs.getRecoBuf(unitArea.Y())); + bufPred.copyFrom(cs.getPredBufCustom(unitArea.Y())); + + bool lmcsFlag = m_pcEncCfg->getLmcs() && cs.sps->getUseLmcs() && cs.slice->getLmcsEnabledFlag() && m_pcReshape->getSliceReshaperInfo().getUseSliceReshaper(); + if (lmcsFlag) + { + lmcsFlag = m_pcReshape->getCTUFlag() || cs.slice->isIntra(); + } + + uint64_t cuLength = cs.cus.size(); + for (uint64_t n = 0; n < cuLength; n++) + { + CodingUnit* cu = cs.cus.at(n); + + if (!cu->Y().valid()) + { + continue; + } + + const Position subCuPos = cu->lumaPos() - unitArea.lumaPos(); + const Size subCuSize = cu->lumaSize(); + PelBuf subBufRec = bufRec. subBuf(subCuPos, subCuSize); + PelBuf subBufPred = bufPred. subBuf(subCuPos, subCuSize); + PelBuf subBufSplit = bufSplit.subBuf(subCuPos, subCuSize); + + if (lmcsFlag) + { + if (((cu->predMode == MODE_INTRA || cu->predMode == MODE_IBC) && cu->chType != CHANNEL_TYPE_CHROMA) || (cu->predMode == MODE_INTER && m_pcReshape->getCTUFlag() && cu->firstPU->ciipFlag)) + { + subBufPred.rspSignal(m_pcReshape->getInvLUT()); + } + subBufRec.rspSignal(m_pcReshape->getInvLUT()); + } + subBufSplit.fill(subBufRec.computeAvg()); + } + + for (int modelIdx = 0; modelIdx < min(1, m_pcCNNFilter->m_NumModelsValid); modelIdx++) + { + m_pcCNNFilter->cnnFilterLumaBlockRd_ext(pic, unitArea, 0, 0, 0, 0, modelIdx, pic->slices[0]->getSliceType() != I_SLICE); + + PelBuf bufDst = m_pcCNNFilter->m_tempBuf[modelIdx].getBuf(unitArea).get(COMPONENT_Y); + + Distortion cnnDistAll = 0; + Distortion cnnDistComp[3] = { 0,0,0 }; + for (uint32_t comp = 0; comp < 1; comp++) + { + const ComponentID compID = ComponentID(comp); + if (!cs.cus.at(0)->blocks[comp].valid()) + { + continue; + } + cnnDistComp[comp] = getDistortionDb(cs, cs.getOrgBuf(unitArea.blocks[compID]), bufDst, compID, unitArea.blocks[compID], true); + cnnDistAll += cnnDistComp[comp]; + } + + cs.useNnCost = true; + cs.costNnOffset = min(double(cs.costNnOffset), (m_pcRdCost->calcRdCost(0, cnnDistComp[COMPONENT_Y]) - m_pcRdCost->calcRdCost(0, orgDistComp[COMPONENT_Y]))); + } +} +#endif + bool EncCu::xCheckRDCostIntra(CodingStructure *&tempCS, CodingStructure *&bestCS, Partitioner &partitioner, const EncTestMode& encTestMode, bool adaptiveColorTrans) { double bestInterCost = m_modeCtrl->getBestInterCost(); @@ -4168,7 +4350,11 @@ Distortion EncCu::getDistortionDb( CodingStructure &cs, CPelBuf org, CPelBuf rec Distortion dist = 0; #if WCG_EXT m_pcRdCost->setChromaFormat(cs.sps->getChromaFormatIdc()); +#if JVET_AB0068_RD + CPelBuf orgLuma = cs.picture->getOrigBuf(clipArea(cs.area.blocks[COMPONENT_Y], cs.picture->blocks[COMPONENT_Y])); +#else CPelBuf orgLuma = cs.picture->getOrigBuf( cs.area.blocks[COMPONENT_Y] ); +#endif if (m_pcEncCfg->getLumaLevelToDeltaQPMapping().isEnabled() || ( m_pcEncCfg->getLmcs() && (cs.slice->getLmcsEnabledFlag() && m_pcReshape->getCTUFlag()))) { @@ -4503,6 +4689,9 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS if( bestCost == bestCS->cost ) //The first EMT pass didn't become the bestCS, so we clear the TUs generated { tempCS->clearTUs(); +#if JVET_AB0068_RD + cu = tempCS->getCU(partitioner.chType); +#endif } else if( false == swapped ) { @@ -4526,6 +4715,10 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS tempCS->cost = MAX_DOUBLE; cu->skip = false; +#if JVET_AB0068_RD + tempCS->useNnCost = false; + tempCS->costNnOffset = 0; +#endif //set SBT info cu->setSbtIdx( sbtIdx ); cu->setSbtPos( sbtPos ); diff --git a/source/Lib/EncoderLib/EncCu.h b/source/Lib/EncoderLib/EncCu.h index 07d4848fa5423df234e300b69f652635cbd80326..6cb93403b3418336bd74b75fc3fc36c3008475ac 100644 --- a/source/Lib/EncoderLib/EncCu.h +++ b/source/Lib/EncoderLib/EncCu.h @@ -55,6 +55,9 @@ #include "InterSearch.h" #include "RateCtrl.h" #include "EncModeCtrl.h" +#if JVET_AB0068_RD +#include "EncoderLib/EncNNFilterSet1.h" +#endif //! \ingroup EncoderLib //! \{ @@ -181,6 +184,9 @@ private: TrQuant* m_pcTrQuant; RdCost* m_pcRdCost; EncSlice* m_pcSliceEncoder; +#if JVET_AB0068_RD + EncNNFilterSet1* m_pcCNNFilter; +#endif LoopFilter* m_pcLoopFilter; CABACWriter* m_CABACEstimator; @@ -257,6 +263,9 @@ protected: xCheckBestMode ( CodingStructure *&tempCS, CodingStructure *&bestCS, Partitioner &pm, const EncTestMode& encTestmode ); void xCheckModeSplit ( CodingStructure *&tempCS, CodingStructure *&bestCS, Partitioner &pm, const EncTestMode& encTestMode, const ModeType modeTypeParent, bool &skipInterPass ); +#if JVET_AB0068_RD + void xCheckCnnlf ( CodingStructure& cs, UnitArea unitArea ); +#endif bool xCheckRDCostIntra(CodingStructure *&tempCS, CodingStructure *&bestCS, Partitioner &pm, const EncTestMode& encTestMode, bool adaptiveColorTrans); diff --git a/source/Lib/EncoderLib/EncGOP.cpp b/source/Lib/EncoderLib/EncGOP.cpp index b4e2f66721e4a7a21a044faab8f118b596e2e801..d2bd5f5693b434bf0479a1c04700d4093f2c6197 100644 --- a/source/Lib/EncoderLib/EncGOP.cpp +++ b/source/Lib/EncoderLib/EncGOP.cpp @@ -204,6 +204,9 @@ void EncGOP::destroy() } #if NN_FILTERING_SET_1 m_pcNNFilterSet1.destroy(); +#if JVET_AB0068_RD + m_pcNNFilterSet1.destroyEnc(); +#endif #endif } @@ -232,6 +235,12 @@ void EncGOP::init ( EncLib* pcEncLib ) #if NN_FILTERING_SET_1 m_pcNNFilterSet1.create(m_pcCfg->getSourceWidth(), m_pcCfg->getSourceHeight(), m_pcCfg->getChromaFormatIdc(), m_pcCfg->getNnlfSet1MaxNumParams()); +#if JVET_AB0068_RD + if (m_pcCfg->getUseEncNnlfOpt()) + { + m_pcNNFilterSet1.createEnc(m_pcCfg->getMaxCUWidth(), m_pcCfg->getMaxCUHeight(), m_pcCfg->getChromaFormatIdc()); + } +#endif #endif #if WCG_EXT @@ -2820,6 +2829,9 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, { m_pcSliceEncoder->setJointCbCrModes(*pcPic->cs, Position(0, 0), pcPic->cs->area.lumaSize()); } +#if JVET_AB0068_RD + m_pcNNFilterSet1.initEnc(m_pcEncLib->getUseEncNnlfOpt() ? 1 : 0, m_pcEncLib->getRdoCnnlfInterLumaModelName(), m_pcEncLib->getRdoCnnlfInterChromaModelName(), m_pcEncLib->getRdoCnnlfIntraLumaModelName(), m_pcEncLib->getRdoCnnlfIntraChromaModelName()); +#endif if( encPic ) // now compress (trial encode) the various slice segments (slices, and dependent slices) { diff --git a/source/Lib/EncoderLib/EncGOP.h b/source/Lib/EncoderLib/EncGOP.h index ce428c792e65126f11aa58cfda409b035e2f1018..ecd77be4d79fe5c29df315aa003b6cef1b161e7f 100644 --- a/source/Lib/EncoderLib/EncGOP.h +++ b/source/Lib/EncoderLib/EncGOP.h @@ -250,6 +250,10 @@ public: void setLastLTRefPoc(int iLastLTRefPoc) { m_lastLTRefPoc = iLastLTRefPoc; } int getLastLTRefPoc() const { return m_lastLTRefPoc; } +#if JVET_AB0068_RD + EncNNFilterSet1* getEncCnnFilter() { return &m_pcNNFilterSet1; } +#endif + void printOutSummary( uint32_t uiNumAllPicCoded, bool isField, const bool printMSEBasedSNR, const bool printSequenceMSE, const bool printMSSSIM, const bool printHexPsnr, const bool printRprPSNR, const BitDepths &bitDepths ); #if W0038_DB_OPT diff --git a/source/Lib/EncoderLib/EncModeCtrl.cpp b/source/Lib/EncoderLib/EncModeCtrl.cpp index 85850014726ae9484f59d198da580a868100ac79..8cda4c889972529ed981a2c29ae5b0677f091773 100644 --- a/source/Lib/EncoderLib/EncModeCtrl.cpp +++ b/source/Lib/EncoderLib/EncModeCtrl.cpp @@ -2097,7 +2097,35 @@ bool EncModeCtrlMTnoRQT::useModeResult( const EncTestMode& encTestmode, CodingSt } // for now just a simple decision based on RD-cost or choose tempCS if bestCS is not yet coded +#if JVET_AB0068_RD + bool useTemp = false; + if (tempCS->features[ENC_FT_RD_COST] != MAX_DOUBLE) + { + if (cuECtx.bestCS) + { + double extraTempCost = 0; + double extraBestCost = 0; + if (tempCS->useNnCost && cuECtx.bestCS->useNnCost) + { + extraTempCost = tempCS->costNnOffset; + extraBestCost = cuECtx.bestCS->costNnOffset; + } + else if (tempCS->useDbCost) + { + extraTempCost = tempCS->costDbOffset; + extraBestCost = cuECtx.bestCS->costDbOffset; + } + useTemp = ((tempCS->features[ENC_FT_RD_COST] + extraTempCost) < (cuECtx.bestCS->features[ENC_FT_RD_COST] + extraBestCost)); + } + else + { + useTemp = true; + } + } + if (useTemp) +#else if( tempCS->features[ENC_FT_RD_COST] != MAX_DOUBLE && ( !cuECtx.bestCS || ( ( tempCS->features[ENC_FT_RD_COST] + ( tempCS->useDbCost ? tempCS->costDbOffset : 0 ) ) < ( cuECtx.bestCS->features[ENC_FT_RD_COST] + ( tempCS->useDbCost ? cuECtx.bestCS->costDbOffset : 0 ) ) ) ) ) +#endif { cuECtx.bestCS = tempCS; cuECtx.bestCU = tempCS->cus[0]; diff --git a/source/Lib/EncoderLib/EncNNFilterSet1.cpp b/source/Lib/EncoderLib/EncNNFilterSet1.cpp index aba35199c5b28fd0edb382bf458c80b4150881e6..72c4b9d77060c971b0b22c7f481d5ef4d53020fb 100644 --- a/source/Lib/EncoderLib/EncNNFilterSet1.cpp +++ b/source/Lib/EncoderLib/EncNNFilterSet1.cpp @@ -50,6 +50,9 @@ #include <ctime> #include <algorithm> #include <numeric> +#include <fstream> +#include <sadl/model.h> +using namespace std; //! \ingroup CommonLib //! \{ @@ -64,6 +67,307 @@ EncNNFilterSet1::EncNNFilterSet1() EncNNFilterSet1::~EncNNFilterSet1() { } + +#if JVET_AB0068_RD +void EncNNFilterSet1::createEnc(const int maxCuWidth, const int maxCuHeight, const ChromaFormat format) +{ + if (m_tempRecBuf[0].getOrigin(0) == NULL) + { + m_tempRecBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); + } + if (!m_tempPredBuf[0].getOrigin(0)) + { + m_tempPredBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); + } + if (!m_tempSplitBuf[0].getOrigin(0)) + { + m_tempSplitBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); + } +} +void EncNNFilterSet1::destroyEnc() +{ + m_tempRecBuf[0].destroy(); + m_tempPredBuf[0].destroy(); + m_tempSplitBuf[0].destroy(); +} +void EncNNFilterSet1::initEnc(int numModels, std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma) +{ + m_NumModelsValid = numModels; + m_interLumaRd = interLuma; + m_interChromaRd = interChroma; + m_intraLumaRd = intraLuma; + m_intraChromaRd = intraChroma; +} + +template<typename T> +struct ModelData { + sadl::Model<T> model; + vector<sadl::Tensor<T>> inputs; + int hor,ver; + bool luma,inter; +}; + +template<typename T> +static std::vector<ModelData<T>> initSpace() { + std::vector<ModelData<T>> v; + v.reserve(5); + return v; +} + +#if NN_FIXED_POINT_IMPLEMENTATION +static std::vector<ModelData<int16_t>> modelsRd = initSpace<int16_t>(); +#else +static std::vector<ModelData<float>> modelsRd = initSpace<float>(); +#endif + +template<typename T> +static ModelData<T> &getModelRd(int ver, int hor, bool luma, bool inter, const std::string modelName) { + ModelData<T> *ptr = nullptr; + for(auto &m: modelsRd) + { + if (m.luma == luma && m.inter == inter) + { + ptr = &m; + break; + } + } + if (ptr == nullptr) + { + if (modelsRd.size() == modelsRd.capacity()) + { + std::cout << "[ERROR] RDO increase cache" << std::endl; + exit(-1); + } + modelsRd.resize(modelsRd.size()+1); + ptr = &modelsRd.back(); + ModelData<T> &m = *ptr; + ifstream file(modelName, ios::binary); + m.model.load(file); + m.luma = luma; + m.inter = inter; + m.ver = 0; + m.hor = 0; + } + ModelData<T> &m = *ptr; + if (m.ver != ver || m.hor != hor) + { + m.inputs = m.model.getInputsTemplate(); + if (luma) + { + for(auto &t: m.inputs) + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + } + else if (inter) // inter chroma + { + int inputId = 0; + for(auto &t: m.inputs) + { + if (t.dims()[3] == 1 && inputId != 3) // luma + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + else if (t.dims()[3] == 1 && inputId == 3) // qp + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 1})); + t.resize(dims); + } + else + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 2})); + t.resize(dims); + } + inputId ++; + } + } + else // intra chroma + { + int inputId = 0; + for(auto &t: m.inputs) + { + if (t.dims()[3] == 1 && inputId != 4 ) // luma + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + else if (t.dims()[3] == 1 && inputId == 4) // QP + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 1})); + t.resize(dims); + } + else + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 2})); + t.resize(dims); + } + inputId ++; + } + } + if (!m.model.init(m.inputs)) + { + cerr << "[ERROR] RDO issue during initialization" << endl; + exit(-1); + } + m.ver = ver; + m.hor = hor; + } + return m; +} + +template<typename T> +void EncNNFilterSet1::prepareInputsLumaRd (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int qp, bool inter) +{ + double inputScale = 1024; +#if NN_FIXED_POINT_IMPLEMENTATION + int shiftInput = NN_INPUT_PRECISION; +#else + int shiftInput = 0; +#endif + + CompArea tempLumaArea = inferArea.Y(); + tempLumaArea.repositionTo(Position(0, 0)); + + PelBuf bufRec = m_tempRecBuf[0].getBuf(tempLumaArea); + PelBuf bufPred = m_tempPredBuf[0].getBuf(tempLumaArea); + PelBuf bufPartition = m_tempSplitBuf[0].getBuf(tempLumaArea); + + sadl::Tensor<T>* inputRec, *inputPred, *inputPartition, *inputQp; + if (inter) + { + inputRec = &inputs[0]; + inputPred = &inputs[1]; + inputQp = &inputs[2]; + } + else + { + inputRec = &inputs[0]; + inputPred = &inputs[1]; + inputPartition = &inputs[2]; + inputQp = &inputs[3]; + } + + int hor = inferArea.lwidth(); + int ver = inferArea.lheight(); + + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + (*inputRec)(0, yy, xx, 0) = (T)(bufRec.at(xx, yy) / inputScale * (1 << shiftInput)); + (*inputPred)(0, yy, xx, 0) = (T)(bufPred.at(xx, yy) / inputScale * (1 << shiftInput)); + (*inputQp)(0, yy, xx, 0) = (T)(qp / 64.0 * (1 << shiftInput)); + if (!inter) + { + (*inputPartition)(0, yy, xx, 0) = (T)(bufPartition.at(xx, yy) / inputScale * (1 << shiftInput)); + } + } + } +} + +template<typename T> +void EncNNFilterSet1::extractOutputsLumaRd (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom) +{ +#if NN_FIXED_POINT_IMPLEMENTATION + int log2InputScale = 10; + int log2OutputScale = 10; + int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale; + int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale; + int offset = (1 << shiftOutput) / 2; +#else + double inputScale = 1024; + double outputScale = 1024; +#endif + auto output = m.result(0); + PelBuf bufDst = tempBuf.getBuf(inferArea).get(COMPONENT_Y); + + int hor = inferArea.lwidth(); + int ver = inferArea.lheight(); + +#if 1 + CompArea tempLumaArea = inferArea.Y(); + tempLumaArea.repositionTo(Position(0, 0)); + + PelBuf bufRec = m_tempRecBuf[0].getBuf(tempLumaArea); +#else +#if FUSE_NN_AND_LF + PelBuf bufRec = pic->getUnfilteredRecBuf(inferArea).get(COMPONENT_Y); +#else + PelBuf bufRec = pic->getRecoBuf(inferArea).get(COMPONENT_Y); +#endif +#endif + + for (int c = 0; c < 4; c++) // output includes 4 sub images + { + for (int y = 0; y < ver >> 1; y++) + { + for (int x = 0; x < hor >> 1; x++) + { + int yy = (y << 1) + c / 2; + int xx = (x << 1) + c % 2; + if (xx < extLeft || yy < extTop || xx >= hor - extRight || yy >= ver - extBottom) + { + continue; + } +#if NN_FIXED_POINT_IMPLEMENTATION + int out = ( output(0, y, x, c) + (bufRec.at(xx, yy) << shiftInput) + offset) >> shiftOutput; +#else + int out = ( output(0, y, x, c) + bufRec.at(xx, yy) / inputScale ) * outputScale + 0.5; +#endif + bufDst.at(xx, yy) = Pel(Clip3<int>( 0, 1023, out)); + } + } + } +} + +template<typename T> void EncNNFilterSet1::cnnFilterLumaBlockRd(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter) +{ + //at::init_num_threads(); // use all available threads + + const int border_to_skip = 0; + if (border_to_skip>0) sadl::Tensor<float>::skip_border = true; + + // get model + ModelData<T> &m = getModelRd<T>(inferArea.lheight(), inferArea.lwidth(), true, inter, inter ? m_interLumaRd : m_intraLumaRd); + sadl::Model<T> &model = m.model; + + // get inputs + vector<sadl::Tensor<T>> &inputs = m.inputs; + + int seqQp = pic->slices[0]->getPPS()->getPicInitQPMinus26() + 26; + int sliceQp = pic->slices[0]->getSliceQp(); + int delta = inter ? paramIdx * 5 : paramIdx * 2; + if ( pic->slices[0]->getTLayer() >= 4 && paramIdx >= 2 ) + { + delta = 5 - delta; + } + int qp = inter ? seqQp - delta : sliceQp - delta; + + prepareInputsLumaRd<T>(pic, inferArea, inputs, qp, inter); + + // inference + if (!model.apply(inputs)) + { + cerr << "[ERROR] RDO issue during luma model inference" << endl; + exit(-1); + } + + // get outputs + extractOutputsLumaRd<T>(pic, model, m_tempBuf[paramIdx], inferArea, extLeft, extRight, extTop, extBottom); +} + +void EncNNFilterSet1::cnnFilterLumaBlockRd_ext(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter) +{ +#if NN_FIXED_POINT_IMPLEMENTATION + cnnFilterLumaBlockRd<int16_t>(pic, inferArea, extLeft, extRight, extTop, extBottom, paramIdx, inter); +#else + cnnFilterLumaBlockRd<float>(pic, inferArea, extLeft, extRight, extTop, extBottom, paramIdx, inter); +#endif +} +#endif + void EncNNFilterSet1::initCABACEstimator( CABACEncoder* cabacEncoder, CtxCache* ctxCache, Slice* pcSlice ) { m_CABACEstimator = cabacEncoder->getCABACEstimator( pcSlice->getSPS() ); diff --git a/source/Lib/EncoderLib/EncNNFilterSet1.h b/source/Lib/EncoderLib/EncNNFilterSet1.h index a7f9990d71e5c3d15d18e2045c43050e7ce71447..3a5b4524c2585e8ae8652faa1ba80fe33ca8f6b5 100644 --- a/source/Lib/EncoderLib/EncNNFilterSet1.h +++ b/source/Lib/EncoderLib/EncNNFilterSet1.h @@ -44,6 +44,7 @@ #include "Reshape.h" #include "CABACWriter.h" #include "CommonLib/NNFilterSet1.h" +#include <sadl/model.h> //! \ingroup CommonLib //! \{ @@ -73,6 +74,34 @@ public: #else void calcRDCost(Picture *pic,std::vector<PelStorage>& tempBuf, int numParams, double* minCost); #endif + +#if JVET_AB0068_RD + int m_NumModelsValid; + + std::string m_interLumaRd; + std::string m_interChromaRd; + std::string m_intraLumaRd; + std::string m_intraChromaRd; + + PelStorage m_tempRecBuf[MAX_NUM_COMPONENT]; + PelStorage m_tempPredBuf[MAX_NUM_COMPONENT]; + PelStorage m_tempSplitBuf[MAX_NUM_COMPONENT]; + void createEnc(const int maxCuWidth, const int maxCuHeight, const ChromaFormat format); + void destroyEnc(); + void initEnc(int numModels, std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma); + +private: + template<typename T> + void prepareInputsLumaRd(Picture* pic, UnitArea inferArea, std::vector<sadl::Tensor<T>>& inputs, int qp, bool inter); + template<typename T> + void extractOutputsLumaRd(Picture* pic, sadl::Model<T>& m, PelStorage& tempBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom); + template<typename T> + void cnnFilterLumaBlockRd(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter); + +public: + void cnnFilterLumaBlockRd_ext(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter); +#endif + }; //! \}