diff --git a/.gitattributes b/.gitattributes index a8cad345af0c3a01cd1c87fa235971d741c7afc9..24aeffc996c59a225939e1c9c513ce6e6d2bf8a9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -21,3 +21,7 @@ models/intra/graph_output_8_8_int16.sadl filter=lfs diff=lfs merge=lfs -text models/intra/graph_output_16_16_float.sadl filter=lfs diff=lfs merge=lfs -text models/NnlfSet0_model_float.sadl filter=lfs diff=lfs merge=lfs -text models/NnlfSet0_model_int16.sadl filter=lfs diff=lfs merge=lfs -text +models/RDO_I_y_model_int16.sadl filter=lfs diff=lfs merge=lfs -text +models/RDO_I_y_model_float.sadl filter=lfs diff=lfs merge=lfs -text +models/RDO_B_y_model_int16.sadl filter=lfs diff=lfs merge=lfs -text +models/RDO_B_y_model_float.sadl filter=lfs diff=lfs merge=lfs -text diff --git a/models/RDO_B_y_model_float.sadl b/models/RDO_B_y_model_float.sadl new file mode 100644 index 0000000000000000000000000000000000000000..d83a12ff5c608965c443ec788225287e1e518f77 --- /dev/null +++ b/models/RDO_B_y_model_float.sadl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51dee64687c334f052849eff756dd1b4882fc2ae9f8dae9d5e45eabd9801d31e +size 1122005 diff --git a/models/RDO_B_y_model_int16.sadl b/models/RDO_B_y_model_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..100c7ac12b15eff8545cda6ddffcc17dd0827b35 --- /dev/null +++ b/models/RDO_B_y_model_int16.sadl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:67071c988d1984cdbfb80fd04866fe82933cf9f5f4d9cb6efebf3921adffd93a +size 563299 diff --git a/models/RDO_I_y_model_float.sadl b/models/RDO_I_y_model_float.sadl new file mode 100644 index 0000000000000000000000000000000000000000..c664456b24ce2d6442cfb50edaed5c7c1b5184d6 --- /dev/null +++ b/models/RDO_I_y_model_float.sadl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f7341f79613f95232f20250d00d04b744df6f9699a530f2dfda6f1345d85142 +size 1121700 diff --git a/models/RDO_I_y_model_int16.sadl b/models/RDO_I_y_model_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..5d8ed496bebc22db73b94f6a93b7c63244bd0b42 --- /dev/null +++ b/models/RDO_I_y_model_int16.sadl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:81916f5bfeb0c0d99f891dabe225a4524de28722f335048c4b00ee47df262c7d +size 563122 diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp index d39de5428f18e5fcaad1e513f1b6dfb81c7d46ea..ac4aa489f3fbd6cdd18821bd3adc5321614023fa 100644 --- a/source/App/EncoderApp/EncApp.cpp +++ b/source/App/EncoderApp/EncApp.cpp @@ -252,12 +252,13 @@ void EncApp::xInitLibCfg() m_cEncLib.setNnlfSet1AlternativeInterLumaModelName (m_nnlfSet1AlternativeInterLumaModelName); #endif #endif -#if JVET_AB0068_RD - m_cEncLib.setUseEncNnlfOpt (m_encNnlfOpt); - m_cEncLib.setRdoCnnlfInterLumaModelName (m_rdoCnnlfInterLumaModelName); - m_cEncLib.setRdoCnnlfInterChromaModelName (m_rdoCnnlfInterChromaModelName); - m_cEncLib.setRdoCnnlfIntraLumaModelName (m_rdoCnnlfIntraLumaModelName); - m_cEncLib.setRdoCnnlfIntraChromaModelName (m_rdoCnnlfIntraChromaModelName); + +#if JVET_AC0328_NNLF_RDO + m_cEncLib.setUseEncNnlfOpt (m_encNnlfOpt); + m_cEncLib.setRdoCnnlfInterLumaModelNameNNFilter0 (m_rdoCnnlfInterLumaModelNameNNFilter0); + m_cEncLib.setRdoCnnlfIntraLumaModelNameNNFilter0 (m_rdoCnnlfIntraLumaModelNameNNFilter0); + m_cEncLib.setRdoCnnlfInterLumaModelNameNNFilter1 (m_rdoCnnlfInterLumaModelNameNNFilter1); + m_cEncLib.setRdoCnnlfIntraLumaModelNameNNFilter1 (m_rdoCnnlfIntraLumaModelNameNNFilter1); #endif m_cEncLib.setProfile ( m_profile); m_cEncLib.setLevel ( m_levelTier, m_level); diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp index 3829eeb278e03b851943588109ad3c675788aaa2..4c86585bf2124688b249afccddd527b34b5395c5 100644 --- a/source/App/EncoderApp/EncAppCfg.cpp +++ b/source/App/EncoderApp/EncAppCfg.cpp @@ -1448,12 +1448,18 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) ( "NnlfSet1Multiframe", m_nnlfSet1Multiframe, false, "Input multiple frames in NN-based loop filter set 1" ) #endif #endif -#if JVET_AB0068_RD - ( "EncNnlfOpt", m_encNnlfOpt, false, "Encoder optimization with NN-based loop filter") - ( "RdoCnnlfInterLumaModel", m_rdoCnnlfInterLumaModelName, string("models/RdNnlfSet1_LumaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter luma model name for RDO") - ( "RdoCnnlfInterChromaModel", m_rdoCnnlfInterChromaModelName, string("models/RdNnlfSet1_ChromaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter chroma model name for RDO") - ( "RdoCnnlfIntraLumaModel", m_rdoCnnlfIntraLumaModelName, string("models/RdNnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra luma model name for RDO") - ( "RdoCnnlfIntraChromaModel", m_rdoCnnlfIntraChromaModelName, string("models/RdNnlfSet1_ChromaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra chroma model name for RDO") + +#if JVET_AC0328_NNLF_RDO + ( "EncNnlfOpt", m_encNnlfOpt, false, "Encoder optimization with NN-based loop filter") +#if NN_FIXED_POINT_IMPLEMENTATION + ( "RdoCnnlfInterLumaModelNNFilter0", m_rdoCnnlfInterLumaModelNameNNFilter0, default_model_path + string("RDO_B_y_model_int16.sadl"), "Cnnlf inter luma model name") + ( "RdoCnnlfIntraLumaModelNNFilter0", m_rdoCnnlfIntraLumaModelNameNNFilter0, default_model_path + string("RDO_I_y_model_int16.sadl"), "Cnnlf intra luma model name") +#else + ( "RdoCnnlfInterLumaModelNNFilter0", m_rdoCnnlfInterLumaModelNameNNFilter0, default_model_path + string("RDO_B_y_model_float.sadl"), "Cnnlf inter luma model name") + ( "RdoCnnlfIntraLumaModelNNFilter0", m_rdoCnnlfIntraLumaModelNameNNFilter0, default_model_path + string("RDO_I_y_model_float.sadl"), "Cnnlf intra luma model name") +#endif + ( "RdoCnnlfInterLumaModelNNFilter1", m_rdoCnnlfInterLumaModelNameNNFilter1, string("models/RdNnlfSet1_LumaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter luma model name for RDO") + ( "RdoCnnlfIntraLumaModelNNFilter1", m_rdoCnnlfIntraLumaModelNameNNFilter1, string("models/RdNnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra luma model name for RDO") #endif ( "RPR", m_rprEnabledFlag, true, "Reference Sample Resolution" ) ( "ScalingRatioHor", m_scalingRatioHor, 1.0, "Scaling ratio in hor direction" ) @@ -2332,8 +2338,8 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) } #endif -#if JVET_AB0068_RD - m_encNnlfOpt = m_nnlfSet1 ? m_encNnlfOpt : false; +#if JVET_AC0328_NNLF_RDO + m_encNnlfOpt = (m_nnlfSet0 || m_nnlfSet1) ? m_encNnlfOpt : false; #endif if ( m_alf ) @@ -4085,7 +4091,7 @@ void EncAppCfg::xPrintParameter() #if NN_FILTERING_SET_1 msg( VERBOSE, "NNLFSET1:%d ", (m_nnlfSet1)?(1):(0)); #endif -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO msg( VERBOSE, "EncNnlfOpt:%d ", m_encNnlfOpt ? 1 : 0); #endif msg( VERBOSE, "SAO:%d ", (m_bUseSAO)?(1):(0)); diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h index ba685eced6e7143cd8b1f926333d03d13bf9fe25..9098bd8198b2d90be3e6f159694a870cbfc52f10 100644 --- a/source/App/EncoderApp/EncAppCfg.h +++ b/source/App/EncoderApp/EncAppCfg.h @@ -98,11 +98,12 @@ protected: std::string m_nnlfSet1AlternativeInterLumaModelName; ///<alternative inter luma nnlf set1 model #endif #endif -#if JVET_AB0068_RD - std::string m_rdoCnnlfInterLumaModelName; ///<inter luma cnnlf model - std::string m_rdoCnnlfInterChromaModelName; ///<inter chroma cnnlf model - std::string m_rdoCnnlfIntraLumaModelName; ///<intra luma cnnlf model - std::string m_rdoCnnlfIntraChromaModelName; ///<inra chroma cnnlf model + +#if JVET_AC0328_NNLF_RDO + std::string m_rdoCnnlfInterLumaModelNameNNFilter0; ///< inter luma nnlf set0 model + std::string m_rdoCnnlfIntraLumaModelNameNNFilter0; ///< intra luma nnlf set0 model + std::string m_rdoCnnlfInterLumaModelNameNNFilter1; ///< inter luma nnlf set1 model + std::string m_rdoCnnlfIntraLumaModelNameNNFilter1; ///< intra luma nnlf set1 model #endif // Lambda modifiers @@ -741,7 +742,7 @@ protected: #endif #endif -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO bool m_encNnlfOpt; #endif diff --git a/source/Lib/CommonLib/CodingStructure.cpp b/source/Lib/CommonLib/CodingStructure.cpp index 47eaef7eb94edc240f9d536a513d7a3bf362a3f4..a079d4c023cb9cbc6e3632308d5218d67081aef8 100644 --- a/source/Lib/CommonLib/CodingStructure.cpp +++ b/source/Lib/CommonLib/CodingStructure.cpp @@ -1177,7 +1177,7 @@ void CodingStructure::initSubStructure( CodingStructure& subStruct, const Channe subStruct.useDbCost = false; subStruct.costDbOffset = 0; -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO subStruct.useNnCost = false; subStruct.costNnOffset = 0; #endif @@ -1478,7 +1478,7 @@ void CodingStructure::initStructData( const int &QP, const bool &skipMotBuf ) lumaCost = MAX_DOUBLE; costDbOffset = 0; useDbCost = false; -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO costNnOffset = 0; useNnCost = false; #endif diff --git a/source/Lib/CommonLib/CodingStructure.h b/source/Lib/CommonLib/CodingStructure.h index 3c6a15d1f8e9a67cfde3552fa1fe9c557c6cce11..b5edead8f40e4140cbc7dd6ca0750b22345ab3e8 100644 --- a/source/Lib/CommonLib/CodingStructure.h +++ b/source/Lib/CommonLib/CodingStructure.h @@ -181,7 +181,7 @@ public: double cost; bool useDbCost; double costDbOffset; -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO bool useNnCost; double costNnOffset; #endif diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index e58f7f34361ddb78b4f72a7e069789d20413d8be..6103e8aa96cef102f6631a26c4fdbea205654994 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -106,6 +106,7 @@ using TypeSadl = float; #define JVET_AB0083_QPADJ 1 // JVET-AB0083: EE1-1.8: More refinements on NN based in-loop filter with a single model (Test 1) #endif +#define JVET_AC0328_NNLF_RDO 1 // JVET-AC0328: EE1-1.2: encoder-only optimization for NN based in-loop filter with a single model // nn filter set 1 #define NN_FILTERING_SET_1 1 @@ -121,8 +122,6 @@ using TypeSadl = float; #define JVET_AC0177_FLIP_INPUT 1 // JVET-AC0177: flip input and output of NN filter model #endif -#define JVET_AB0068_RD 1 // JVET-AB0068: EE1-1.6: RDO Considering Deep In-Loop Filtering - //########### place macros to be removed in next cycle below this line ############### #define JVET_V0056 1 // MCTF changes as presented in JVET-V0056 diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h index 3ff3f7d294ecdc66dcaa600da7fa4f383796c287..baf754b7ec9c978d89d15a831ba98e5e4022a5f3 100644 --- a/source/Lib/EncoderLib/EncCfg.h +++ b/source/Lib/EncoderLib/EncCfg.h @@ -167,13 +167,15 @@ protected: std::string m_nnlfSet1AlternativeInterLumaModelName; ///<alternative inter luma nnlf set1 model #endif #endif -#if JVET_AB0068_RD + +#if JVET_AC0328_NNLF_RDO bool m_encNnlfOpt; - std::string m_rdoCnnlfInterLumaModelName; ///<inter luma cnnlf model - std::string m_rdoCnnlfInterChromaModelName; ///<inter chroma cnnlf model - std::string m_rdoCnnlfIntraLumaModelName; ///<intra luma cnnlf model - std::string m_rdoCnnlfIntraChromaModelName; ///<inra chroma cnnlf model + std::string m_rdoCnnlfInterLumaModelNameNNFilter0; ///< inter luma nnlf set0 model + std::string m_rdoCnnlfIntraLumaModelNameNNFilter0; ///< intra luma nnlf set0 model + std::string m_rdoCnnlfInterLumaModelNameNNFilter1; ///< inter luma nnlf set1 model + std::string m_rdoCnnlfIntraLumaModelNameNNFilter1; ///< intra luma nnlf set1 model #endif + int m_iFrameRate; int m_FrameSkip; uint32_t m_temporalSubsampleRatio; @@ -853,19 +855,21 @@ public: void setNnlfSet1AlternativeInterLumaModelName(std::string s) { m_nnlfSet1AlternativeInterLumaModelName = s; } #endif #endif -#if JVET_AB0068_RD - std::string getRdoCnnlfInterLumaModelName() { return m_rdoCnnlfInterLumaModelName; } - std::string getRdoCnnlfInterChromaModelName() { return m_rdoCnnlfInterChromaModelName; } - std::string getRdoCnnlfIntraLumaModelName() { return m_rdoCnnlfIntraLumaModelName; } - std::string getRdoCnnlfIntraChromaModelName() { return m_rdoCnnlfIntraChromaModelName; } - void setRdoCnnlfInterLumaModelName(std::string s) { m_rdoCnnlfInterLumaModelName = s; } - void setRdoCnnlfInterChromaModelName(std::string s) { m_rdoCnnlfInterChromaModelName = s; } - void setRdoCnnlfIntraLumaModelName(std::string s) { m_rdoCnnlfIntraLumaModelName = s; } - void setRdoCnnlfIntraChromaModelName(std::string s) { m_rdoCnnlfIntraChromaModelName = s; } + +#if JVET_AC0328_NNLF_RDO + std::string getRdoCnnlfInterLumaModelNameNNFilter0() { return m_rdoCnnlfInterLumaModelNameNNFilter0; } + std::string getRdoCnnlfIntraLumaModelNameNNFilter0() { return m_rdoCnnlfIntraLumaModelNameNNFilter0; } + void setRdoCnnlfInterLumaModelNameNNFilter0(std::string s) { m_rdoCnnlfInterLumaModelNameNNFilter0 = s; } + void setRdoCnnlfIntraLumaModelNameNNFilter0(std::string s) { m_rdoCnnlfIntraLumaModelNameNNFilter0 = s; } + + std::string getRdoCnnlfInterLumaModelNameNNFilter1() { return m_rdoCnnlfInterLumaModelNameNNFilter1; } + std::string getRdoCnnlfIntraLumaModelNameNNFilter1() { return m_rdoCnnlfIntraLumaModelNameNNFilter1; } + void setRdoCnnlfInterLumaModelNameNNFilter1(std::string s) { m_rdoCnnlfInterLumaModelNameNNFilter1 = s; } + void setRdoCnnlfIntraLumaModelNameNNFilter1(std::string s) { m_rdoCnnlfIntraLumaModelNameNNFilter1 = s; } bool getUseEncNnlfOpt() { return m_encNnlfOpt; }; void setUseEncNnlfOpt(bool b) { m_encNnlfOpt = b; }; #endif - + void setProfile(Profile::Name profile) { m_profile = profile; } void setLevel(Level::Tier tier, Level::Name level) { m_levelTier = tier; m_level = level; } bool getFrameOnlyConstraintFlag() const { return m_frameOnlyConstraintFlag; } @@ -2042,6 +2046,9 @@ public: #if NN_FILTERING_SET_0 void setUseNnlfSet0(bool b) { m_nnlfSet0 = b; } +#if JVET_AC0328_NNLF_RDO + bool getUseNnlfSet0() const { return m_nnlfSet0; } +#endif #endif #if NN_FILTERING_SET_1 diff --git a/source/Lib/EncoderLib/EncCu.cpp b/source/Lib/EncoderLib/EncCu.cpp index cea3f9d7d02736c1a2b0d4815cdf824cea12711a..4dfa4f692d476d1b3a3ad56402c0945ee855eec4 100644 --- a/source/Lib/EncoderLib/EncCu.cpp +++ b/source/Lib/EncoderLib/EncCu.cpp @@ -235,8 +235,8 @@ void EncCu::init( EncLib* pcEncLib, const SPS& sps PARL_PARAM( const int tId ) ) m_pcEncLib = pcEncLib; m_dataId = tId; #endif -#if JVET_AB0068_RD - m_pcCNNFilter = pcEncLib->getGOPEncoder()->getEncCnnFilter(); +#if JVET_AC0328_NNLF_RDO + m_pcCNNLFEncoder = pcEncLib->getGOPEncoder()->getEncCnnFilter(); #endif m_pcLoopFilter = pcEncLib->getLoopFilter(); m_GeoCostList.init(GEO_NUM_PARTITION_MODE, m_pcEncCfg->getMaxNumGeoCand()); @@ -1335,7 +1335,7 @@ void EncCu::xCheckModeSplit(CodingStructure *&tempCS, CodingStructure *&bestCS, tempCS->cost = MAX_DOUBLE; tempCS->costDbOffset = 0; tempCS->useDbCost = m_pcEncCfg->getUseEncDbOpt(); -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO tempCS->costNnOffset = 0; tempCS->useNnCost = false; #endif @@ -1381,7 +1381,7 @@ void EncCu::xCheckModeSplit(CodingStructure *&tempCS, CodingStructure *&bestCS, tempCS->cost = MAX_DOUBLE; tempCS->costDbOffset = 0; tempCS->useDbCost = m_pcEncCfg->getUseEncDbOpt(); -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO tempCS->costNnOffset = 0; tempCS->useNnCost = false; #endif @@ -1574,11 +1574,11 @@ void EncCu::xCheckModeSplit(CodingStructure *&tempCS, CodingStructure *&bestCS, } } -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO UnitArea unitArea = clipArea(bestCS->area, *bestCS->picture); bool useEncNnOpt = false; - if (m_pcCNNFilter->m_NumModelsValid > 0 && partitioner.chType == CHANNEL_TYPE_LUMA && bestCS->cost != MAX_DOUBLE && tempCS->cost != MAX_DOUBLE) + if (m_pcCNNLFEncoder->m_NumModelsValid > 0 && partitioner.chType == CHANNEL_TYPE_LUMA && bestCS->cost != MAX_DOUBLE && tempCS->cost != MAX_DOUBLE) { int lumaWidth = unitArea.lwidth(); int lumaHeight = unitArea.lheight(); @@ -1672,7 +1672,7 @@ void EncCu::xCheckModeSplit(CodingStructure *&tempCS, CodingStructure *&bestCS, tempCS->prevQP[partitioner.chType] = oldPrevQp; } -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO void EncCu::xCheckCnnlf(CodingStructure& cs, UnitArea unitArea) { if (cs.useNnCost) @@ -1684,7 +1684,6 @@ void EncCu::xCheckCnnlf(CodingStructure& cs, UnitArea unitArea) Picture* pic = cs.picture; - Distortion orgDistAll = 0; Distortion orgDistComp[3] = { 0,0,0 }; for (uint32_t comp = 0; comp < 1; comp++) { @@ -1694,15 +1693,14 @@ void EncCu::xCheckCnnlf(CodingStructure& cs, UnitArea unitArea) continue; } orgDistComp[comp] = getDistortionDb(cs, cs.getOrgBuf(unitArea.blocks[compID]), cs.getRecoBuf(unitArea.blocks[compID]), compID, unitArea.blocks[compID], false); - orgDistAll += orgDistComp[comp]; } CompArea tempLumaArea = unitArea.Y(); tempLumaArea.repositionTo(Position(0, 0)); - PelBuf bufRec = m_pcCNNFilter->m_tempRecBuf [COMPONENT_Y].getBuf(tempLumaArea); - PelBuf bufPred = m_pcCNNFilter->m_tempPredBuf [COMPONENT_Y].getBuf(tempLumaArea); - PelBuf bufSplit = m_pcCNNFilter->m_tempSplitBuf[COMPONENT_Y].getBuf(tempLumaArea); + PelBuf bufRec = m_pcCNNLFEncoder->m_tempRecBuf [COMPONENT_Y].getBuf(tempLumaArea); + PelBuf bufPred = m_pcCNNLFEncoder->m_tempPredBuf [COMPONENT_Y].getBuf(tempLumaArea); + PelBuf bufSplit = m_pcCNNLFEncoder->m_tempSplitBuf[COMPONENT_Y].getBuf(tempLumaArea); bufRec.copyFrom(cs.getRecoBuf(unitArea.Y())); bufPred.copyFrom(cs.getPredBufCustom(unitArea.Y())); @@ -1736,31 +1734,26 @@ void EncCu::xCheckCnnlf(CodingStructure& cs, UnitArea unitArea) } subBufRec.rspSignal(m_pcReshape->getInvLUT()); } - subBufSplit.fill(subBufRec.computeAvg()); + if (pic->cs->sps->getNnlfSet1EnabledFlag()) + subBufSplit.fill(subBufRec.computeAvg()); } - for (int modelIdx = 0; modelIdx < min(1, m_pcCNNFilter->m_NumModelsValid); modelIdx++) - { - m_pcCNNFilter->cnnFilterLumaBlockRd_ext(pic, unitArea, 0, 0, 0, 0, modelIdx, pic->slices[0]->getSliceType() != I_SLICE); - - PelBuf bufDst = m_pcCNNFilter->m_tempBuf[modelIdx].getBuf(unitArea).get(COMPONENT_Y); + m_pcCNNLFEncoder->cnnFilterLumaBlockRd_ext(pic, unitArea, pic->slices[0]->getSliceType() != I_SLICE); + PelBuf bufDst = m_pcCNNLFEncoder->m_tempBuf.getBuf(unitArea).get(COMPONENT_Y); - Distortion cnnDistAll = 0; - Distortion cnnDistComp[3] = { 0,0,0 }; - for (uint32_t comp = 0; comp < 1; comp++) + Distortion cnnDistComp[3] = { 0,0,0 }; + for (uint32_t comp = 0; comp < 1; comp++) + { + const ComponentID compID = ComponentID(comp); + if (!cs.cus.at(0)->blocks[comp].valid()) { - const ComponentID compID = ComponentID(comp); - if (!cs.cus.at(0)->blocks[comp].valid()) - { - continue; - } - cnnDistComp[comp] = getDistortionDb(cs, cs.getOrgBuf(unitArea.blocks[compID]), bufDst, compID, unitArea.blocks[compID], true); - cnnDistAll += cnnDistComp[comp]; + continue; } - - cs.useNnCost = true; - cs.costNnOffset = min(double(cs.costNnOffset), (m_pcRdCost->calcRdCost(0, cnnDistComp[COMPONENT_Y]) - m_pcRdCost->calcRdCost(0, orgDistComp[COMPONENT_Y]))); + cnnDistComp[comp] = getDistortionDb(cs, cs.getOrgBuf(unitArea.blocks[compID]), bufDst, compID, unitArea.blocks[compID], true); } + + cs.useNnCost = true; + cs.costNnOffset = min(double(cs.costNnOffset), (m_pcRdCost->calcRdCost(0, cnnDistComp[COMPONENT_Y]) - m_pcRdCost->calcRdCost(0, orgDistComp[COMPONENT_Y]))); } #endif @@ -4530,7 +4523,7 @@ Distortion EncCu::getDistortionDb( CodingStructure &cs, CPelBuf org, CPelBuf rec Distortion dist = 0; #if WCG_EXT m_pcRdCost->setChromaFormat(cs.sps->getChromaFormatIdc()); -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO CPelBuf orgLuma = cs.picture->getOrigBuf(clipArea(cs.area.blocks[COMPONENT_Y], cs.picture->blocks[COMPONENT_Y])); #else CPelBuf orgLuma = cs.picture->getOrigBuf( cs.area.blocks[COMPONENT_Y] ); @@ -4869,7 +4862,7 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS if( bestCost == bestCS->cost ) //The first EMT pass didn't become the bestCS, so we clear the TUs generated { tempCS->clearTUs(); -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO cu = tempCS->getCU(partitioner.chType); #endif } @@ -4895,7 +4888,7 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS tempCS->cost = MAX_DOUBLE; cu->skip = false; -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO tempCS->useNnCost = false; tempCS->costNnOffset = 0; #endif diff --git a/source/Lib/EncoderLib/EncCu.h b/source/Lib/EncoderLib/EncCu.h index 6cb93403b3418336bd74b75fc3fc36c3008475ac..78444f2c97674c5cf8640a75245d2fb1c477e48e 100644 --- a/source/Lib/EncoderLib/EncCu.h +++ b/source/Lib/EncoderLib/EncCu.h @@ -55,8 +55,8 @@ #include "InterSearch.h" #include "RateCtrl.h" #include "EncModeCtrl.h" -#if JVET_AB0068_RD -#include "EncoderLib/EncNNFilterSet1.h" +#if JVET_AC0328_NNLF_RDO +#include "EncNNFilter.h" #endif //! \ingroup EncoderLib //! \{ @@ -184,8 +184,8 @@ private: TrQuant* m_pcTrQuant; RdCost* m_pcRdCost; EncSlice* m_pcSliceEncoder; -#if JVET_AB0068_RD - EncNNFilterSet1* m_pcCNNFilter; +#if JVET_AC0328_NNLF_RDO + EncNNFilter* m_pcCNNLFEncoder; #endif LoopFilter* m_pcLoopFilter; @@ -263,7 +263,7 @@ protected: xCheckBestMode ( CodingStructure *&tempCS, CodingStructure *&bestCS, Partitioner &pm, const EncTestMode& encTestmode ); void xCheckModeSplit ( CodingStructure *&tempCS, CodingStructure *&bestCS, Partitioner &pm, const EncTestMode& encTestMode, const ModeType modeTypeParent, bool &skipInterPass ); -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO void xCheckCnnlf ( CodingStructure& cs, UnitArea unitArea ); #endif diff --git a/source/Lib/EncoderLib/EncGOP.cpp b/source/Lib/EncoderLib/EncGOP.cpp index 996fc82098d416776852387fe4bd8284a1a7b747..32fb9ccceec367f7a13131411ea77b53b4b57f6c 100644 --- a/source/Lib/EncoderLib/EncGOP.cpp +++ b/source/Lib/EncoderLib/EncGOP.cpp @@ -202,11 +202,11 @@ void EncGOP::destroy() delete m_picOrig; m_picOrig = NULL; } +#if JVET_AC0328_NNLF_RDO + m_cCNNLFEncoder.destroyEnc(); +#endif #if NN_FILTERING_SET_1 m_pcNNFilterSet1.destroy(); -#if JVET_AB0068_RD - m_pcNNFilterSet1.destroyEnc(); -#endif #endif } @@ -222,6 +222,10 @@ void EncGOP::init ( EncLib* pcEncLib ) m_pcSAO = pcEncLib->getSAO(); m_pcALF = pcEncLib->getALF(); +#if JVET_AC0328_NNLF_RDO + m_cCNNLFEncoder.createEnc(m_pcCfg->getSourceWidth(), m_pcCfg->getSourceHeight(), m_pcCfg->getMaxCUWidth(), m_pcCfg->getMaxCUHeight(), m_pcCfg->getChromaFormatIdc(), m_pcCfg->getNnlfSet1MaxNumParams()); +#endif + #if NN_FILTERING_SET_0 m_pcCNNLF = pcEncLib->getCNNLF(); #endif @@ -235,12 +239,6 @@ void EncGOP::init ( EncLib* pcEncLib ) #if NN_FILTERING_SET_1 m_pcNNFilterSet1.create(m_pcCfg->getSourceWidth(), m_pcCfg->getSourceHeight(), m_pcCfg->getChromaFormatIdc(), m_pcCfg->getNnlfSet1MaxNumParams()); -#if JVET_AB0068_RD - if (m_pcCfg->getUseEncNnlfOpt()) - { - m_pcNNFilterSet1.createEnc(m_pcCfg->getMaxCUWidth(), m_pcCfg->getMaxCUHeight(), m_pcCfg->getChromaFormatIdc()); - } -#endif #endif #if WCG_EXT @@ -2829,8 +2827,11 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, { m_pcSliceEncoder->setJointCbCrModes(*pcPic->cs, Position(0, 0), pcPic->cs->area.lumaSize()); } -#if JVET_AB0068_RD - m_pcNNFilterSet1.initEnc(m_pcEncLib->getUseEncNnlfOpt() ? 1 : 0, m_pcEncLib->getRdoCnnlfInterLumaModelName(), m_pcEncLib->getRdoCnnlfInterChromaModelName(), m_pcEncLib->getRdoCnnlfIntraLumaModelName(), m_pcEncLib->getRdoCnnlfIntraChromaModelName()); +#if JVET_AC0328_NNLF_RDO + if (m_pcEncLib->getUseNnlfSet0()) + m_cCNNLFEncoder.initEnc(m_pcEncLib->getUseEncNnlfOpt() ? 1 : 0, m_pcEncLib->getUseNnlfSet0(), m_pcEncLib->getRdoCnnlfInterLumaModelNameNNFilter0(), m_pcEncLib->getRdoCnnlfIntraLumaModelNameNNFilter0()); + else + m_cCNNLFEncoder.initEnc(m_pcEncLib->getUseEncNnlfOpt() ? 1 : 0, m_pcEncLib->getUseNnlfSet0(), m_pcEncLib->getRdoCnnlfInterLumaModelNameNNFilter1(), m_pcEncLib->getRdoCnnlfIntraLumaModelNameNNFilter1()); #endif if( encPic ) // now compress (trial encode) the various slice segments (slices, and dependent slices) diff --git a/source/Lib/EncoderLib/EncGOP.h b/source/Lib/EncoderLib/EncGOP.h index f6cc19ed03339caabe9cd078fd73f038a3bf562b..8145980355b207399d879787b4d4dc972b2ecfa4 100644 --- a/source/Lib/EncoderLib/EncGOP.h +++ b/source/Lib/EncoderLib/EncGOP.h @@ -73,6 +73,10 @@ #include <chrono> #endif +#if JVET_AC0328_NNLF_RDO +#include "EncNNFilter.h" +#endif + #if NN_FILTERING_SET_0 #include "EncNNFilterSet0.h" #endif @@ -165,6 +169,9 @@ private: EncSampleAdaptiveOffset* m_pcSAO; EncAdaptiveLoopFilter* m_pcALF; +#if JVET_AC0328_NNLF_RDO + EncNNFilter m_cCNNLFEncoder; +#endif #if NN_FILTERING_SET_0 EncNNFilterSet0* m_pcCNNLF; #endif @@ -256,8 +263,8 @@ public: void setLastLTRefPoc(int iLastLTRefPoc) { m_lastLTRefPoc = iLastLTRefPoc; } int getLastLTRefPoc() const { return m_lastLTRefPoc; } -#if JVET_AB0068_RD - EncNNFilterSet1* getEncCnnFilter() { return &m_pcNNFilterSet1; } +#if JVET_AC0328_NNLF_RDO + EncNNFilter* getEncCnnFilter() { return &m_cCNNLFEncoder; } #endif void printOutSummary( uint32_t uiNumAllPicCoded, bool isField, const bool printMSEBasedSNR, const bool printSequenceMSE, diff --git a/source/Lib/EncoderLib/EncModeCtrl.cpp b/source/Lib/EncoderLib/EncModeCtrl.cpp index 8cda4c889972529ed981a2c29ae5b0677f091773..6db41edb1339b5a79668176ab1ecc9760222a4bb 100644 --- a/source/Lib/EncoderLib/EncModeCtrl.cpp +++ b/source/Lib/EncoderLib/EncModeCtrl.cpp @@ -2097,7 +2097,7 @@ bool EncModeCtrlMTnoRQT::useModeResult( const EncTestMode& encTestmode, CodingSt } // for now just a simple decision based on RD-cost or choose tempCS if bestCS is not yet coded -#if JVET_AB0068_RD +#if JVET_AC0328_NNLF_RDO bool useTemp = false; if (tempCS->features[ENC_FT_RD_COST] != MAX_DOUBLE) { diff --git a/source/Lib/EncoderLib/EncNNFilter.cpp b/source/Lib/EncoderLib/EncNNFilter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..093250f1c2036fa8ba2028a5df20bba0dc2a138e --- /dev/null +++ b/source/Lib/EncoderLib/EncNNFilter.cpp @@ -0,0 +1,372 @@ +/* The copyright in this software is being made available under the BSD + * License, included below. This software may be subject to other third party + * and contributor rights, including patent rights, and no such rights are + * granted under this license. + * + * Copyright (c) 2010-2018, ITU/ISO/IEC + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may + * be used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF + * THE POSSIBILITY OF SUCH DAMAGE. + */ + + /** \file EncNNFilter.cpp + \brief cnn loop filter on the encoder side class + */ + +#include "EncNNFilter.h" + +#if JVET_AC0328_NNLF_RDO +void EncNNFilter::createEnc(const int picWidth, const int picHeight, const int maxCuWidth, const int maxCuHeight, const ChromaFormat format, const int nnlfSet1NumParams) +{ + if (m_tempBuf.getOrigin(0) == NULL) + { + m_tempBuf.create(format, Area(0, 0, picWidth, picHeight)); + } + if (m_tempRecBuf[0].getOrigin(0) == NULL) + { + m_tempRecBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); + } + if (!m_tempPredBuf[0].getOrigin(0)) + { + m_tempPredBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); + } + if (!m_tempSplitBuf[0].getOrigin(0)) + { + m_tempSplitBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); + } +} +void EncNNFilter::destroyEnc() +{ + m_tempBuf.destroy(); + m_tempRecBuf[0].destroy(); + m_tempPredBuf[0].destroy(); + m_tempSplitBuf[0].destroy(); +} + +void EncNNFilter::initEnc(int numModels, bool nnlfSet0Flag, std::string interLuma, std::string intraLuma) +{ + m_NumModelsValid = numModels; + if (nnlfSet0Flag) + { + m_interLumaRdNNFilter0 = interLuma; + m_intraLumaRdNNFilter0 = intraLuma; + } + else + { + m_interLumaRdNNFilter1 = interLuma; + m_intraLumaRdNNFilter1 = intraLuma; + } +} + +template<typename T> +struct ModelData { + sadl::Model<T> model; + vector<sadl::Tensor<T>> inputs; + int hor,ver; + bool luma,inter; +}; + +template<typename T> +static std::vector<ModelData<T>> initSpace() { + std::vector<ModelData<T>> v; + v.reserve(5); + return v; +} + +#if NN_FIXED_POINT_IMPLEMENTATION +static std::vector<ModelData<int16_t>> modelsRd = initSpace<int16_t>(); +#else +static std::vector<ModelData<float>> modelsRd = initSpace<float>(); +#endif + +template<typename T> +static ModelData<T> &getModelRd(int ver, int hor, bool luma, bool inter, const std::string modelName) { + ModelData<T> *ptr = nullptr; + for(auto &m: modelsRd) + { + if (m.luma == luma && m.inter == inter) + { + ptr = &m; + break; + } + } + if (ptr == nullptr) + { + if (modelsRd.size() == modelsRd.capacity()) + { + std::cout << "[ERROR] RDO increase cache" << std::endl; + exit(-1); + } + modelsRd.resize(modelsRd.size()+1); + ptr = &modelsRd.back(); + ModelData<T> &m = *ptr; + ifstream file(modelName, ios::binary); + m.model.load(file); + m.luma = luma; + m.inter = inter; + m.ver = 0; + m.hor = 0; + } + ModelData<T> &m = *ptr; + if (m.ver != ver || m.hor != hor) + { + m.inputs = m.model.getInputsTemplate(); + if (luma) + { + for(auto &t: m.inputs) + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + } + else if (inter) // inter chroma + { + int inputId = 0; + for(auto &t: m.inputs) + { + if (t.dims()[3] == 1 && inputId != 3) // luma + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + else if (t.dims()[3] == 1 && inputId == 3) // qp + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 1})); + t.resize(dims); + } + else + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 2})); + t.resize(dims); + } + inputId ++; + } + } + else // intra chroma + { + int inputId = 0; + for(auto &t: m.inputs) + { + if (t.dims()[3] == 1 && inputId != 4 ) // luma + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + else if (t.dims()[3] == 1 && inputId == 4) // QP + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 1})); + t.resize(dims); + } + else + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 2})); + t.resize(dims); + } + inputId ++; + } + } + if (!m.model.init(m.inputs)) + { + cerr << "[ERROR] RDO issue during initialization" << endl; + exit(-1); + } + m.ver = ver; + m.hor = hor; + } + return m; +} + +template<typename T> +void EncNNFilter::prepareInputsLumaRd (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int sliceQp, int baseQp, bool inter) +{ + double inputScale = 1024; + bool nnlfSet0Flag = pic->cs->sps->getNnlfSet0EnabledFlag(); +#if NN_FIXED_POINT_IMPLEMENTATION + int shiftInput = nnlfSet0Flag ? 11 : NN_INPUT_PRECISION; +#else + int shiftInput = 0; +#endif + + CompArea tempLumaArea = inferArea.Y(); + tempLumaArea.repositionTo(Position(0, 0)); + + PelBuf bufRec = m_tempRecBuf[0].getBuf(tempLumaArea); + PelBuf bufPred = m_tempPredBuf[0].getBuf(tempLumaArea); + PelBuf bufPartition = m_tempSplitBuf[0].getBuf(tempLumaArea); + + sadl::Tensor<T>* inputRec, *inputPred, *inputSliceQp, *inputBaseQp, *inputPartition, *inputQp; + if (inter) + { + inputRec = &inputs[0]; + inputPred = &inputs[1]; + if (nnlfSet0Flag) + { + inputSliceQp = &inputs[2]; + inputBaseQp = &inputs[3]; + } + else + { + inputQp = &inputs[2]; + } + } + else + { + inputRec = &inputs[0]; + inputPred = &inputs[1]; + if (nnlfSet0Flag) + { + inputBaseQp = &inputs[2]; + } + else + { + inputPartition = &inputs[2]; + inputQp = &inputs[3]; + } + } + + int hor = inferArea.lwidth(); + int ver = inferArea.lheight(); + + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + (*inputRec)(0, yy, xx, 0) = (T)(bufRec.at(xx, yy) / inputScale * (1 << shiftInput)); + (*inputPred)(0, yy, xx, 0) = (T)(bufPred.at(xx, yy) / inputScale * (1 << shiftInput)); + if (nnlfSet0Flag) + { + if (inter) + { + (*inputSliceQp)(0, yy, xx, 0) = (T)(sliceQp / inputScale * (1 << shiftInput)); + } + (*inputBaseQp)(0, yy, xx, 0) = (T)(baseQp / inputScale * (1 << shiftInput)); + } + else + { + (*inputQp)(0, yy, xx, 0) = (T)(baseQp / 64.0 * (1 << shiftInput)); + if (!inter) + { + (*inputPartition)(0, yy, xx, 0) = (T)(bufPartition.at(xx, yy) / inputScale * (1 << shiftInput)); + } + } + } + } +} + +template<typename T> +void EncNNFilter::extractOutputsLumaRd (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, UnitArea inferArea) +{ +#if NN_FIXED_POINT_IMPLEMENTATION + int log2InputScale = 10; + int log2OutputScale = 10; + int shiftInput = pic->cs->sps->getNnlfSet0EnabledFlag() ? 11 - log2InputScale : NN_OUTPUT_PRECISION - log2OutputScale; + int shiftOutput = pic->cs->sps->getNnlfSet0EnabledFlag() ? 11 - log2InputScale : NN_OUTPUT_PRECISION - log2OutputScale; + int offset = (1 << shiftOutput) / 2; +#else + double inputScale = 1024; + double outputScale = 1024; +#endif + auto output = m.result(0); + PelBuf bufDst = tempBuf.getBuf(inferArea).get(COMPONENT_Y); + + int hor = inferArea.lwidth(); + int ver = inferArea.lheight(); + + CompArea tempLumaArea = inferArea.Y(); + tempLumaArea.repositionTo(Position(0, 0)); + + PelBuf bufRec = m_tempRecBuf[0].getBuf(tempLumaArea); + + for (int c = 0; c < 4; c++) // output includes 4 sub images + { + for (int y = 0; y < ver >> 1; y++) + { + for (int x = 0; x < hor >> 1; x++) + { + int yy = (y << 1) + c / 2; + int xx = (x << 1) + c % 2; +#if NN_FIXED_POINT_IMPLEMENTATION + int out = ( output(0, y, x, c) + (bufRec.at(xx, yy) << shiftInput) + offset) >> shiftOutput; +#else + int out = ( output(0, y, x, c) + bufRec.at(xx, yy) / inputScale ) * outputScale + 0.5; +#endif + bufDst.at(xx, yy) = Pel(Clip3<int>( 0, 1023, out)); + } + } + } +} + +template<typename T> +void EncNNFilter::cnnFilterLumaBlockRd(Picture* pic, UnitArea inferArea, bool inter) +{ + // get model + bool nnlfSet0Flag = pic->cs->sps->getNnlfSet0EnabledFlag(); + string m_interLumaRd, m_intraLumaRd; + m_interLumaRd = nnlfSet0Flag ? m_interLumaRdNNFilter0 : m_interLumaRdNNFilter1; + m_intraLumaRd = nnlfSet0Flag ? m_intraLumaRdNNFilter0 : m_intraLumaRdNNFilter1; + + ModelData<T> &m = getModelRd<T>(inferArea.lheight(), inferArea.lwidth(), true, inter, inter ? m_interLumaRd : m_intraLumaRd); + sadl::Model<T> &model = m.model; + + // get inputs + vector<sadl::Tensor<T>> &inputs = m.inputs; + + int baseQp = pic->slices[0]->getPPS()->getPicInitQPMinus26() + 26; + int sliceQp = pic->slices[0]->getSliceQp(); + int paramIdx = 0; + int delta = inter ? paramIdx * 5 : paramIdx * 2; + if ( pic->slices[0]->getTLayer() >= 4 && paramIdx >= 2 ) + { + delta = 5 - delta; + } + int qp = inter ? baseQp - delta : sliceQp - delta; + + if (nnlfSet0Flag) + prepareInputsLumaRd<T>(pic, inferArea, inputs, sliceQp, baseQp, inter); + else + prepareInputsLumaRd<T>(pic, inferArea, inputs, sliceQp, qp, inter); + + // inference + if (!model.apply(inputs)) + { + cerr << "[ERROR] RDO issue during luma model inference" << endl; + exit(-1); + } + + // get outputs + extractOutputsLumaRd<T>(pic, model, m_tempBuf, inferArea); +} + +void EncNNFilter::cnnFilterLumaBlockRd_ext(Picture* pic, UnitArea inferArea, bool inter) +{ +#if NN_FIXED_POINT_IMPLEMENTATION + cnnFilterLumaBlockRd<int16_t>(pic, inferArea, inter); +#else + cnnFilterLumaBlockRd<float>(pic, inferArea, inter); +#endif +} + +#endif + diff --git a/source/Lib/EncoderLib/EncNNFilter.h b/source/Lib/EncoderLib/EncNNFilter.h new file mode 100644 index 0000000000000000000000000000000000000000..07d039b77f58f71690bbcac38e716a6392faa966 --- /dev/null +++ b/source/Lib/EncoderLib/EncNNFilter.h @@ -0,0 +1,84 @@ +/* The copyright in this software is being made available under the BSD + * License, included below. This software may be subject to other third party + * and contributor rights, including patent rights, and no such rights are + * granted under this license. + * + * Copyright (c) 2010-2018, ITU/ISO/IEC + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * * Neither the name of the ITU/ISO/IEC nor the names of its contributors may + * be used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS + * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF + * THE POSSIBILITY OF SUCH DAMAGE. + */ + +/** \file EncNNFilterSet0.h + \brief estimation part of cnn loop filter class (header) + */ + +/** \file EncNNFilter.h + \brief cnn loop filter on the encoder side class (header) + */ + +#ifndef __ENCNNFILTER__ +#define __ENCNNFILTER__ + +#include "CommonLib/CommonDef.h" + +#if JVET_AC0328_NNLF_RDO + +#include "Picture.h" +#include <sadl/model.h> +#include <fstream> +using namespace std; + +class EncNNFilter +{ +public: + std::string m_interLumaRdNNFilter0; + std::string m_intraLumaRdNNFilter0; + + int m_NumModelsValid; + std::string m_interLumaRdNNFilter1; + std::string m_intraLumaRdNNFilter1; + + PelStorage m_tempBuf; + PelStorage m_tempRecBuf [MAX_NUM_COMPONENT]; + PelStorage m_tempPredBuf [MAX_NUM_COMPONENT]; + PelStorage m_tempSplitBuf[MAX_NUM_COMPONENT]; + void createEnc(const int picWidth, const int picHeight, const int maxCuWidth, const int maxCuHeight, const ChromaFormat format, const int nnlfSet1NumParams); + void destroyEnc(); + void initEnc(int numModels, bool nnlfSet0Flag, std::string interLuma, std::string intraLuma); + +private: + template<typename T> + void prepareInputsLumaRd(Picture* pic, UnitArea inferArea, std::vector<sadl::Tensor<T>>& inputs, int sliceQp, int baseQp, bool inter); + template<typename T> + void extractOutputsLumaRd(Picture* pic, sadl::Model<T>& m, PelStorage& tempBuf, UnitArea inferArea); + template<typename T> + void cnnFilterLumaBlockRd(Picture* pic, UnitArea inferArea, bool inter); + +public: + void cnnFilterLumaBlockRd_ext(Picture* pic, UnitArea inferArea, bool inter); +}; +#endif +#endif diff --git a/source/Lib/EncoderLib/EncNNFilterSet1.cpp b/source/Lib/EncoderLib/EncNNFilterSet1.cpp index c6c8100856feb6c6fd771a70ed80a81629d6aa5f..3bf646862170b11f79bcd40609df14a64851155c 100644 --- a/source/Lib/EncoderLib/EncNNFilterSet1.cpp +++ b/source/Lib/EncoderLib/EncNNFilterSet1.cpp @@ -67,307 +67,6 @@ EncNNFilterSet1::EncNNFilterSet1() EncNNFilterSet1::~EncNNFilterSet1() { } - -#if JVET_AB0068_RD -void EncNNFilterSet1::createEnc(const int maxCuWidth, const int maxCuHeight, const ChromaFormat format) -{ - if (m_tempRecBuf[0].getOrigin(0) == NULL) - { - m_tempRecBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); - } - if (!m_tempPredBuf[0].getOrigin(0)) - { - m_tempPredBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); - } - if (!m_tempSplitBuf[0].getOrigin(0)) - { - m_tempSplitBuf[0].create(format, Area(0, 0, maxCuWidth, maxCuHeight)); - } -} -void EncNNFilterSet1::destroyEnc() -{ - m_tempRecBuf[0].destroy(); - m_tempPredBuf[0].destroy(); - m_tempSplitBuf[0].destroy(); -} -void EncNNFilterSet1::initEnc(int numModels, std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma) -{ - m_NumModelsValid = numModels; - m_interLumaRd = interLuma; - m_interChromaRd = interChroma; - m_intraLumaRd = intraLuma; - m_intraChromaRd = intraChroma; -} - -template<typename T> -struct ModelData { - sadl::Model<T> model; - vector<sadl::Tensor<T>> inputs; - int hor,ver; - bool luma,inter; -}; - -template<typename T> -static std::vector<ModelData<T>> initSpace() { - std::vector<ModelData<T>> v; - v.reserve(5); - return v; -} - -#if NN_FIXED_POINT_IMPLEMENTATION -static std::vector<ModelData<int16_t>> modelsRd = initSpace<int16_t>(); -#else -static std::vector<ModelData<float>> modelsRd = initSpace<float>(); -#endif - -template<typename T> -static ModelData<T> &getModelRd(int ver, int hor, bool luma, bool inter, const std::string modelName) { - ModelData<T> *ptr = nullptr; - for(auto &m: modelsRd) - { - if (m.luma == luma && m.inter == inter) - { - ptr = &m; - break; - } - } - if (ptr == nullptr) - { - if (modelsRd.size() == modelsRd.capacity()) - { - std::cout << "[ERROR] RDO increase cache" << std::endl; - exit(-1); - } - modelsRd.resize(modelsRd.size()+1); - ptr = &modelsRd.back(); - ModelData<T> &m = *ptr; - ifstream file(modelName, ios::binary); - m.model.load(file); - m.luma = luma; - m.inter = inter; - m.ver = 0; - m.hor = 0; - } - ModelData<T> &m = *ptr; - if (m.ver != ver || m.hor != hor) - { - m.inputs = m.model.getInputsTemplate(); - if (luma) - { - for(auto &t: m.inputs) - { - sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); - t.resize(dims); - } - } - else if (inter) // inter chroma - { - int inputId = 0; - for(auto &t: m.inputs) - { - if (t.dims()[3] == 1 && inputId != 3) // luma - { - sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); - t.resize(dims); - } - else if (t.dims()[3] == 1 && inputId == 3) // qp - { - sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 1})); - t.resize(dims); - } - else - { - sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 2})); - t.resize(dims); - } - inputId ++; - } - } - else // intra chroma - { - int inputId = 0; - for(auto &t: m.inputs) - { - if (t.dims()[3] == 1 && inputId != 4 ) // luma - { - sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); - t.resize(dims); - } - else if (t.dims()[3] == 1 && inputId == 4) // QP - { - sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 1})); - t.resize(dims); - } - else - { - sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 2})); - t.resize(dims); - } - inputId ++; - } - } - if (!m.model.init(m.inputs)) - { - cerr << "[ERROR] RDO issue during initialization" << endl; - exit(-1); - } - m.ver = ver; - m.hor = hor; - } - return m; -} - -template<typename T> -void EncNNFilterSet1::prepareInputsLumaRd (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int qp, bool inter) -{ - double inputScale = 1024; -#if NN_FIXED_POINT_IMPLEMENTATION - int shiftInput = NN_INPUT_PRECISION; -#else - int shiftInput = 0; -#endif - - CompArea tempLumaArea = inferArea.Y(); - tempLumaArea.repositionTo(Position(0, 0)); - - PelBuf bufRec = m_tempRecBuf[0].getBuf(tempLumaArea); - PelBuf bufPred = m_tempPredBuf[0].getBuf(tempLumaArea); - PelBuf bufPartition = m_tempSplitBuf[0].getBuf(tempLumaArea); - - sadl::Tensor<T>* inputRec, *inputPred, *inputPartition, *inputQp; - if (inter) - { - inputRec = &inputs[0]; - inputPred = &inputs[1]; - inputQp = &inputs[2]; - } - else - { - inputRec = &inputs[0]; - inputPred = &inputs[1]; - inputPartition = &inputs[2]; - inputQp = &inputs[3]; - } - - int hor = inferArea.lwidth(); - int ver = inferArea.lheight(); - - for (int yy = 0; yy < ver; yy++) - { - for (int xx = 0; xx < hor; xx++) - { - (*inputRec)(0, yy, xx, 0) = (T)(bufRec.at(xx, yy) / inputScale * (1 << shiftInput)); - (*inputPred)(0, yy, xx, 0) = (T)(bufPred.at(xx, yy) / inputScale * (1 << shiftInput)); - (*inputQp)(0, yy, xx, 0) = (T)(qp / 64.0 * (1 << shiftInput)); - if (!inter) - { - (*inputPartition)(0, yy, xx, 0) = (T)(bufPartition.at(xx, yy) / inputScale * (1 << shiftInput)); - } - } - } -} - -template<typename T> -void EncNNFilterSet1::extractOutputsLumaRd (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom) -{ -#if NN_FIXED_POINT_IMPLEMENTATION - int log2InputScale = 10; - int log2OutputScale = 10; - int shiftInput = NN_OUTPUT_PRECISION - log2InputScale; - int shiftOutput = NN_OUTPUT_PRECISION - log2OutputScale; - int offset = (1 << shiftOutput) / 2; -#else - double inputScale = 1024; - double outputScale = 1024; -#endif - auto output = m.result(0); - PelBuf bufDst = tempBuf.getBuf(inferArea).get(COMPONENT_Y); - - int hor = inferArea.lwidth(); - int ver = inferArea.lheight(); - -#if 1 - CompArea tempLumaArea = inferArea.Y(); - tempLumaArea.repositionTo(Position(0, 0)); - - PelBuf bufRec = m_tempRecBuf[0].getBuf(tempLumaArea); -#else -#if FUSE_NN_AND_LF - PelBuf bufRec = pic->getUnfilteredRecBuf(inferArea).get(COMPONENT_Y); -#else - PelBuf bufRec = pic->getRecoBuf(inferArea).get(COMPONENT_Y); -#endif -#endif - - for (int c = 0; c < 4; c++) // output includes 4 sub images - { - for (int y = 0; y < ver >> 1; y++) - { - for (int x = 0; x < hor >> 1; x++) - { - int yy = (y << 1) + c / 2; - int xx = (x << 1) + c % 2; - if (xx < extLeft || yy < extTop || xx >= hor - extRight || yy >= ver - extBottom) - { - continue; - } -#if NN_FIXED_POINT_IMPLEMENTATION - int out = ( output(0, y, x, c) + (bufRec.at(xx, yy) << shiftInput) + offset) >> shiftOutput; -#else - int out = ( output(0, y, x, c) + bufRec.at(xx, yy) / inputScale ) * outputScale + 0.5; -#endif - bufDst.at(xx, yy) = Pel(Clip3<int>( 0, 1023, out)); - } - } - } -} - -template<typename T> void EncNNFilterSet1::cnnFilterLumaBlockRd(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter) -{ - //at::init_num_threads(); // use all available threads - - const int border_to_skip = 0; - if (border_to_skip>0) sadl::Tensor<float>::skip_border = true; - - // get model - ModelData<T> &m = getModelRd<T>(inferArea.lheight(), inferArea.lwidth(), true, inter, inter ? m_interLumaRd : m_intraLumaRd); - sadl::Model<T> &model = m.model; - - // get inputs - vector<sadl::Tensor<T>> &inputs = m.inputs; - - int seqQp = pic->slices[0]->getPPS()->getPicInitQPMinus26() + 26; - int sliceQp = pic->slices[0]->getSliceQp(); - int delta = inter ? paramIdx * 5 : paramIdx * 2; - if ( pic->slices[0]->getTLayer() >= 4 && paramIdx >= 2 ) - { - delta = 5 - delta; - } - int qp = inter ? seqQp - delta : sliceQp - delta; - - prepareInputsLumaRd<T>(pic, inferArea, inputs, qp, inter); - - // inference - if (!model.apply(inputs)) - { - cerr << "[ERROR] RDO issue during luma model inference" << endl; - exit(-1); - } - - // get outputs - extractOutputsLumaRd<T>(pic, model, m_tempBuf[paramIdx], inferArea, extLeft, extRight, extTop, extBottom); -} - -void EncNNFilterSet1::cnnFilterLumaBlockRd_ext(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter) -{ -#if NN_FIXED_POINT_IMPLEMENTATION - cnnFilterLumaBlockRd<int16_t>(pic, inferArea, extLeft, extRight, extTop, extBottom, paramIdx, inter); -#else - cnnFilterLumaBlockRd<float>(pic, inferArea, extLeft, extRight, extTop, extBottom, paramIdx, inter); -#endif -} -#endif - void EncNNFilterSet1::initCABACEstimator( CABACEncoder* cabacEncoder, CtxCache* ctxCache, Slice* pcSlice ) { m_CABACEstimator = cabacEncoder->getCABACEstimator( pcSlice->getSPS() ); diff --git a/source/Lib/EncoderLib/EncNNFilterSet1.h b/source/Lib/EncoderLib/EncNNFilterSet1.h index 68c5d56c7472e7a51d5a138e7db0e043264bfb93..da873d5ba0df628168550936c0abd7aedf0fdb80 100644 --- a/source/Lib/EncoderLib/EncNNFilterSet1.h +++ b/source/Lib/EncoderLib/EncNNFilterSet1.h @@ -78,34 +78,6 @@ public: #else void calcRDCost(Picture *pic,std::vector<PelStorage>& tempBuf, int numParams, double* minCost); #endif - -#if JVET_AB0068_RD - int m_NumModelsValid; - - std::string m_interLumaRd; - std::string m_interChromaRd; - std::string m_intraLumaRd; - std::string m_intraChromaRd; - - PelStorage m_tempRecBuf[MAX_NUM_COMPONENT]; - PelStorage m_tempPredBuf[MAX_NUM_COMPONENT]; - PelStorage m_tempSplitBuf[MAX_NUM_COMPONENT]; - void createEnc(const int maxCuWidth, const int maxCuHeight, const ChromaFormat format); - void destroyEnc(); - void initEnc(int numModels, std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma); - -private: - template<typename T> - void prepareInputsLumaRd(Picture* pic, UnitArea inferArea, std::vector<sadl::Tensor<T>>& inputs, int qp, bool inter); - template<typename T> - void extractOutputsLumaRd(Picture* pic, sadl::Model<T>& m, PelStorage& tempBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom); - template<typename T> - void cnnFilterLumaBlockRd(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter); - -public: - void cnnFilterLumaBlockRd_ext(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter); -#endif - }; //! \}