From b031efb109e9f657acea5e1d43531b88ae8968f2 Mon Sep 17 00:00:00 2001 From: Liqiang Wang <liqiangwang@tencent.com> Date: Mon, 15 Aug 2022 17:06:54 +0800 Subject: [PATCH] For JVET-AA0088, apply the common inference API, refactor the code and add a SPS flag --- .../NnlfSet0_model_float.sadl | Bin .../NnlfSet0_model_int16.sadl | Bin source/App/DecoderApp/DecApp.cpp | 6 +- source/App/DecoderApp/DecAppCfg.cpp | 8 +- source/App/DecoderApp/DecAppCfg.h | 5 +- source/App/EncoderApp/EncApp.cpp | 8 +- source/App/EncoderApp/EncAppCfg.cpp | 12 +- source/App/EncoderApp/EncAppCfg.h | 9 +- source/Lib/CommonLib/CMakeLists.txt | 7 +- source/Lib/CommonLib/CodingStatistics.h | 2 +- source/Lib/CommonLib/CodingStructure.h | 4 - source/Lib/CommonLib/CommonDef.h | 2 +- source/Lib/CommonLib/Contexts.cpp | 2 +- source/Lib/CommonLib/Contexts.h | 2 +- .../{CnnLoopFilter.cpp => NNFilterSet0.cpp} | 308 +++++++++--------- .../{CnnLoopFilter.h => NNFilterSet0.h} | 24 +- source/Lib/CommonLib/Picture.h | 14 +- source/Lib/CommonLib/Rom.cpp | 5 + source/Lib/CommonLib/Rom.h | 4 + source/Lib/CommonLib/Slice.h | 12 +- source/Lib/CommonLib/TypeDef.h | 45 ++- source/Lib/DecoderLib/CABACReader.cpp | 6 +- source/Lib/DecoderLib/CABACReader.h | 2 +- source/Lib/DecoderLib/DecLib.cpp | 41 +-- source/Lib/DecoderLib/DecLib.h | 16 +- source/Lib/DecoderLib/DecSlice.cpp | 3 +- source/Lib/DecoderLib/VLCReader.cpp | 45 +-- source/Lib/DecoderLib/VLCReader.h | 3 +- source/Lib/EncoderLib/CABACWriter.cpp | 64 ++-- source/Lib/EncoderLib/CABACWriter.h | 4 +- source/Lib/EncoderLib/EncCfg.h | 18 +- source/Lib/EncoderLib/EncGOP.cpp | 32 +- source/Lib/EncoderLib/EncGOP.h | 13 +- source/Lib/EncoderLib/EncLib.cpp | 17 +- source/Lib/EncoderLib/EncLib.h | 12 +- ...cCnnLoopFilter.cpp => EncNNFilterSet0.cpp} | 58 ++-- .../{EncCnnLoopFilter.h => EncNNFilterSet0.h} | 15 +- source/Lib/EncoderLib/VLCWriter.cpp | 54 +-- source/Lib/EncoderLib/VLCWriter.h | 3 +- 39 files changed, 438 insertions(+), 447 deletions(-) rename CNNLF/IB_model.sadl => models/NnlfSet0_model_float.sadl (100%) rename CNNLF/IB_model_int16.sadl => models/NnlfSet0_model_int16.sadl (100%) rename source/Lib/CommonLib/{CnnLoopFilter.cpp => NNFilterSet0.cpp} (57%) rename source/Lib/CommonLib/{CnnLoopFilter.h => NNFilterSet0.h} (86%) rename source/Lib/EncoderLib/{EncCnnLoopFilter.cpp => EncNNFilterSet0.cpp} (84%) rename source/Lib/EncoderLib/{EncCnnLoopFilter.h => EncNNFilterSet0.h} (93%) diff --git a/CNNLF/IB_model.sadl b/models/NnlfSet0_model_float.sadl similarity index 100% rename from CNNLF/IB_model.sadl rename to models/NnlfSet0_model_float.sadl diff --git a/CNNLF/IB_model_int16.sadl b/models/NnlfSet0_model_int16.sadl similarity index 100% rename from CNNLF/IB_model_int16.sadl rename to models/NnlfSet0_model_int16.sadl diff --git a/source/App/DecoderApp/DecApp.cpp b/source/App/DecoderApp/DecApp.cpp index 905836ee27..96de4a463b 100644 --- a/source/App/DecoderApp/DecApp.cpp +++ b/source/App/DecoderApp/DecApp.cpp @@ -575,7 +575,11 @@ void DecApp::xCreateDecLib() #endif ); m_cDecLib.setDecodedPictureHashSEIEnabled(m_decodedPictureHashSEIEnabled); - + +#if NN_FILTERING_SET_0 + m_cDecLib.setModelPath(m_ModelPath); +#endif + #if NN_FILTERING_SET_1 m_cDecLib.setNnlfSet1InterLumaModelName (m_nnlfSet1InterLumaModelName); m_cDecLib.setNnlfSet1InterChromaModelName (m_nnlfSet1InterChromaModelName); diff --git a/source/App/DecoderApp/DecAppCfg.cpp b/source/App/DecoderApp/DecAppCfg.cpp index d1c4c1192f..58d4f9a6e4 100644 --- a/source/App/DecoderApp/DecAppCfg.cpp +++ b/source/App/DecoderApp/DecAppCfg.cpp @@ -80,14 +80,18 @@ bool DecAppCfg::parseCfg( int argc, char* argv[] ) #if NNVC_DUMP_DATA ("DumpBasename", m_dumpBasename, string(""), "basename for data dumping\n") #endif - + +#if NN_FILTERING_SET_0 + ("ModelPath,-mp", m_ModelPath, default_model_path, "model path\n") +#endif + #if NN_FILTERING_SET_1 ( "NnlfSet1InterLumaModel", m_nnlfSet1InterLumaModelName, string("models/NnlfSet1_LumaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter luma model name") ( "NnlfSet1InterChromaModel", m_nnlfSet1InterChromaModelName, string("models/NnlfSet1_ChromaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter chroma model name") ( "NnlfSet1IntraLumaModel", m_nnlfSet1IntraLumaModelName, string("models/NnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra luma model name") ( "NnlfSet1IntraChromaModel", m_nnlfSet1IntraChromaModelName, string("models/NnlfSet1_ChromaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra chroma model name") #endif - + ("OplFile,-opl", m_oplFilename , string(""), "opl-file name without extension for conformance testing\n") #if ENABLE_SIMD_OPT diff --git a/source/App/DecoderApp/DecAppCfg.h b/source/App/DecoderApp/DecAppCfg.h index 0058ade751..2315e946ec 100644 --- a/source/App/DecoderApp/DecAppCfg.h +++ b/source/App/DecoderApp/DecAppCfg.h @@ -58,7 +58,10 @@ class DecAppCfg protected: std::string m_bitstreamFileName; ///< input bitstream file name std::string m_reconFileName; ///< output reconstruction file name - +#if NN_FILTERING_SET_0 + std::string m_ModelPath; ///< model path +#endif + #if NNVC_DUMP_DATA std::string m_dumpBasename; ///< output basename for data #endif diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp index c4c7f58870..2ddd26bf53 100644 --- a/source/App/EncoderApp/EncApp.cpp +++ b/source/App/EncoderApp/EncApp.cpp @@ -530,6 +530,10 @@ void EncApp::xInitLibCfg() m_cEncLib.setNoChromaQpOffsetConstraintFlag(false); } +#if NN_FILTERING_SET_0 + m_cEncLib.setModelPath(m_ModelPath); +#endif + //====== Coding Structure ======== m_cEncLib.setIntraPeriod ( m_iIntraPeriod ); m_cEncLib.setDecodingRefreshType ( m_iDecodingRefreshType ); @@ -1054,8 +1058,8 @@ void EncApp::xInitLibCfg() m_cEncLib.setForceSingleSplitThread ( m_forceSplitSequential ); #endif -#if NN_FILTER - m_cEncLib.setUseCNNLF (m_cnnlf); +#if NN_FILTERING_SET_0 + m_cEncLib.setUseNnlfSet0 (m_nnlfSet0); #endif m_cEncLib.setUseALF ( m_alf ); diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp index b4ff413bb0..e046a21094 100644 --- a/source/App/EncoderApp/EncAppCfg.cpp +++ b/source/App/EncoderApp/EncAppCfg.cpp @@ -715,6 +715,9 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) opts.addOptions() ("help", do_help, false, "this help text") ("c", po::parseConfigFile, "configuration file name") +#if NN_FILTERING_SET_0 + ("ModelPath,-mp", m_ModelPath, default_model_path, "model path\n") +#endif ("WarnUnknowParameter,w", warnUnknowParameter, 0, "warn for unknown configuration parameters instead of failing") ("isSDR", sdr, false, "compatibility") #if ENABLE_SIMD_OPT @@ -1462,10 +1465,9 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) opts.addOptions() -#if NN_FILTER - ("CNNLF", m_cnnlf, true, "Cnn Loop Filter\n") +#if NN_FILTERING_SET_0 + ("NnlfSet0", m_nnlfSet0, true, "NN-based loop filter set 0\n") #endif - ("TemporalFilter", m_gopBasedTemporalFilterEnabled, false, "Enable GOP based temporal filter. Disabled per default") ("TemporalFilterFutureReference", m_gopBasedTemporalFilterFutureReference, true, "Enable referencing of future frames in the GOP based temporal filter. This is typically disabled for Low Delay configurations.") ("TemporalFilterStrengthFrame*", m_gopBasedTemporalFilterStrengths, std::map<int, double>(), "Strength for every * frame in GOP based temporal filter, where * is an integer." @@ -3992,8 +3994,8 @@ void EncAppCfg::xPrintParameter() msg( VERBOSE, "Slices: %d ", m_numSlicesInPic); msg( VERBOSE, "MCTS:%d ", m_MCTSEncConstraint ); -#if NN_FILTER - msg(VERBOSE, "CNNLF:%d", m_cnnlf ? 1 : 0); +#if NN_FILTERING_SET_0 + msg(VERBOSE, "NnlfSet0:%d ", m_nnlfSet0 ? 1 : 0); #endif msg( VERBOSE, "NNLFSET1:%d ", (m_nnlfSet1)?(1):(0)); diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h index 55a47a0461..2b445c7cd3 100644 --- a/source/App/EncoderApp/EncAppCfg.h +++ b/source/App/EncoderApp/EncAppCfg.h @@ -85,6 +85,9 @@ protected: std::string m_inputFileName; ///< source file name std::string m_bitstreamFileName; ///< output bitstream file std::string m_reconFileName; ///< output reconstruction file +#if NN_FILTERING_SET_0 + std::string m_ModelPath; ///< model path +#endif #if NN_FILTERING_SET_1 std::string m_nnlfSet1InterLumaModelName; ///<inter luma nnlf set1 model @@ -92,7 +95,7 @@ protected: std::string m_nnlfSet1IntraLumaModelName; ///<intra luma nnlf set1 model std::string m_nnlfSet1IntraChromaModelName; ///<inra chroma nnlf set1 model #endif - + // Lambda modifiers double m_adLambdaModifier[ MAX_TLAYER ]; ///< Lambda modifier array for each temporal layer std::vector<double> m_adIntraLambdaModifier; ///< Lambda modifier for Intra pictures, one for each temporal layer. If size>temporalLayer, then use [temporalLayer], else if size>0, use [size()-1], else use m_adLambdaModifier. @@ -700,8 +703,8 @@ protected: bool m_bs2ModPOCAndType; bool m_forceDecodeBitstream1; -#if NN_FILTER - bool m_cnnlf; ///< CNN Loop Filter +#if NN_FILTERING_SET_0 + bool m_nnlfSet0; ///< CNN Loop Filter #endif bool m_alf; ///< Adaptive Loop Filter diff --git a/source/Lib/CommonLib/CMakeLists.txt b/source/Lib/CommonLib/CMakeLists.txt index 8e11521739..f736b6e5e2 100644 --- a/source/Lib/CommonLib/CMakeLists.txt +++ b/source/Lib/CommonLib/CMakeLists.txt @@ -106,15 +106,14 @@ elseif( UNIX OR MINGW ) set_property( SOURCE ${AVX2_SRC_FILES} APPEND PROPERTY COMPILE_FLAGS "-mavx2" ) endif() -if( MSVC ) - set_property( SOURCE CnnLoopFilter.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ") +if( MSVC ) + set_property( SOURCE NNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ") set_property( SOURCE NNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ") elseif( UNIX OR MINGW ) - set_property( SOURCE CnnLoopFilter.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw") + set_property( SOURCE NNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw") set_property( SOURCE NNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw") endif() - # example: place header files in different folders source_group( "Natvis Files" FILES ${NATVIS_FILES} ) diff --git a/source/Lib/CommonLib/CodingStatistics.h b/source/Lib/CommonLib/CodingStatistics.h index ddef1b8764..1b926b8828 100644 --- a/source/Lib/CommonLib/CodingStatistics.h +++ b/source/Lib/CommonLib/CodingStatistics.h @@ -97,7 +97,7 @@ enum CodingStatisticsType #endif STATS__CABAC_BITS__SAO, -#if NN_FILTER +#if NN_FILTERING_SET_0 STATS__CABAC_BITS__CNNLF, #endif diff --git a/source/Lib/CommonLib/CodingStructure.h b/source/Lib/CommonLib/CodingStructure.h index 1a2ddc9e66..8114fb225c 100644 --- a/source/Lib/CommonLib/CodingStructure.h +++ b/source/Lib/CommonLib/CodingStructure.h @@ -293,10 +293,6 @@ public: const CPelUnitBuf getPredBufCustom(const UnitArea &unit) const; #endif -#if DATA_PREDICTION - PelUnitBuf getPredBufCustom() { return m_predCustom; } -#endif - PelBuf getResiBuf(const CompArea &blk); const CPelBuf getResiBuf(const CompArea &blk) const; PelUnitBuf getResiBuf(const UnitArea &unit); diff --git a/source/Lib/CommonLib/CommonDef.h b/source/Lib/CommonLib/CommonDef.h index bdffde89b6..ad7713270f 100644 --- a/source/Lib/CommonLib/CommonDef.h +++ b/source/Lib/CommonLib/CommonDef.h @@ -180,7 +180,7 @@ static const int MAXIMUM_INTRA_FILTERED_HEIGHT = 16; static const int MIP_MAX_WIDTH = MAX_TB_SIZEY; static const int MIP_MAX_HEIGHT = MAX_TB_SIZEY; -#if NN_FILTER +#if NN_FILTERING_SET_0 static const int MAX_NUM_CNN = 4; #endif diff --git a/source/Lib/CommonLib/Contexts.cpp b/source/Lib/CommonLib/Contexts.cpp index 72a6d6d1bf..35b5e67815 100644 --- a/source/Lib/CommonLib/Contexts.cpp +++ b/source/Lib/CommonLib/Contexts.cpp @@ -784,7 +784,7 @@ const CtxSet ContextSetCfg::ctbAlfFlag = ContextSetCfg::addCtxSet { 0, 0, 0, 4, 0, 0, 1, 0, 0, }, }); -#if NN_FILTER +#if NN_FILTERING_SET_0 const CtxSet ContextSetCfg::ctbCnnlfFlag = { ContextSetCfg::addCtxSet diff --git a/source/Lib/CommonLib/Contexts.h b/source/Lib/CommonLib/Contexts.h index 1a02107c6b..40a67e0002 100644 --- a/source/Lib/CommonLib/Contexts.h +++ b/source/Lib/CommonLib/Contexts.h @@ -257,7 +257,7 @@ public: static const CtxSet ChromaQpAdjIdc; static const CtxSet ImvFlag; static const CtxSet BcwIdx; -#if NN_FILTER +#if NN_FILTERING_SET_0 static const CtxSet ctbCnnlfFlag; #endif static const CtxSet ctbAlfFlag; diff --git a/source/Lib/CommonLib/CnnLoopFilter.cpp b/source/Lib/CommonLib/NNFilterSet0.cpp similarity index 57% rename from source/Lib/CommonLib/CnnLoopFilter.cpp rename to source/Lib/CommonLib/NNFilterSet0.cpp index 828e39164d..6b70ae2da5 100644 --- a/source/Lib/CommonLib/CnnLoopFilter.cpp +++ b/source/Lib/CommonLib/NNFilterSet0.cpp @@ -31,19 +31,18 @@ * THE POSSIBILITY OF SUCH DAMAGE. */ -/** \file CnnLoopFilter.cpp +/** \file NNFilterSet0.cpp \brief cnn loop filter class */ -#include "CnnLoopFilter.h" +#include "NNFilterSet0.h" -#if NN_FILTER +#if NN_FILTERING_SET_0 #include "CodingStructure.h" #include "Picture.h" -CnnLoopFilter::CnnLoopFilter() +NNFilterSet0::NNFilterSet0() { - for( int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++ ) { m_ctuEnableFlag[compIdx] = nullptr; @@ -51,10 +50,9 @@ CnnLoopFilter::CnnLoopFilter() m_initFlag = false; } -void CnnLoopFilter::PreCNNLFProcess(Picture* pic, CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam) +void NNFilterSet0::PreCNNLFProcess(Picture* pic, CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam) { - int baseQP = cs.slice->getPPS()->getPicInitQPMinus26() + 26; - initCnnModel(baseQP, cs.slice->getSliceQp()); + initCnnModel(); if (!cnnlfSliceParam.enabledFlag[COMPONENT_Y] && !cnnlfSliceParam.enabledFlag[COMPONENT_Cb] && !cnnlfSliceParam.enabledFlag[COMPONENT_Cr]) { @@ -71,11 +69,8 @@ void CnnLoopFilter::PreCNNLFProcess(Picture* pic, CodingStructure& cs, CnnlfSlic m_tempCnnBuf[0].copyFrom(recYuv); PelUnitBuf cnnYuv = m_tempCnnBuf[0].getBuf(cs.area); - PelUnitBuf predYuv = cs.getPredBufCustom(); - PelUnitBuf partitionYuv = cs.picture->getCuAverageBuf(); - // run DRNLF - runCNNLF(pic, recYuv, predYuv, partitionYuv, cnnYuv, cs.slice->getSliceType(), baseQP, cs.slice->getSliceQp(), true); + runCNNLF(pic, cnnYuv, 0, true); for (int i = 1; i < MAX_NUM_CNN; i++) { @@ -83,7 +78,7 @@ void CnnLoopFilter::PreCNNLFProcess(Picture* pic, CodingStructure& cs, CnnlfSlic } } -void CnnLoopFilter::CNNLFProcess( CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam ) +void NNFilterSet0::CNNLFProcess( CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam ) { if( !cnnlfSliceParam.enabledFlag[COMPONENT_Y] && !cnnlfSliceParam.enabledFlag[COMPONENT_Cb] && !cnnlfSliceParam.enabledFlag[COMPONENT_Cr] ) { @@ -139,7 +134,7 @@ void CnnLoopFilter::CNNLFProcess( CodingStructure& cs, CnnlfSliceParam& cnnlfSli filterPic(cs, cnnlfSliceParam); } -void CnnLoopFilter::create( const int picWidth, const int picHeight, const ChromaFormat format, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE]) +void NNFilterSet0::create( const int picWidth, const int picHeight, const ChromaFormat format, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE], std::string path) { std::memcpy( m_inputBitDepth, inputBitDepth, sizeof( m_inputBitDepth ) ); m_picWidth = picWidth; @@ -148,6 +143,7 @@ void CnnLoopFilter::create( const int picWidth, const int picHeight, const Chrom m_maxCUHeight = maxCUHeight; m_maxCUDepth = maxCUDepth; m_chromaFormat = format; + m_ModelPath = path; m_numCTUsInWidth = ( m_picWidth / m_maxCUWidth ) + ( ( m_picWidth % m_maxCUWidth ) ? 1 : 0 ); m_numCTUsInHeight = ( m_picHeight / m_maxCUHeight ) + ( ( m_picHeight % m_maxCUHeight ) ? 1 : 0 ); @@ -160,7 +156,7 @@ void CnnLoopFilter::create( const int picWidth, const int picHeight, const Chrom } } -void CnnLoopFilter::destroy() +void NNFilterSet0::destroy() { for (int i = 0; i < MAX_NUM_CNN; i++) { @@ -169,7 +165,7 @@ void CnnLoopFilter::destroy() } -void CnnLoopFilter::filterPic(CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam) +void NNFilterSet0::filterPic(CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam) { PelUnitBuf recYuv = cs.getRecoBuf(); PelUnitBuf cnnYuv[MAX_NUM_CHANNEL_TYPE]; @@ -219,42 +215,39 @@ void CnnLoopFilter::filterPic(CodingStructure& cs, CnnlfSliceParam& cnnlfSlicePa } } -void CnnLoopFilter::filterBlk( PelUnitBuf &recUnitBuf, const CPelUnitBuf& cnnUnitBuf, const Area& blk, const ComponentID compId, const ClpRng& clpRng ) +void NNFilterSet0::filterBlk( PelUnitBuf &recUnitBuf, const CPelUnitBuf& cnnUnitBuf, const Area& blk, const ComponentID compId, const ClpRng& clpRng ) { const CPelBuf cnnBlk = cnnUnitBuf.get(compId).subBuf(blk.pos(), blk.size()); PelBuf recBlk = recUnitBuf.get(compId).subBuf(blk.pos(), blk.size()); recBlk.copyFrom(cnnBlk); } -void CnnLoopFilter::runCNNLF(Picture* pic, const PelUnitBuf& recUnitBuf, const PelUnitBuf& predUnitBuf, const PelUnitBuf& partitionUnitBuf, PelUnitBuf& cnnUnitBuf, const SliceType slice_type, const int baseQP, const int iQP, bool is_dec) +bool NNFilterSet0::skipPatch(int x, int y, int ctuRsAddr, bool is_dec) +{ + if (is_dec) + { + if (!m_ctuEnableFlag[COMPONENT_Y][ctuRsAddr] && !m_ctuEnableFlag[COMPONENT_Cb][ctuRsAddr] && !m_ctuEnableFlag[COMPONENT_Cr][ctuRsAddr]) + { + return true; + } + } + return false; +} + +void NNFilterSet0::runCNNLF(Picture *pic, PelUnitBuf &cnnUnitBuf, const int baseQPoffset, bool is_dec) { + CodingStructure &cs = *pic->cs; + const int slice_type = cs.slice->getSliceType() != I_SLICE ? 1023 : 0; + const int baseQP = cs.slice->getPPS()->getPicInitQPMinus26() + 26; + const int sliceQp = cs.slice->getSliceQp(); + const int baseQPFinal = baseQP + baseQPoffset; int patchSize = 128; int padSize = 8; -#if NN_SCALE - double out_maxValue = (1023 << NN_SCALE_EXT_SHIFT); -#if SADL_INT16 - constexpr int org_quantizers_in = 10; - constexpr int sadl_quantizers_in = 11; - constexpr int in_shift = sadl_quantizers_in - org_quantizers_in; - constexpr int org_quantizers_out = 10 + NN_SCALE_EXT_SHIFT; - constexpr int sadl_quantizers_out = 11; - constexpr int out_shift = org_quantizers_out - sadl_quantizers_out; -#else - double maxValue = 1023; -#endif -#else - double out_maxValue = 1023; -#if SADL_INT16 - constexpr int org_quantizers_in = 10; - constexpr int sadl_quantizers_in = 11; - constexpr int in_shift = sadl_quantizers_in - org_quantizers_in; - constexpr int org_quantizers_out = 10; - constexpr int sadl_quantizers_out = 11; - constexpr int out_shift = org_quantizers_out - sadl_quantizers_out; -#else - double maxValue = 1023; -#endif -#endif + + double in_maxValue = 1023; + const int org_quantizers_in = 10; + const int sadl_quantizers_in = 11; + const int in_left_shift = sadl_quantizers_in - org_quantizers_in; int picWidth = pic->getPicWidthInLumaSamples(); int picHeight = pic->getPicHeightInLumaSamples(); @@ -262,154 +255,138 @@ void CnnLoopFilter::runCNNLF(Picture* pic, const PelUnitBuf& recUnitBuf, const P int picWidthInPatchs = ceil((double)picWidth / patchSize); int picHeightInPatchs = ceil((double)picHeight / patchSize); - int ctuRsAddr = 0; for (int y = 0; y < picHeightInPatchs; y++) { for (int x = 0; x < picWidthInPatchs; x++) { - if (is_dec) + if (skipPatch(x, y, ctuRsAddr, is_dec)) { - if (!m_ctuEnableFlag[COMPONENT_Y][ctuRsAddr] && !m_ctuEnableFlag[COMPONENT_Cb][ctuRsAddr] && !m_ctuEnableFlag[COMPONENT_Cr][ctuRsAddr]) - { - ctuRsAddr++; - continue; - } + ctuRsAddr++; + continue; } + // patch area int pix_y = y * patchSize; int pix_x = x * patchSize; - int pix_y_end = (y+1) * patchSize > picHeight ? picHeight - 1 : (y + 1) * patchSize - 1; int pix_x_end = (x+1) * patchSize > picWidth ? picWidth - 1 : (x + 1) * patchSize - 1; + // patch area with the extension int st_h = pix_y - padSize < 0 ? 0 : pix_y - padSize; int ed_h = pix_y_end + padSize + 1 > picHeight ? picHeight - 1 : pix_y_end + padSize; int st_w = pix_x - padSize < 0 ? 0 : pix_x - padSize; int ed_w = pix_x_end + padSize + 1 > picWidth ? picWidth - 1 : pix_x_end + padSize; + // patch size int actualPatchSizeH = ed_h - st_h + 1; int actualPatchSizeW = ed_w - st_w + 1; - m_Input = m_Module.getInputsTemplate(); - unsigned int m_Input_id = 0; - for (auto& t : m_Input) - { - if (m_Input_id < 2) - { - t.resize(std::initializer_list<int>({ 1, actualPatchSizeH, actualPatchSizeW, 3 })); - } - else - { - t.resize(std::initializer_list<int>({ 1, actualPatchSizeH, actualPatchSizeW, 1 })); - } - m_Input_id++; - } - if (!m_Module.init(m_Input)) { - cerr << "[ERROR] issue during initialization" << endl; - exit(-1); - } + // preparation and inference + std::vector<InputData> listInputData; + InputData inputRec = { NN_INPUT_REC, 0, in_maxValue, in_left_shift, true, true }; + InputData inputPred = { NN_INPUT_PRED, 1, in_maxValue, in_left_shift, true, true }; + InputData inputSliceQp = { NN_INPUT_LOCAL_QP, 2, in_maxValue, in_left_shift, true, true }; + InputData inputBaseQp = { NN_INPUT_GLOBAL_QP, 3, in_maxValue, in_left_shift, true, true }; + InputData inputSliceType = { NN_INPUT_SLICE_TYPE, 4, in_maxValue, in_left_shift, true, true }; + listInputData.push_back(inputRec); + listInputData.push_back(inputPred); + listInputData.push_back(inputSliceQp); + listInputData.push_back(inputBaseQp); + listInputData.push_back(inputSliceType); + const UnitArea inferBlock(cs.area.chromaFormat, Area(st_w, st_h, actualPatchSizeW, actualPatchSizeH)); + initPatch(actualPatchSizeW, actualPatchSizeH); + NNInference::prepareInputs<TypeSadl>(pic, inferBlock, m_Input, baseQPFinal, sliceQp, slice_type, listInputData); + NNInference::infer<TypeSadl>(m_Module, m_Input); + + // extract the results + extractOutputs(pic, pix_x, pix_y, pix_x_end, pix_y_end, st_w, st_h, cnnUnitBuf); - for (int yy = 0; yy < actualPatchSizeH; yy++) - { - for (int xx = 0; xx < actualPatchSizeW; xx++) - { - int id_x[2], id_y[2]; - id_x[0] = st_w + xx; - id_y[0] = st_h + yy; - id_x[1] = (st_w >> 1) + (xx >> 1); - id_y[1] = (st_h >> 1) + (yy >> 1); - -#if SADL_INT16 - m_Input[0](0, yy, xx, COMPONENT_Y) = recUnitBuf.get(COMPONENT_Y).at(id_x[0], id_y[0]) << in_shift; - m_Input[0](0, yy, xx, COMPONENT_Cb) = recUnitBuf.get(COMPONENT_Cb).at(id_x[1], id_y[1]) << in_shift; - m_Input[0](0, yy, xx, COMPONENT_Cr) = recUnitBuf.get(COMPONENT_Cr).at(id_x[1], id_y[1]) << in_shift; - - m_Input[1](0, yy, xx, COMPONENT_Y) = predUnitBuf.get(COMPONENT_Y).at(id_x[0], id_y[0]) << in_shift; - m_Input[1](0, yy, xx, COMPONENT_Cb) = predUnitBuf.get(COMPONENT_Cb).at(id_x[1], id_y[1]) << in_shift; - m_Input[1](0, yy, xx, COMPONENT_Cr) = predUnitBuf.get(COMPONENT_Cr).at(id_x[1], id_y[1]) << in_shift; - - m_Input[2](0, yy, xx, 0) = iQP << in_shift; - m_Input[3](0, yy, xx, 0) = baseQP << in_shift; - m_Input[4](0, yy, xx, 0) = (slice_type == I_SLICE ? 0 : 1023) << in_shift; -#else - m_Input[0](0, yy, xx, COMPONENT_Y) = recUnitBuf.get(COMPONENT_Y).at(id_x[0], id_y[0]) / maxValue; - m_Input[0](0, yy, xx, COMPONENT_Cb) = recUnitBuf.get(COMPONENT_Cb).at(id_x[1], id_y[1]) / maxValue; - m_Input[0](0, yy, xx, COMPONENT_Cr) = recUnitBuf.get(COMPONENT_Cr).at(id_x[1], id_y[1]) / maxValue; - - m_Input[1](0, yy, xx, COMPONENT_Y) = predUnitBuf.get(COMPONENT_Y).at(id_x[0], id_y[0]) / maxValue; - m_Input[1](0, yy, xx, COMPONENT_Cb) = predUnitBuf.get(COMPONENT_Cb).at(id_x[1], id_y[1]) / maxValue; - m_Input[1](0, yy, xx, COMPONENT_Cr) = predUnitBuf.get(COMPONENT_Cr).at(id_x[1], id_y[1]) / maxValue; + ctuRsAddr++; + } + } +} - m_Input[2](0, yy, xx, 0) = iQP / maxValue; - m_Input[3](0, yy, xx, 0) = baseQP / maxValue; - m_Input[4](0, yy, xx, 0) = (slice_type == I_SLICE ? 0 : maxValue) / maxValue; +void NNFilterSet0::extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_end, int pix_y_end, int st_w, int st_h, PelUnitBuf &cnnUnitBuf) +{ + CodingStructure &cs = *pic->cs; + PelUnitBuf recUnitBuf = cs.getRecoBuf(); +#if NN_SCALE + int nn_scale_shift = NN_SCALE_EXT_SHIFT; +#else + int nn_scale_shift = 0; #endif - } - } - if (!m_Module.apply(m_Input)) { - cerr << "[ERROR] issue during inference" << endl; - exit(-1); - } +#if !NN_FIXED_POINT_IMPLEMENTATION + double in_maxValue = 1023; +#endif + int real_maxValue = 1023; + int out_maxValue = real_maxValue << nn_scale_shift; + +#if NN_FIXED_POINT_IMPLEMENTATION + const int org_quantizers_in = 10; + const int org_quantizers_out = 10 + nn_scale_shift; + const int sadl_quantizers_in = 11; + const int sadl_quantizers_out = 11; + const int in_left_shift = sadl_quantizers_in - org_quantizers_in; + const int out_left_shift = org_quantizers_out - sadl_quantizers_out; +#endif - int centerH = pix_y_end - pix_y + 1; - int centerW = pix_x_end - pix_x + 1; - for (int kk = 0; kk < 4; kk++) + //extract luma + int centerH = pix_y_end - pix_y + 1; + int centerW = pix_x_end - pix_x + 1; + for (int kk = 0; kk < 4; kk++) + { + for (int yy = 0; yy < (centerH >> 1); yy++) + { + for (int xx = 0; xx < (centerW >> 1); xx++) { - for (int yy = 0; yy < (centerH >> 1); yy++) - { - for (int xx = 0; xx < (centerW >> 1); xx++) - { - int id_x = pix_x + (xx << 1) + kk % 2; - int id_y = pix_y + (yy << 1) + kk / 2; + int id_x = pix_x + (xx << 1) + kk % 2; + int id_y = pix_y + (yy << 1) + kk / 2; -#if SADL_INT16 - cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int(m_Module.result(0)(0, (id_y - st_h) >> 1, (id_x - st_w) >> 1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) << in_shift)) << out_shift)); +#if NN_FIXED_POINT_IMPLEMENTATION + cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int(m_Module.result(0)(0, (id_y - st_h) >> 1, (id_x - st_w) >> 1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) << in_left_shift)) << out_left_shift)); #else - cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int((m_Module.result(0)(0, (id_y - st_h) >> 1, (id_x - st_w)>>1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) / maxValue)) * out_maxValue + 0.5))); + cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int((m_Module.result(0)(0, (id_y - st_h) >> 1, (id_x - st_w)>>1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) / in_maxValue)) * out_maxValue + 0.5))); #endif - } - } } + } + } - int centerH_chroma = centerH >> 1; - int centerW_chroma = centerW >> 1; - for (int compID = COMPONENT_Cb; compID < MAX_NUM_COMPONENT; compID++) + // extract chroma + int centerH_chroma = centerH >> 1; + int centerW_chroma = centerW >> 1; + for (int compID = COMPONENT_Cb; compID < MAX_NUM_COMPONENT; compID++) + { + const ComponentID comp = ComponentID(compID); + for (int yy = 0; yy < centerH_chroma; yy++) + { + for (int xx = 0; xx < centerW_chroma; xx++) { - const ComponentID comp = ComponentID(compID); - for (int yy = 0; yy < centerH_chroma; yy++) + int id_x = (pix_x >> 1) + xx; + int id_y = (pix_y >> 1) + yy; + + float sample = 0; + for (int kk = 0; kk < 4; kk++) { - for (int xx = 0; xx < centerW_chroma; xx++) - { - int id_x = (pix_x >> 1) + xx; - int id_y = (pix_y >> 1) + yy; - - float sample = 0; - for (int kk = 0; kk < 4; kk++) - { - int pos_x = ((pix_x - st_w) >> 1) + xx; - int pos_y = ((pix_y - st_h) >> 1) + yy; - sample = sample + m_Module.result(0)(0, pos_y, pos_x, compID*4 + kk); - } - sample = sample / 4; - -#if SADL_INT16 - cnnUnitBuf.get(comp).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int(sample + (recUnitBuf.get(comp).at(id_x, id_y) << in_shift)) << out_shift)); + int pos_x = ((pix_x - st_w) >> 1) + xx; + int pos_y = ((pix_y - st_h) >> 1) + yy; + sample = sample + m_Module.result(0)(0, pos_y, pos_x, compID*4 + kk); + } + sample = sample / 4; + +#if NN_FIXED_POINT_IMPLEMENTATION + cnnUnitBuf.get(comp).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int(sample + (recUnitBuf.get(comp).at(id_x, id_y) << in_left_shift)) << out_left_shift)); #else - cnnUnitBuf.get(comp).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int((sample + (recUnitBuf.get(comp).at(id_x, id_y) / maxValue)) * out_maxValue + 0.5))); + cnnUnitBuf.get(comp).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int((sample + (recUnitBuf.get(comp).at(id_x, id_y) / in_maxValue)) * out_maxValue + 0.5))); #endif - } - } } - - ctuRsAddr++; } } } #if NN_SCALE -void CnnLoopFilter::scaleResidue(CodingStructure& cs, PelUnitBuf recUnitBuf, PelUnitBuf cnnYuv, int *scale_list, bool is_dec) +void NNFilterSet0::scaleResidue(CodingStructure& cs, PelUnitBuf recUnitBuf, PelUnitBuf cnnYuv, int *scale_list, bool is_dec) { const PreCalcValues& pcv = *cs.pcv; for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) @@ -459,19 +436,42 @@ void CnnLoopFilter::scaleResidue(CodingStructure& cs, PelUnitBuf recUnitBuf, Pel } #endif -void CnnLoopFilter::initCnnModel(const int baseQP, const int iQP) +void NNFilterSet0::initPatch(const int PatchWidth, const int PatchHeight) +{ + m_Input = m_Module.getInputsTemplate(); + unsigned int m_Input_id = 0; + for (auto &t: m_Input) + { + if (m_Input_id < 2) + { + t.resize(std::initializer_list<int>({ 1, PatchHeight, PatchWidth, 3 })); + } + else + { + t.resize(std::initializer_list<int>({ 1, PatchHeight, PatchWidth, 1 })); + } + m_Input_id++; + } + if (!m_Module.init(m_Input)) + { + cerr << "[ERROR] issue during initialization" << endl; + exit(-1); + } +} + +void NNFilterSet0::initCnnModel() { if (m_initFlag) { return; } - std::string rootPath = "./CNNLF/"; + std::string rootPath = m_ModelPath; std::string ModelPath; -#if SADL_INT16 - ModelPath = rootPath + "IB_model_int16.sadl"; +#if NN_FIXED_POINT_IMPLEMENTATION + ModelPath = rootPath + "NnlfSet0_model_int16.sadl"; #else - ModelPath = rootPath + "IB_model.sadl"; + ModelPath = rootPath + "NnlfSet0_model_float.sadl"; #endif ifstream file(ModelPath, ios::binary); diff --git a/source/Lib/CommonLib/CnnLoopFilter.h b/source/Lib/CommonLib/NNFilterSet0.h similarity index 86% rename from source/Lib/CommonLib/CnnLoopFilter.h rename to source/Lib/CommonLib/NNFilterSet0.h index b9569d0791..b39d1450fc 100644 --- a/source/Lib/CommonLib/CnnLoopFilter.h +++ b/source/Lib/CommonLib/NNFilterSet0.h @@ -31,7 +31,7 @@ * THE POSSIBILITY OF SUCH DAMAGE. */ -/** \file CnnLoopFilter.h +/** \file NNFilterSet0.h \brief cnn loop filter class (header) */ @@ -40,22 +40,23 @@ #include "CommonDef.h" -#if NN_FILTER +#if NN_FILTERING_SET_0 #include "Unit.h" #include "Picture.h" #define HAVE_INTTYPES_H 1 #define __STDC_FORMAT_MACROS #include <sadl/model.h> +#include "NNInference.h" #include <fstream> using namespace std; -class CnnLoopFilter +class NNFilterSet0 { public: - CnnLoopFilter(); - virtual ~CnnLoopFilter() {} - void create(const int picWidth, const int picHeight, const ChromaFormat format, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE]); + NNFilterSet0(); + virtual ~NNFilterSet0() {} + void create(const int picWidth, const int picHeight, const ChromaFormat format, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE], std::string path); void destroy(); void PreCNNLFProcess(Picture* pic, CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam); void CNNLFProcess(CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam); @@ -63,6 +64,7 @@ public: protected: uint8_t* m_ctuEnableFlag[MAX_NUM_COMPONENT]; PelStorage m_tempCnnBuf[MAX_NUM_CNN]; + int m_inputBitDepth[MAX_NUM_CHANNEL_TYPE]; int m_picWidth; int m_picHeight; @@ -76,14 +78,20 @@ protected: ClpRngs m_clpRngs; bool m_initFlag; + std::string m_ModelPath; sadl::Model<TypeSadl> m_Module; vector<sadl::Tensor<TypeSadl>> m_Input; void filterPic(CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam); void filterBlk(PelUnitBuf &recDst, const CPelUnitBuf& recSrc, const Area& blk, const ComponentID compI, const ClpRng& clpRng); - void initCnnModel(const int baseQP, const int iQP); - void runCNNLF(Picture* pic, const PelUnitBuf& recUnitBuf, const PelUnitBuf& predUnitBuf, const PelUnitBuf& partitionUnitBuf, PelUnitBuf& cnnUnitBuf, const SliceType slice_type, const int baseQP, const int iQP, bool is_dec); + void initCnnModel(); + void initPatch(const int PatchWidth, const int PatchHeight); + + bool skipPatch(int x, int y, int ctuRsAddr, bool is_dec); + void runCNNLF(Picture* pic, PelUnitBuf& cnnUnitBuf, const int baseQPoffset, bool is_dec); + void extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_end, int pix_y_end, int st_w, int st_h, PelUnitBuf &cnnUnitBuf); + #if NN_SCALE void scaleResidue(CodingStructure& cs, PelUnitBuf recUnitBuf, PelUnitBuf cnnYuv, int *scale_list, bool is_dec); #endif diff --git a/source/Lib/CommonLib/Picture.h b/source/Lib/CommonLib/Picture.h index db07c59d02..eb934e6f8b 100644 --- a/source/Lib/CommonLib/Picture.h +++ b/source/Lib/CommonLib/Picture.h @@ -371,7 +371,7 @@ public: } } #endif - + std::vector<uint8_t> m_alfCtuEnableFlag[MAX_NUM_COMPONENT]; uint8_t* getAlfCtuEnableFlag( int compIdx ) { return m_alfCtuEnableFlag[compIdx].data(); } std::vector<uint8_t>* getAlfCtuEnableFlag() { return m_alfCtuEnableFlag; } @@ -406,17 +406,17 @@ public: } } -#if NN_FILTER +#if NN_FILTERING_SET_0 std::vector<uint8_t> m_cnnlfCtuEnableFlag[MAX_NUM_COMPONENT]; uint8_t* getCnnlfCtuEnableFlag(int compIdx) { return m_cnnlfCtuEnableFlag[compIdx].data(); } std::vector<uint8_t>* getCnnlfCtuEnableFlag() { return m_cnnlfCtuEnableFlag; } void resizeCnnlfCtuEnableFlag(int numEntries) { - for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) - { - m_cnnlfCtuEnableFlag[compIdx].resize(numEntries); - std::fill(m_cnnlfCtuEnableFlag[compIdx].begin(), m_cnnlfCtuEnableFlag[compIdx].end(), 0); - } + for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) + { + m_cnnlfCtuEnableFlag[compIdx].resize(numEntries); + std::fill(m_cnnlfCtuEnableFlag[compIdx].begin(), m_cnnlfCtuEnableFlag[compIdx].end(), 0); + } } #endif diff --git a/source/Lib/CommonLib/Rom.cpp b/source/Lib/CommonLib/Rom.cpp index dc1c29aede..43b2a4a6fb 100644 --- a/source/Lib/CommonLib/Rom.cpp +++ b/source/Lib/CommonLib/Rom.cpp @@ -44,6 +44,11 @@ #include <math.h> #include <iomanip> + +#if NN_FILTERING_SET_0 +std::string default_model_path = "./models/"; +#endif + // ==================================================================================================================== // Initialize / destroy functions // ==================================================================================================================== diff --git a/source/Lib/CommonLib/Rom.h b/source/Lib/CommonLib/Rom.h index e7352e3c10..d56ab592d1 100644 --- a/source/Lib/CommonLib/Rom.h +++ b/source/Lib/CommonLib/Rom.h @@ -48,6 +48,10 @@ //! \ingroup CommonLib //! \{ +#if NN_FILTERING_SET_0 +extern std::string default_model_path; +#endif + // ==================================================================================================================== // Initialize / destroy functions // ==================================================================================================================== diff --git a/source/Lib/CommonLib/Slice.h b/source/Lib/CommonLib/Slice.h index a4727aa55e..49e7efa514 100644 --- a/source/Lib/CommonLib/Slice.h +++ b/source/Lib/CommonLib/Slice.h @@ -1476,7 +1476,7 @@ private: static const int m_winUnitY[NUM_CHROMA_FORMAT]; ProfileTierLevel m_profileTierLevel; -#if NN_FILTER +#if NN_FILTERING_SET_0 bool m_cnnlfEnabledFlag; #endif @@ -1741,9 +1741,9 @@ public: void setSAOEnabledFlag(bool bVal) { m_saoEnabledFlag = bVal; } bool getSAOEnabledFlag() const { return m_saoEnabledFlag; } -#if NN_FILTER - bool getCNNLFEnabledFlag() const { return m_cnnlfEnabledFlag; } - void setCNNLFEnabledFlag(bool b) { m_cnnlfEnabledFlag = b; } +#if NN_FILTERING_SET_0 + bool getNnlfSet0EnabledFlag() const { return m_cnnlfEnabledFlag; } + void setNnlfSet0EnabledFlag(bool b) { m_cnnlfEnabledFlag = b; } #endif bool getALFEnabledFlag() const { return m_alfEnabledFlag; } @@ -2746,7 +2746,7 @@ private: bool m_disableSATDForRd; bool m_isLossless; -#if NN_FILTER +#if NN_FILTERING_SET_0 CnnlfSliceParam m_cnnlfSliceParam; #if NN_SCALE int nn_scale[MAX_NUM_COMPONENT]; @@ -3032,7 +3032,7 @@ public: void resetProcessingTime() { m_dProcessingTime = m_iProcessingStartTime = 0; } double getProcessingTime() const { return m_dProcessingTime; } -#if NN_FILTER +#if NN_FILTERING_SET_0 void setCnnlfSliceParam(CnnlfSliceParam& cnnlfSliceParam) { m_cnnlfSliceParam = cnnlfSliceParam; } CnnlfSliceParam& getCnnlfSliceParam() { return m_cnnlfSliceParam; } #if NN_SCALE diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 05b3967926..b478ce7ed5 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -67,20 +67,17 @@ #define NNVC_USE_PARTITION_AS_CU_AVERAGE 1 // average on the CU #define NNVC_USE_QP 1 // QP slice #define NNVC_USE_SLICETYPE 1 // slice type + +// nn filter set 0 +#define NN_FILTERING_SET_0 1 +#if NN_FILTERING_SET_0 -#define NN_FILTER 1 -#if NN_FILTER -#define SADL_INT16 1 - -#if SADL_INT16 +#if NN_FIXED_POINT_IMPLEMENTATION using TypeSadl = int16_t; #else using TypeSadl = float; #endif -#define DATA_PREDICTION 1 -#define DATA_PARTITION 1 - #define NN_SCALE 1 #if NN_SCALE #define NN_SCALE_SHIFT 8 @@ -89,7 +86,7 @@ using TypeSadl = float; #define NN_SCALE_UP_FACTOR 1.25 #define NN_SCALE_BOT_FACTOR 0.0625 #endif -#endif +#endif //end of filter set 0 // nn filter set 1 #define NN_FILTERING_SET_1 1 @@ -1416,28 +1413,26 @@ struct XUCache #define SIGN(x) ( (x) >= 0 ? 1 : -1 ) -#if NN_FILTER +#if NN_FILTERING_SET_0 struct CnnlfSliceParam { - bool enabledFlag[MAX_NUM_COMPONENT]; - char frameCtrlFlag[2]; + bool enabledFlag[MAX_NUM_COMPONENT]; + char frameCtrlFlag[2]; - void reset() - { - std::memset(enabledFlag, false, sizeof(enabledFlag)); - std::memset(frameCtrlFlag, 0, sizeof(frameCtrlFlag)); - - } + void reset() + { + std::memset(enabledFlag, false, sizeof(enabledFlag)); + std::memset(frameCtrlFlag, 0, sizeof(frameCtrlFlag)); + } - const CnnlfSliceParam& operator = (const CnnlfSliceParam& src) - { - std::memcpy(enabledFlag, src.enabledFlag, sizeof(enabledFlag)); - std::memcpy(frameCtrlFlag, src.frameCtrlFlag, sizeof(frameCtrlFlag)); - return *this; - } + const CnnlfSliceParam& operator = (const CnnlfSliceParam& src) + { + std::memcpy(enabledFlag, src.enabledFlag, sizeof(enabledFlag)); + std::memcpy(frameCtrlFlag, src.frameCtrlFlag, sizeof(frameCtrlFlag)); + return *this; + } }; #endif - //! \} #endif diff --git a/source/Lib/DecoderLib/CABACReader.cpp b/source/Lib/DecoderLib/CABACReader.cpp index 64cd75b792..826813df83 100644 --- a/source/Lib/DecoderLib/CABACReader.cpp +++ b/source/Lib/DecoderLib/CABACReader.cpp @@ -143,7 +143,7 @@ void CABACReader::coding_tree_unit( CodingStructure& cs, const UnitArea& area, i cs.modeType = partitioner.modeType = MODE_TYPE_ALL; -#if NN_FILTER +#if NN_FILTERING_SET_0 for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) { readCnnlfCtuEnableFlag(cs, ctuRsAddr, compIdx); @@ -296,11 +296,11 @@ void CABACReader::coding_tree_unit( CodingStructure& cs, const UnitArea& area, i } -#if NN_FILTER +#if NN_FILTERING_SET_0 void CABACReader::readCnnlfCtuEnableFlag(CodingStructure& cs, uint32_t ctuRsAddr, const int compIdx) { CnnlfSliceParam& cnnlfSliceParam = cs.slice->getCnnlfSliceParam(); - if (cs.sps->getCNNLFEnabledFlag() && cnnlfSliceParam.enabledFlag[compIdx]) + if (cs.sps->getNnlfSet0EnabledFlag() && cnnlfSliceParam.enabledFlag[compIdx]) { const PreCalcValues& pcv = *cs.pcv; int frame_width_in_ctus = pcv.widthInCtus; diff --git a/source/Lib/DecoderLib/CABACReader.h b/source/Lib/DecoderLib/CABACReader.h index 20292d359e..ab27eba411 100644 --- a/source/Lib/DecoderLib/CABACReader.h +++ b/source/Lib/DecoderLib/CABACReader.h @@ -65,7 +65,7 @@ public: // coding tree unit (clause 7.3.8.2) void coding_tree_unit ( CodingStructure& cs, const UnitArea& area, int (&qps)[2], unsigned ctuRsAddr ); -#if NN_FILTER +#if NN_FILTERING_SET_0 void readCnnlfCtuEnableFlag(CodingStructure& cs, uint32_t ctuRsAddr, const int compIdx); #endif diff --git a/source/Lib/DecoderLib/DecLib.cpp b/source/Lib/DecoderLib/DecLib.cpp index dd8f566bb6..04923354e2 100644 --- a/source/Lib/DecoderLib/DecLib.cpp +++ b/source/Lib/DecoderLib/DecLib.cpp @@ -192,18 +192,18 @@ bool tryDecodePicture( Picture* pcEncPic, const int expectedPoc, const std::stri } else { -#if NN_FILTER - if (pic->cs->sps->getCNNLFEnabledFlag()) +#if NN_FILTERING_SET_0 + if (pic->cs->sps->getNnlfSet0EnabledFlag()) + { + for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) + { + std::copy(pic->getCnnlfCtuEnableFlag()[compIdx].begin(), pic->getCnnlfCtuEnableFlag()[compIdx].end(), pcEncPic->getCnnlfCtuEnableFlag()[compIdx].begin()); + } + for (int i = 0; i < pic->slices.size(); i++) { - for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) - { - std::copy(pic->getCnnlfCtuEnableFlag()[compIdx].begin(), pic->getCnnlfCtuEnableFlag()[compIdx].end(), pcEncPic->getCnnlfCtuEnableFlag()[compIdx].begin()); - } - for (int i = 0; i < pic->slices.size(); i++) - { - pcEncPic->slices[i]->getCnnlfSliceParam() = pic->slices[i]->getCnnlfSliceParam(); - } + pcEncPic->slices[i]->getCnnlfSliceParam() = pic->slices[i]->getCnnlfSliceParam(); } + } #endif if ( pic->cs->sps->getSAOEnabledFlag() ) { @@ -559,7 +559,7 @@ void DecLib::deletePicBuffer ( ) } m_cALF.destroy(); m_cSAO.destroy(); -#if NN_FILTER +#if NN_FILTERING_SET_0 m_cCNNLF.destroy(); #endif m_cLoopFilter.destroy(); @@ -691,6 +691,7 @@ void DecLib::executeLoopFilters() m_cReshaper.setRecReshaped(false); m_cSAO.setReshaper(&m_cReshaper); } + #if NNVC_USE_BS m_pcPic->getBsMapBuf().fill(0); #endif @@ -698,8 +699,8 @@ void DecLib::executeLoopFilters() m_pcPic->getRecBeforeDbfBuf().copyFrom(m_pcPic->getRecoBuf()); #endif -#if NN_FILTER - if (cs.sps->getCNNLFEnabledFlag()) +#if NN_FILTERING_SET_0 + if (cs.sps->getNnlfSet0EnabledFlag()) { m_cCNNLF.PreCNNLFProcess(m_pcPic, cs, cs.slice->getCnnlfSliceParam()); } @@ -731,9 +732,9 @@ void DecLib::executeLoopFilters() { m_cSAO.SAOProcess( cs, cs.picture->getSAO() ); } - -#if NN_FILTER - if (cs.sps->getCNNLFEnabledFlag()) + +#if NN_FILTERING_SET_0 + if (cs.sps->getNnlfSet0EnabledFlag()) { m_cCNNLF.CNNLFProcess(cs, cs.slice->getCnnlfSliceParam()); } @@ -1719,11 +1720,11 @@ void DecLib::xActivateParameterSets( const InputNALUnit nalu ) pSlice->m_ccAlfFilterControl[0] = m_cALF.getCcAlfControlIdc(COMPONENT_Cb); pSlice->m_ccAlfFilterControl[1] = m_cALF.getCcAlfControlIdc(COMPONENT_Cr); -#if NN_FILTER - if (sps->getCNNLFEnabledFlag()) +#if NN_FILTERING_SET_0 + if (sps->getNnlfSet0EnabledFlag()) { - const int maxDepth = floorLog2(sps->getMaxCUWidth()) - sps->getLog2MinCodingBlockSize(); - m_cCNNLF.create(pps->getPicWidthInLumaSamples(), pps->getPicHeightInLumaSamples(), sps->getChromaFormatIdc(), sps->getMaxCUWidth(), sps->getMaxCUHeight(), maxDepth, sps->getBitDepths().recon); + const int maxDepth = floorLog2(sps->getMaxCUWidth()) - sps->getLog2MinCodingBlockSize(); + m_cCNNLF.create(pps->getPicWidthInLumaSamples(), pps->getPicHeightInLumaSamples(), sps->getChromaFormatIdc(), sps->getMaxCUWidth(), sps->getMaxCUHeight(), maxDepth, sps->getBitDepths().recon, getModelPath()); } #endif diff --git a/source/Lib/DecoderLib/DecLib.h b/source/Lib/DecoderLib/DecLib.h index a56593120f..64a897a863 100644 --- a/source/Lib/DecoderLib/DecLib.h +++ b/source/Lib/DecoderLib/DecLib.h @@ -54,9 +54,11 @@ #include "CommonLib/SEI.h" #include "CommonLib/Unit.h" #include "CommonLib/Reshape.h" -#if NN_FILTER -#include "CommonLib/CnnLoopFilter.h" + +#if NN_FILTERING_SET_0 +#include "CommonLib/NNFilterSet0.h" #endif + #if NN_FILTERING_SET_1 #include "CommonLib/NNFilterSet1.h" #endif @@ -125,8 +127,9 @@ private: SampleAdaptiveOffset m_cSAO; AdaptiveLoopFilter m_cALF; -#if NN_FILTER - CnnLoopFilter m_cCNNLF; +#if NN_FILTERING_SET_0 + NNFilterSet0 m_cCNNLF; + std::string m_ModelPath; ///< model path #endif Reshape m_cReshaper; ///< reshaper class @@ -261,6 +264,11 @@ public: void updatePrevGDRInSameLayer(); void updatePrevIRAPAndGDRSubpic(); +#if NN_FILTERING_SET_0 + void setModelPath(std::string path) { m_ModelPath = path; } + std::string getModelPath() { return m_ModelPath; } +#endif + #if JVET_S0078_NOOUTPUTPRIORPICFLAG bool getAudIrapOrGdrAuFlag() const { return m_audIrapOrGdrAuFlag; } #endif diff --git a/source/Lib/DecoderLib/DecSlice.cpp b/source/Lib/DecoderLib/DecSlice.cpp index 5c27f43d73..2eab4cbdf6 100644 --- a/source/Lib/DecoderLib/DecSlice.cpp +++ b/source/Lib/DecoderLib/DecSlice.cpp @@ -92,7 +92,7 @@ void DecSlice::decompressSlice( Slice* slice, InputBitstream* bitstream, int deb cs.pcv = slice->getPPS()->pcv; cs.chromaQpAdj = 0; -#if NN_FILTER +#if NN_FILTERING_SET_0 cs.picture->resizeCnnlfCtuEnableFlag(cs.pcv->sizeInCtus); #endif @@ -110,7 +110,6 @@ void DecSlice::decompressSlice( Slice* slice, InputBitstream* bitstream, int deb #if NN_FILTERING_SET_1 cs.picture->resizeNnlfSet1ParamIdx( cs.pcv->sizeInNnlfSet1InferSize ); #endif - const unsigned numSubstreams = slice->getNumberOfSubstreamSizes() + 1; // init each couple {EntropyDecoder, Substream} diff --git a/source/Lib/DecoderLib/VLCReader.cpp b/source/Lib/DecoderLib/VLCReader.cpp index 7bbdf03cea..78d1aa1d73 100644 --- a/source/Lib/DecoderLib/VLCReader.cpp +++ b/source/Lib/DecoderLib/VLCReader.cpp @@ -48,8 +48,8 @@ #include "CommonLib/AdaptiveLoopFilter.h" #include "CommonLib/ProfileLevelTier.h" -#if NN_FILTER -#include "CommonLib/CnnLoopFilter.h" +#if NN_FILTERING_SET_0 +#include "CommonLib/NNFilterSet0.h" #endif #if ENABLE_TRACING @@ -1746,8 +1746,8 @@ void HLSyntaxReader::parseSPS(SPS* pcSPS) pcSPS->setCCALFEnabledFlag(false); } -#if NN_FILTER - pcSPS->setCNNLFEnabledFlag(true); // always enable CNNLF +#if NN_FILTERING_SET_0 + READ_FLAG(uiCode, "sps_nnlf_set0_enable_flag"); pcSPS->setNnlfSet0EnabledFlag(uiCode == 1); #endif READ_FLAG(uiCode, "sps_lmcs_enable_flag"); pcSPS->setUseLmcs(uiCode == 1); @@ -4183,8 +4183,8 @@ void HLSyntaxReader::parseSliceHeader (Slice* pcSlice, PicHeader* picHeader, Par pcSlice->setUseChromaQpAdj(false); } -#if NN_FILTER - if (sps->getCNNLFEnabledFlag()) +#if NN_FILTERING_SET_0 + if (sps->getNnlfSet0EnabledFlag()) { cnnlf(pcSlice->getCnnlfSliceParam()); #if NN_SCALE @@ -4192,28 +4192,27 @@ void HLSyntaxReader::parseSliceHeader (Slice* pcSlice, PicHeader* picHeader, Par && (pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_LUMA] == 1 || pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_LUMA] == 5) ) { - READ_SCODE(NN_SCALE_SHIFT + 1, iCode, "nnScale_Y"); + READ_SCODE(NN_SCALE_SHIFT + 1, iCode, "nn scale Y"); pcSlice->setNnScale(iCode + (1 << NN_SCALE_SHIFT), COMPONENT_Y); } if (pcSlice->getCnnlfSliceParam().enabledFlag[COMPONENT_Cb] && (pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_CHROMA] == 1 || pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_CHROMA] == 5)) { - READ_SCODE(NN_SCALE_SHIFT + 1, iCode, "nnScale_Cb"); + READ_SCODE(NN_SCALE_SHIFT + 1, iCode, "nn scale Cb"); pcSlice->setNnScale(iCode + (1 << NN_SCALE_SHIFT), COMPONENT_Cb); } if (pcSlice->getCnnlfSliceParam().enabledFlag[COMPONENT_Cr] && (pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_CHROMA] == 1 || pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_CHROMA] == 5)) { - READ_SCODE(NN_SCALE_SHIFT + 1, iCode, "nnScale_Cr"); + READ_SCODE(NN_SCALE_SHIFT + 1, iCode, "nn scale Cr"); pcSlice->setNnScale(iCode + (1 << NN_SCALE_SHIFT), COMPONENT_Cr); } #endif } #endif - #if NN_FILTERING_SET_1 if (sps->getNnlfSet1EnabledFlag()) { @@ -5244,34 +5243,15 @@ void HLSyntaxReader::alfFilter( AlfParam& alfParam, const bool isChroma, const i } } -#if NN_FILTER -int HLSyntaxReader::truncatedUnaryEqProb(const int maxSymbol) -{ - for (int k = 0; k < maxSymbol; k++) - { - uint32_t symbol; -#if RExt__DECODER_DEBUG_BIT_STATISTICS - xReadFlag(symbol, ""); -#else - xReadFlag(symbol); -#endif - - if (!symbol) - { - return k; - } - } - return maxSymbol; -} - +#if NN_FILTERING_SET_0 void HLSyntaxReader::cnnlf(CnnlfSliceParam& cnnlfSliceParam) { uint32_t code; const int inv_map_list[9] = { 1, 5, 0, 2, 6, 3, 7, 4, 8 }; for (int chId = 0; chId < MAX_NUM_CHANNEL_TYPE; chId++) { - READ_UVLC(code, ""); - CHECK(code < 0 || code > 8, ""); + READ_UVLC(code, "nn-filter mode"); + CHECK(code < 0 || code > 8, "Invalid nn-filter mode"); cnnlfSliceParam.frameCtrlFlag[chId] = inv_map_list[code]; if (chId == CHANNEL_TYPE_LUMA) @@ -5286,6 +5266,5 @@ void HLSyntaxReader::cnnlf(CnnlfSliceParam& cnnlfSliceParam) } } #endif - //! \} diff --git a/source/Lib/DecoderLib/VLCReader.h b/source/Lib/DecoderLib/VLCReader.h index eddab99d05..9e107e3380 100644 --- a/source/Lib/DecoderLib/VLCReader.h +++ b/source/Lib/DecoderLib/VLCReader.h @@ -188,8 +188,7 @@ public: void ccAlfFilter( Slice *pcSlice ); void dpb_parameters(int maxSubLayersMinus1, bool subLayerInfoFlag, SPS *pcSPS); private: -#if NN_FILTER - int truncatedUnaryEqProb(const int maxSymbol); +#if NN_FILTERING_SET_0 void cnnlf(CnnlfSliceParam& cnnlfSliceParam); #endif diff --git a/source/Lib/EncoderLib/CABACWriter.cpp b/source/Lib/EncoderLib/CABACWriter.cpp index 857af01561..29d923a874 100644 --- a/source/Lib/EncoderLib/CABACWriter.cpp +++ b/source/Lib/EncoderLib/CABACWriter.cpp @@ -162,7 +162,7 @@ void CABACWriter::coding_tree_unit( CodingStructure& cs, const UnitArea& area, i partitioner.initCtu(area, CH_L, *cs.slice); -#if NN_FILTER +#if NN_FILTERING_SET_0 if (!skipSao) { for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) @@ -3525,42 +3525,42 @@ void CABACWriter::codeAlfCtuAlternative( CodingStructure& cs, uint32_t ctuRsAddr } } -#if NN_FILTER +#if NN_FILTERING_SET_0 void CABACWriter::codeCnnlfCtuEnableFlag(CodingStructure& cs, uint32_t ctuRsAddr, const int compIdx, CnnlfSliceParam* cnnlfParam) { - CnnlfSliceParam& cnnlfSliceParam = cnnlfParam ? (*cnnlfParam) : cs.slice->getCnnlfSliceParam(); - if (cs.sps->getCNNLFEnabledFlag() && cnnlfSliceParam.enabledFlag[compIdx]) - { - const PreCalcValues& pcv = *cs.pcv; - int frame_width_in_ctus = pcv.widthInCtus; - int ry = ctuRsAddr / frame_width_in_ctus; - int rx = ctuRsAddr - ry * frame_width_in_ctus; - const Position pos(rx * cs.pcv->maxCUWidth, ry * cs.pcv->maxCUHeight); - const uint32_t curSliceIdx = cs.slice->getIndependentSliceIdx(); + CnnlfSliceParam& cnnlfSliceParam = cnnlfParam ? (*cnnlfParam) : cs.slice->getCnnlfSliceParam(); + if (cs.sps->getNnlfSet0EnabledFlag() && cnnlfSliceParam.enabledFlag[compIdx]) + { + const PreCalcValues& pcv = *cs.pcv; + int frame_width_in_ctus = pcv.widthInCtus; + int ry = ctuRsAddr / frame_width_in_ctus; + int rx = ctuRsAddr - ry * frame_width_in_ctus; + const Position pos(rx * cs.pcv->maxCUWidth, ry * cs.pcv->maxCUHeight); + const uint32_t curSliceIdx = cs.slice->getIndependentSliceIdx(); - const uint32_t curTileIdx = cs.pps->getTileIdx(pos); - bool leftAvail = cs.getCURestricted(pos.offset(-(int)pcv.maxCUWidth, 0), pos, curSliceIdx, curTileIdx, CH_L) ? true : false; - bool aboveAvail = cs.getCURestricted(pos.offset(0, -(int)pcv.maxCUHeight), pos, curSliceIdx, curTileIdx, CH_L) ? true : false; + const uint32_t curTileIdx = cs.pps->getTileIdx(pos); + bool leftAvail = cs.getCURestricted(pos.offset(-(int)pcv.maxCUWidth, 0), pos, curSliceIdx, curTileIdx, CH_L) ? true : false; + bool aboveAvail = cs.getCURestricted(pos.offset(0, -(int)pcv.maxCUHeight), pos, curSliceIdx, curTileIdx, CH_L) ? true : false; - int leftCTUAddr = leftAvail ? ctuRsAddr - 1 : -1; - int aboveCTUAddr = aboveAvail ? ctuRsAddr - frame_width_in_ctus : -1; - if (cnnlfSliceParam.enabledFlag[compIdx]) - { - uint8_t* ctbCnnlfFlag = cs.slice->getPic()->getCnnlfCtuEnableFlag(compIdx); - if (cnnlfSliceParam.frameCtrlFlag[toChannelType((ComponentID)compIdx)] >= 1 - && cnnlfSliceParam.frameCtrlFlag[toChannelType((ComponentID)compIdx)] <= 4) - { - CHECK(!ctbCnnlfFlag[ctuRsAddr], "CNNLF CTB enable flag must be 1 with frameCtrlFlag = 1"); - } - else - { - int ctx = 0; - ctx += leftCTUAddr > -1 ? (ctbCnnlfFlag[leftCTUAddr] ? 1 : 0) : 0; - ctx += aboveCTUAddr > -1 ? (ctbCnnlfFlag[aboveCTUAddr] ? 1 : 0) : 0; - m_BinEncoder.encodeBin(ctbCnnlfFlag[ctuRsAddr], Ctx::ctbCnnlfFlag(compIdx * 3 + ctx)); - } - } + int leftCTUAddr = leftAvail ? ctuRsAddr - 1 : -1; + int aboveCTUAddr = aboveAvail ? ctuRsAddr - frame_width_in_ctus : -1; + if (cnnlfSliceParam.enabledFlag[compIdx]) + { + uint8_t* ctbCnnlfFlag = cs.slice->getPic()->getCnnlfCtuEnableFlag(compIdx); + if (cnnlfSliceParam.frameCtrlFlag[toChannelType((ComponentID)compIdx)] >= 1 + && cnnlfSliceParam.frameCtrlFlag[toChannelType((ComponentID)compIdx)] <= 4) + { + CHECK(!ctbCnnlfFlag[ctuRsAddr], "CNNLF CTB enable flag must be 1 with frameCtrlFlag = 1"); + } + else + { + int ctx = 0; + ctx += leftCTUAddr > -1 ? (ctbCnnlfFlag[leftCTUAddr] ? 1 : 0) : 0; + ctx += aboveCTUAddr > -1 ? (ctbCnnlfFlag[aboveCTUAddr] ? 1 : 0) : 0; + m_BinEncoder.encodeBin(ctbCnnlfFlag[ctuRsAddr], Ctx::ctbCnnlfFlag(compIdx * 3 + ctx)); + } } + } } #endif diff --git a/source/Lib/EncoderLib/CABACWriter.h b/source/Lib/EncoderLib/CABACWriter.h index 1fe7569245..e4182ed12a 100644 --- a/source/Lib/EncoderLib/CABACWriter.h +++ b/source/Lib/EncoderLib/CABACWriter.h @@ -164,7 +164,7 @@ public: void codeCcAlfFilterControlIdc(uint8_t idcVal, CodingStructure &cs, const ComponentID compID, const int curIdx, const uint8_t *filterControlIdc, Position lumaPos, const int filterCount); -#if NN_FILTER +#if NN_FILTERING_SET_0 void codeCnnlfCtuEnableFlag(CodingStructure& cs, uint32_t ctuRsAddr, const int compIdx, CnnlfSliceParam* cnnlfParam = NULL); #endif @@ -172,7 +172,7 @@ public: void codeNnlfSet1ParamIdx ( CodingStructure& cs, ChannelType chType); void codeNnlfSet1ParamIdx ( CodingStructure& cs, uint32_t ctuRsAddr, const int chal ); #endif - + private: void unary_max_symbol ( unsigned symbol, unsigned ctxId0, unsigned ctxIdN, unsigned maxSymbol ); void unary_max_eqprob ( unsigned symbol, unsigned maxSymbol ); diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h index f17224f012..9a213b40f6 100644 --- a/source/Lib/EncoderLib/EncCfg.h +++ b/source/Lib/EncoderLib/EncCfg.h @@ -171,6 +171,9 @@ protected: int m_iSourceHeight; Window m_conformanceWindow; int m_framesToBeEncoded; +#if NN_FILTERING_SET_0 + std::string m_ModelPath; ///< model path +#endif double m_adLambdaModifier[ MAX_TLAYER ]; std::vector<double> m_adIntraLambdaModifier; double m_dIntraQpFactor; ///< Intra Q Factor. If negative, use a default equation: 0.57*(1.0 - Clip3( 0.0, 0.5, 0.05*(double)(isField ? (GopSize-1)/2 : GopSize-1) )) @@ -752,8 +755,8 @@ protected: bool m_ccalf; int m_ccalfQpThreshold; -#if NN_FILTER - bool m_cnnlf; ///< Cnn Lopp Filter +#if NN_FILTERING_SET_0 + bool m_nnlfSet0; ///< Cnn Lopp Filter #endif #if NN_FILTERING_SET_1 @@ -1317,6 +1320,10 @@ public: int getSourceWidth () const { return m_iSourceWidth; } int getSourceHeight () const { return m_iSourceHeight; } int getFramesToBeEncoded () const { return m_framesToBeEncoded; } +#if NN_FILTERING_SET_0 + void setModelPath(std::string path) { m_ModelPath = path; } + std::string getModelPath() { return m_ModelPath; } +#endif //====== Lambda Modifiers ======== void setLambdaModifier ( uint32_t uiIndex, double dValue ) { m_adLambdaModifier[ uiIndex ] = dValue; } @@ -1968,10 +1975,9 @@ public: bool getUseCCALF() const { return m_ccalf; } void setCCALFQpThreshold( int b ) { m_ccalfQpThreshold = b; } int getCCALFQpThreshold() const { return m_ccalfQpThreshold; } - -#if NN_FILTER - void setUseCNNLF(bool b) { m_cnnlf = b; } - bool getUseCNNLF(bool b) const { return m_cnnlf; } + +#if NN_FILTERING_SET_0 + void setUseNnlfSet0(bool b) { m_nnlfSet0 = b; } #endif #if NN_FILTERING_SET_1 diff --git a/source/Lib/EncoderLib/EncGOP.cpp b/source/Lib/EncoderLib/EncGOP.cpp index 0773b931bc..9013e928b2 100644 --- a/source/Lib/EncoderLib/EncGOP.cpp +++ b/source/Lib/EncoderLib/EncGOP.cpp @@ -227,7 +227,7 @@ void EncGOP::init ( EncLib* pcEncLib ) m_pcSAO = pcEncLib->getSAO(); m_pcALF = pcEncLib->getALF(); -#if NN_FILTER +#if NN_FILTERING_SET_0 m_pcCNNLF = pcEncLib->getCNNLF(); #endif @@ -2751,11 +2751,11 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, #endif #endif -#if NN_FILTER - if (pcSlice->getSPS()->getCNNLFEnabledFlag()) +#if NN_FILTERING_SET_0 + if (pcSlice->getSPS()->getNnlfSet0EnabledFlag()) { - pcPic->resizeCnnlfCtuEnableFlag(numberOfCtusInFrame); - std::memset(pcSlice->getCnnlfSliceParam().enabledFlag, false, sizeof(pcSlice->getCnnlfSliceParam().enabledFlag)); + pcPic->resizeCnnlfCtuEnableFlag(numberOfCtusInFrame); + std::memset(pcSlice->getCnnlfSliceParam().enabledFlag, false, sizeof(pcSlice->getCnnlfSliceParam().enabledFlag)); } #endif @@ -2983,13 +2983,6 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, } } -#if NN_FILTER - if (pcSlice->getSPS()->getCNNLFEnabledFlag()) - { - m_pcCNNLF->PreCNNLFProcess(pcPic, cs); - } -#endif - // create SAO object based on the picture size if( pcSlice->getSPS()->getSAOEnabledFlag() ) { @@ -3053,7 +3046,14 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, #if NNVC_USE_REC_BEFORE_DBF pcPic->getRecBeforeDbfBuf().copyFrom(pcPic->getRecoBuf()); #endif - + +#if NN_FILTERING_SET_0 + if (pcSlice->getSPS()->getNnlfSet0EnabledFlag()) + { + m_pcCNNLF->PreCNNLFProcess(pcPic, cs); + } +#endif + m_pcLoopFilter->loopFilterPic( cs ); #if NNVC_USE_REC_AFTER_DBF @@ -3107,9 +3107,9 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, } } - -#if NN_FILTER - if (pcSlice->getSPS()->getCNNLFEnabledFlag()) + +#if NN_FILTERING_SET_0 + if (pcSlice->getSPS()->getNnlfSet0EnabledFlag()) { CnnlfSliceParam cnnlfSliceParam; m_pcCNNLF->initCABACEstimator(m_pcEncLib->getCABACEncoder(), m_pcEncLib->getCtxCache(), pcSlice); diff --git a/source/Lib/EncoderLib/EncGOP.h b/source/Lib/EncoderLib/EncGOP.h index d001690dcb..73ba9fef49 100644 --- a/source/Lib/EncoderLib/EncGOP.h +++ b/source/Lib/EncoderLib/EncGOP.h @@ -70,13 +70,14 @@ #include "HDRLib/inc/DistortionMetricDeltaE.H" #include <chrono> #endif -#if NN_FILTER -#include "EncCnnLoopFilter.h" + +#if NN_FILTERING_SET_0 +#include "EncNNFilterSet0.h" #endif + #if NN_FILTERING_SET_1 #include "EncoderLib/EncNNFilterSet1.h" #endif - //! \ingroup EncoderLib //! \{ @@ -141,7 +142,7 @@ private: EncCfg* m_pcCfg; EncSlice* m_pcSliceEncoder; PicList* m_pcListPic; - + #if NN_FILTERING_SET_1 EncNNFilterSet1* m_pcNNFilterSet1; #endif @@ -162,8 +163,8 @@ private: EncSampleAdaptiveOffset* m_pcSAO; EncAdaptiveLoopFilter* m_pcALF; -#if NN_FILTER - EncCnnLoopFilter* m_pcCNNLF; +#if NN_FILTERING_SET_0 + EncNNFilterSet0* m_pcCNNLF; #endif EncReshape* m_pcReshaper; diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp index 430cc02e37..fbe41de786 100644 --- a/source/Lib/EncoderLib/EncLib.cpp +++ b/source/Lib/EncoderLib/EncLib.cpp @@ -135,10 +135,10 @@ void EncLib::create( const int layerId ) m_cLoopFilter.initEncPicYuvBuffer(m_chromaFormatIDC, Size(getSourceWidth(), getSourceHeight()), getMaxCUWidth()); } -#if NN_FILTER - if (m_cnnlf) +#if NN_FILTERING_SET_0 + if (m_nnlfSet0) { - m_cEncCNNLF.create(getSourceWidth(), getSourceHeight(), m_chromaFormatIDC, m_maxCUWidth, m_maxCUHeight, floorLog2(m_maxCUWidth) - MIN_CU_LOG2, m_bitDepth, m_inputBitDepth); + m_cEncCNNLF.create(getSourceWidth(), getSourceHeight(), m_chromaFormatIDC, m_maxCUWidth, m_maxCUHeight, floorLog2(m_maxCUWidth) - MIN_CU_LOG2, m_bitDepth, m_inputBitDepth, getModelPath()); } #endif @@ -178,10 +178,10 @@ void EncLib::destroy () m_cCuEncoder. destroy(); #endif -#if NN_FILTER - if (m_cnnlf) +#if NN_FILTERING_SET_0 + if (m_nnlfSet0) { - m_cEncCNNLF.destroy(); + m_cEncCNNLF.destroy(); } #endif @@ -1426,10 +1426,9 @@ void EncLib::xInitSPS( SPS& sps ) sps.setALFEnabledFlag( m_alf ); sps.setCCALFEnabledFlag( m_ccalf ); -#if NN_FILTER - sps.setCNNLFEnabledFlag(m_cnnlf); +#if NN_FILTERING_SET_0 + sps.setNnlfSet0EnabledFlag(m_nnlfSet0); #endif - sps.setFieldSeqFlag(false); sps.setVuiParametersPresentFlag(getVuiParametersPresentFlag()); diff --git a/source/Lib/EncoderLib/EncLib.h b/source/Lib/EncoderLib/EncLib.h index f2e9a924fa..180fa9f2bd 100644 --- a/source/Lib/EncoderLib/EncLib.h +++ b/source/Lib/EncoderLib/EncLib.h @@ -58,8 +58,8 @@ #include "EncAdaptiveLoopFilter.h" #include "RateCtrl.h" -#if NN_FILTER -#include "EncCnnLoopFilter.h" +#if NN_FILTERING_SET_0 +#include "EncNNFilterSet0.h" #endif class EncLibCommon; @@ -100,8 +100,8 @@ private: EncSampleAdaptiveOffset m_cEncSAO; ///< sample adaptive offset class EncAdaptiveLoopFilter m_cEncALF; -#if NN_FILTER - EncCnnLoopFilter m_cEncCNNLF; +#if NN_FILTERING_SET_0 + EncNNFilterSet0 m_cEncCNNLF; #endif HLSWriter m_HLSWriter; ///< CAVLC encoder @@ -221,8 +221,8 @@ public: EncSampleAdaptiveOffset* getSAO () { return &m_cEncSAO; } EncAdaptiveLoopFilter* getALF () { return &m_cEncALF; } -#if NN_FILTER - EncCnnLoopFilter* getCNNLF () { return &m_cEncCNNLF; } +#if NN_FILTERING_SET_0 + EncNNFilterSet0* getCNNLF () { return &m_cEncCNNLF; } #endif EncGOP* getGOPEncoder () { return &m_cGOPEncoder; } diff --git a/source/Lib/EncoderLib/EncCnnLoopFilter.cpp b/source/Lib/EncoderLib/EncNNFilterSet0.cpp similarity index 84% rename from source/Lib/EncoderLib/EncCnnLoopFilter.cpp rename to source/Lib/EncoderLib/EncNNFilterSet0.cpp index 70e9943199..b7bc3b6008 100644 --- a/source/Lib/EncoderLib/EncCnnLoopFilter.cpp +++ b/source/Lib/EncoderLib/EncNNFilterSet0.cpp @@ -31,32 +31,32 @@ * THE POSSIBILITY OF SUCH DAMAGE. */ - /** \file EncCnnLoopFilter.cpp + /** \file EncNNFilterSet0.cpp \brief estimation part of cnn loop filter class */ -#include "EncCnnLoopFilter.h" +#include "EncNNFilterSet0.h" -#if NN_FILTER +#if NN_FILTERING_SET_0 #include "CommonLib/Picture.h" #include "CommonLib/CodingStructure.h" #define CnnlfCtx(c) SubCtx( Ctx::ctbCnnlfFlag, c ) -EncCnnLoopFilter::EncCnnLoopFilter() +EncNNFilterSet0::EncNNFilterSet0() : m_CABACEstimator(nullptr) { } -void EncCnnLoopFilter::create(const int picWidth, const int picHeight, const ChromaFormat chromaFormatIDC, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE], const int internalBitDepth[MAX_NUM_CHANNEL_TYPE]) +void EncNNFilterSet0::create(const int picWidth, const int picHeight, const ChromaFormat chromaFormatIDC, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE], const int internalBitDepth[MAX_NUM_CHANNEL_TYPE], std::string path) { - CnnLoopFilter::create(picWidth, picHeight, chromaFormatIDC, maxCUWidth, maxCUHeight, maxCUDepth, inputBitDepth); + NNFilterSet0::create(picWidth, picHeight, chromaFormatIDC, maxCUWidth, maxCUHeight, maxCUDepth, inputBitDepth, path); for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) { m_ctuEnableFlagTmp[compIdx] = new uint8_t[m_numCTUsInPic]; } } -void EncCnnLoopFilter::destroy() +void EncNNFilterSet0::destroy() { for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) { @@ -66,10 +66,10 @@ void EncCnnLoopFilter::destroy() m_ctuEnableFlagTmp[compIdx] = nullptr; } } - CnnLoopFilter::destroy(); + NNFilterSet0::destroy(); } -void EncCnnLoopFilter::initCABACEstimator(CABACEncoder* cabacEncoder, CtxCache* ctxCache, Slice* pcSlice) +void EncNNFilterSet0::initCABACEstimator(CABACEncoder* cabacEncoder, CtxCache* ctxCache, Slice* pcSlice) { m_CABACEstimator = cabacEncoder->getCABACEstimator(pcSlice->getSPS()); m_CtxCache = ctxCache; @@ -77,10 +77,9 @@ void EncCnnLoopFilter::initCABACEstimator(CABACEncoder* cabacEncoder, CtxCache* m_CABACEstimator->resetBits(); } -void EncCnnLoopFilter::PreCNNLFProcess(Picture* pic, CodingStructure& cs) +void EncNNFilterSet0::PreCNNLFProcess(Picture* pic, CodingStructure& cs) { - int baseQP = cs.slice->getPPS()->getPicInitQPMinus26() + 26; - initCnnModel(baseQP, cs.slice->getSliceQp()); + initCnnModel(); PelUnitBuf orgUnitBuf = cs.getOrgBuf(); PelUnitBuf recUnitBuf = cs.getRecoBuf(); @@ -89,11 +88,8 @@ void EncCnnLoopFilter::PreCNNLFProcess(Picture* pic, CodingStructure& cs) CHECK(orgUnitBuf.bufs.size() != cnnUnitBuf.bufs.size(), "Error buf size."); - PelUnitBuf predYuv = cs.getPredBufCustom(); - PelUnitBuf partitionYuv = cs.picture->getCuAverageBuf(); - // run CNNLF - runCNNLF(pic, recUnitBuf, predYuv, partitionYuv, cnnUnitBuf, cs.slice->getSliceType(), baseQP, cs.slice->getSliceQp(), false); + runCNNLF(pic, cnnUnitBuf, 0, false); for (int i = 1; i < MAX_NUM_CNN; i++) { @@ -101,7 +97,7 @@ void EncCnnLoopFilter::PreCNNLFProcess(Picture* pic, CodingStructure& cs) } } -void EncCnnLoopFilter::CNNLFProcess(CodingStructure& cs, const double *lambdas, CnnlfSliceParam& cnnlfSliceParam) +void EncNNFilterSet0::CNNLFProcess(CodingStructure& cs, const double *lambdas, CnnlfSliceParam& cnnlfSliceParam) { // set clipping range m_clpRngs = cs.slice->getClpRngs(); @@ -158,7 +154,7 @@ void EncCnnLoopFilter::CNNLFProcess(CodingStructure& cs, const double *lambdas, filterPic(cs, cnnlfSliceParam); } -double EncCnnLoopFilter::deriveCtbCnnlfEnableFlags(CodingStructure& cs, const PelUnitBuf& orgUnitBuf, const PelUnitBuf& cnnUnitBuf, const PelUnitBuf& recUnitBuf, ChannelType channel) +double EncNNFilterSet0::deriveCtbCnnlfEnableFlags(CodingStructure& cs, const PelUnitBuf& orgUnitBuf, const PelUnitBuf& cnnUnitBuf, const PelUnitBuf& recUnitBuf, ChannelType channel) { const ComponentID compIDFirst = isLuma(channel) ? COMPONENT_Y : COMPONENT_Cb; const ComponentID compIDLast = isLuma(channel) ? COMPONENT_Y : COMPONENT_Cr; @@ -231,7 +227,7 @@ double EncCnnLoopFilter::deriveCtbCnnlfEnableFlags(CodingStructure& cs, const Pe } #if NN_SCALE -void EncCnnLoopFilter::deriveScale(PelUnitBuf orgUnitBuf, PelUnitBuf recUnitBuf, PelUnitBuf cnnUnitBuf, Slice *slice) +void EncNNFilterSet0::deriveScale(PelUnitBuf orgUnitBuf, PelUnitBuf recUnitBuf, PelUnitBuf cnnUnitBuf, Slice *slice) { int area = recUnitBuf.get(COMPONENT_Y).width * recUnitBuf.get(COMPONENT_Y).height; for (int compID = 0; compID < MAX_NUM_COMPONENT; compID++) @@ -279,7 +275,7 @@ void EncCnnLoopFilter::deriveScale(PelUnitBuf orgUnitBuf, PelUnitBuf recUnitBuf, } #endif -void EncCnnLoopFilter::cnnlfEncoder(CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam, const PelUnitBuf& orgUnitBuf, PelUnitBuf& recUnitBuf, const ChannelType channel) +void EncNNFilterSet0::cnnlfEncoder(CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam, const PelUnitBuf& orgUnitBuf, PelUnitBuf& recUnitBuf, const ChannelType channel) { const TempCtx ctxStart(m_CtxCache, CnnlfCtx(m_CABACEstimator->getCtx())); TempCtx ctxBest(m_CtxCache); @@ -351,7 +347,7 @@ void EncCnnLoopFilter::cnnlfEncoder(CodingStructure& cs, CnnlfSliceParam& cnnlfS copyCtuEnableFlag(m_ctuEnableFlag, m_ctuEnableFlagTmp, channel); } -void EncCnnLoopFilter::copyCnnlfSliceParam(CnnlfSliceParam& cnnlfSliceParamDst, CnnlfSliceParam& cnnlfSliceParamSrc, ChannelType channel) +void EncNNFilterSet0::copyCnnlfSliceParam(CnnlfSliceParam& cnnlfSliceParamDst, CnnlfSliceParam& cnnlfSliceParamSrc, ChannelType channel) { cnnlfSliceParamDst.frameCtrlFlag[channel] = cnnlfSliceParamSrc.frameCtrlFlag[channel]; if (isLuma(channel)) @@ -364,7 +360,7 @@ void EncCnnLoopFilter::copyCnnlfSliceParam(CnnlfSliceParam& cnnlfSliceParamDst, cnnlfSliceParamDst.enabledFlag[COMPONENT_Cr] = cnnlfSliceParamSrc.enabledFlag[COMPONENT_Cr]; } } -double EncCnnLoopFilter::getCnnCost(const CPelUnitBuf& orgUnitBuf, const CPelUnitBuf& cnnUnitBuf, ChannelType channel) +double EncNNFilterSet0::getCnnCost(const CPelUnitBuf& orgUnitBuf, const CPelUnitBuf& cnnUnitBuf, ChannelType channel) { double cost = 0; if (isLuma(channel)) @@ -379,7 +375,7 @@ double EncCnnLoopFilter::getCnnCost(const CPelUnitBuf& orgUnitBuf, const CPelUni return cost; } -double EncCnnLoopFilter::getOrgCost(const CPelUnitBuf& orgUnitBuf, const CPelUnitBuf& recUnitBuf, ChannelType channel) +double EncNNFilterSet0::getOrgCost(const CPelUnitBuf& orgUnitBuf, const CPelUnitBuf& recUnitBuf, ChannelType channel) { double dist = 0; if (isLuma(channel)) @@ -394,7 +390,7 @@ double EncCnnLoopFilter::getOrgCost(const CPelUnitBuf& orgUnitBuf, const CPelUni return dist; } -double EncCnnLoopFilter::xCalcSSD(const CPelBuf& refBuf, const CPelBuf& cmpBuf, ChannelType ch) +double EncNNFilterSet0::xCalcSSD(const CPelBuf& refBuf, const CPelBuf& cmpBuf, ChannelType ch) { int iWidth = refBuf.width; int iHeight = refBuf.height; @@ -419,7 +415,7 @@ double EncCnnLoopFilter::xCalcSSD(const CPelBuf& refBuf, const CPelBuf& cmpBuf, return uiSSD; } -int EncCnnLoopFilter::lengthTruncatedUnary(int symbol, int maxSymbol) +int EncNNFilterSet0::lengthTruncatedUnary(int symbol, int maxSymbol) { if (maxSymbol == 0) { @@ -444,12 +440,12 @@ int EncCnnLoopFilter::lengthTruncatedUnary(int symbol, int maxSymbol) return numBins; } -void EncCnnLoopFilter::setFrameCtrlFlag(CnnlfSliceParam& cnnlfSlicePara, ChannelType ch, char val) +void EncNNFilterSet0::setFrameCtrlFlag(CnnlfSliceParam& cnnlfSlicePara, ChannelType ch, char val) { cnnlfSlicePara.frameCtrlFlag[ch] = val; } -void EncCnnLoopFilter::setEnableFlag(CnnlfSliceParam& cnnlfSlicePara, ChannelType channel, bool val) +void EncNNFilterSet0::setEnableFlag(CnnlfSliceParam& cnnlfSlicePara, ChannelType channel, bool val) { if (channel == CHANNEL_TYPE_LUMA) { @@ -461,7 +457,7 @@ void EncCnnLoopFilter::setEnableFlag(CnnlfSliceParam& cnnlfSlicePara, ChannelTyp } } -void EncCnnLoopFilter::setEnableFlag(CnnlfSliceParam& cnnlfSlicePara, ChannelType channel, uint8_t** ctuFlags) +void EncNNFilterSet0::setEnableFlag(CnnlfSliceParam& cnnlfSlicePara, ChannelType channel, uint8_t** ctuFlags) { const ComponentID compIDFirst = isLuma(channel) ? COMPONENT_Y : COMPONENT_Cb; const ComponentID compIDLast = isLuma(channel) ? COMPONENT_Y : COMPONENT_Cr; @@ -479,7 +475,7 @@ void EncCnnLoopFilter::setEnableFlag(CnnlfSliceParam& cnnlfSlicePara, ChannelTyp } } -void EncCnnLoopFilter::copyCtuEnableFlag(uint8_t** ctuFlagsDst, uint8_t** ctuFlagsSrc, ChannelType channel) +void EncNNFilterSet0::copyCtuEnableFlag(uint8_t** ctuFlagsDst, uint8_t** ctuFlagsSrc, ChannelType channel) { if (isLuma(channel)) { @@ -492,7 +488,7 @@ void EncCnnLoopFilter::copyCtuEnableFlag(uint8_t** ctuFlagsDst, uint8_t** ctuFla } } -void EncCnnLoopFilter::setCtuEnableFlag(uint8_t** ctuFlags, ChannelType channel, uint8_t val) +void EncNNFilterSet0::setCtuEnableFlag(uint8_t** ctuFlags, ChannelType channel, uint8_t val) { if (isLuma(channel)) { @@ -504,6 +500,4 @@ void EncCnnLoopFilter::setCtuEnableFlag(uint8_t** ctuFlags, ChannelType channel, memset(ctuFlags[COMPONENT_Cr], val, sizeof(uint8_t) * m_numCTUsInPic); } } - #endif - diff --git a/source/Lib/EncoderLib/EncCnnLoopFilter.h b/source/Lib/EncoderLib/EncNNFilterSet0.h similarity index 93% rename from source/Lib/EncoderLib/EncCnnLoopFilter.h rename to source/Lib/EncoderLib/EncNNFilterSet0.h index b525f9bd40..ae8f96c92c 100644 --- a/source/Lib/EncoderLib/EncCnnLoopFilter.h +++ b/source/Lib/EncoderLib/EncNNFilterSet0.h @@ -31,19 +31,19 @@ * THE POSSIBILITY OF SUCH DAMAGE. */ -/** \file EncCnnLoopFilter.h +/** \file EncNNFilterSet0.h \brief estimation part of cnn loop filter class (header) */ #ifndef __ENCCNNLOOPFILTER__ #define __ENCCNNLOOPFILTER__ -#include "CommonLib/CnnLoopFilter.h" +#include "CommonLib/NNFilterSet0.h" -#if NN_FILTER +#if NN_FILTERING_SET_0 #include "CABACWriter.h" -class EncCnnLoopFilter : public CnnLoopFilter +class EncNNFilterSet0 : public NNFilterSet0 { private: @@ -55,12 +55,12 @@ private: const double FracBitsScale = 1.0 / double( 1 << SCALE_BITS ); public: - EncCnnLoopFilter(); - virtual ~EncCnnLoopFilter() {} + EncNNFilterSet0(); + virtual ~EncNNFilterSet0() {} void PreCNNLFProcess(Picture* pic, CodingStructure& cs); void CNNLFProcess( CodingStructure& cs, const double *lambdas, CnnlfSliceParam& cnnlfSliceParam ); void initCABACEstimator( CABACEncoder* cabacEncoder, CtxCache* ctxCache, Slice* pcSlice ); - void create( const int picWidth, const int picHeight, const ChromaFormat chromaFormatIDC, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE], const int internalBitDepth[MAX_NUM_CHANNEL_TYPE]); + void create( const int picWidth, const int picHeight, const ChromaFormat chromaFormatIDC, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE], const int internalBitDepth[MAX_NUM_CHANNEL_TYPE], std::string path); void destroy(); private: @@ -81,7 +81,6 @@ private: void copyCtuEnableFlag( uint8_t** ctuFlagsDst, uint8_t** ctuFlagsSrc, ChannelType channel ); double xCalcSSD(const CPelBuf& refBuf, const CPelBuf& cmpBuf, ChannelType ch); int lengthTruncatedUnary(int symbol, int maxSymbol); - }; #endif #endif diff --git a/source/Lib/EncoderLib/VLCWriter.cpp b/source/Lib/EncoderLib/VLCWriter.cpp index d85f371de3..756a756b14 100644 --- a/source/Lib/EncoderLib/VLCWriter.cpp +++ b/source/Lib/EncoderLib/VLCWriter.cpp @@ -46,9 +46,9 @@ #include "CommonLib/AdaptiveLoopFilter.h" #include "CommonLib/ProfileLevelTier.h" -#if NN_FILTER -#include "EncCnnLoopFilter.h" -#include "CommonLib/CnnLoopFilter.h" +#if NN_FILTERING_SET_0 +#include "EncNNFilterSet0.h" +#include "CommonLib/NNFilterSet0.h" #endif //! \ingroup EncoderLib @@ -1005,6 +1005,7 @@ void HLSWriter::codeSPS( const SPS* pcSPS ) } } + #if NN_FILTERING_SET_1 WRITE_FLAG( pcSPS->getNnlfSet1EnabledFlag(), "sps_nnlf_set1_enabled_flag" ); if (pcSPS->getNnlfSet1EnabledFlag()) @@ -1021,6 +1022,11 @@ void HLSWriter::codeSPS( const SPS* pcSPS ) { WRITE_FLAG( pcSPS->getCCALFEnabledFlag(), "sps_ccalf_enabled_flag" ); } + +#if NN_FILTERING_SET_0 + WRITE_FLAG(pcSPS->getNnlfSet0EnabledFlag() ? 1 : 0, "sps_nnlf_set0_enable_flag"); +#endif + WRITE_FLAG(pcSPS->getUseLmcs() ? 1 : 0, "sps_lmcs_enable_flag"); WRITE_FLAG(pcSPS->getUseWP() ? 1 : 0, "sps_weighted_pred_flag"); // Use of Weighting Prediction (P_SLICE) WRITE_FLAG(pcSPS->getUseWPBiPred() ? 1 : 0, "sps_weighted_bipred_flag"); // Use of Weighting Bi-Prediction (B_SLICE) @@ -2488,8 +2494,8 @@ void HLSWriter::codeSliceHeader ( Slice* pcSlice, PicHeader *picHeader ) WRITE_FLAG(pcSlice->getUseChromaQpAdj(), "sh_cu_chroma_qp_offset_enabled_flag"); } -#if NN_FILTER - if (pcSlice->getSPS()->getCNNLFEnabledFlag()) +#if NN_FILTERING_SET_0 + if (pcSlice->getSPS()->getNnlfSet0EnabledFlag()) { cnnlf(pcSlice->getCnnlfSliceParam()); #if NN_SCALE @@ -2497,19 +2503,19 @@ void HLSWriter::codeSliceHeader ( Slice* pcSlice, PicHeader *picHeader ) && (pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_LUMA] == 1 || pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_LUMA] == 5)) { - WRITE_SCODE(pcSlice->getNnScale(COMPONENT_Y) - (1 << NN_SCALE_SHIFT), NN_SCALE_SHIFT + 1, "nnScale_Y"); + WRITE_SCODE(pcSlice->getNnScale(COMPONENT_Y) - (1 << NN_SCALE_SHIFT), NN_SCALE_SHIFT + 1, "nn scale Y"); } if (pcSlice->getCnnlfSliceParam().enabledFlag[COMPONENT_Cb] && (pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_CHROMA] == 1 || pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_CHROMA] == 5)) { - WRITE_SCODE(pcSlice->getNnScale(COMPONENT_Cb) - (1 << NN_SCALE_SHIFT), NN_SCALE_SHIFT + 1, "nnScale_Cb"); + WRITE_SCODE(pcSlice->getNnScale(COMPONENT_Cb) - (1 << NN_SCALE_SHIFT), NN_SCALE_SHIFT + 1, "nn scale Cb"); } if (pcSlice->getCnnlfSliceParam().enabledFlag[COMPONENT_Cr] && (pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_CHROMA] == 1 || pcSlice->getCnnlfSliceParam().frameCtrlFlag[CHANNEL_TYPE_CHROMA] == 5)) { - WRITE_SCODE(pcSlice->getNnScale(COMPONENT_Cr) - (1 << NN_SCALE_SHIFT), NN_SCALE_SHIFT + 1, "nnScale_Cr"); + WRITE_SCODE(pcSlice->getNnScale(COMPONENT_Cr) - (1 << NN_SCALE_SHIFT), NN_SCALE_SHIFT + 1, "nn cale Cr"); } #endif } @@ -3156,41 +3162,15 @@ void HLSWriter::alfFilter( const AlfParam& alfParam, const bool isChroma, const } } -#if NN_FILTER -void HLSWriter::truncatedUnaryEqProb(int symbol, const int maxSymbol) -{ - if (maxSymbol == 0) - { - return; - } - - bool codeLast = (maxSymbol > symbol); - int bins = 0; - int numBins = 0; - - while (symbol--) - { - bins <<= 1; - bins++; - numBins++; - } - if (codeLast) - { - bins <<= 1; - numBins++; - } - CHECK(!(numBins <= 32), "Unspecified error"); - xWriteCode(bins, numBins); -} - +#if NN_FILTERING_SET_0 void HLSWriter::cnnlf(const CnnlfSliceParam& cnnlfSliceParam) { const int map_list[9] = { 2, 0, 3, 5, 7, 1, 4, 6, 8 }; for (int chId = 0; chId < MAX_NUM_CHANNEL_TYPE; chId++) { int code = cnnlfSliceParam.frameCtrlFlag[chId]; - CHECK(code < 0 || code > 8, ""); - WRITE_UVLC(map_list[code], ""); + CHECK(code < 0 || code > 8, "Invalid nn-filter mode"); + WRITE_UVLC(map_list[code], "nn-filter mode"); } } #endif diff --git a/source/Lib/EncoderLib/VLCWriter.h b/source/Lib/EncoderLib/VLCWriter.h index d9d81488cc..fb5ab35870 100644 --- a/source/Lib/EncoderLib/VLCWriter.h +++ b/source/Lib/EncoderLib/VLCWriter.h @@ -154,8 +154,7 @@ public: void alfFilter( const AlfParam& alfParam, const bool isChroma, const int altIdx ); void dpb_parameters(int maxSubLayersMinus1, bool subLayerInfoFlag, const SPS *pcSPS); private: -#if NN_FILTER - void truncatedUnaryEqProb(int symbol, int maxSymbol); +#if NN_FILTERING_SET_0 void cnnlf(const CnnlfSliceParam& cnnlfSliceParam); #endif }; -- GitLab