diff --git a/CMakeLists.txt b/CMakeLists.txt index 10a86aa00ee6472c0763dd90f33cb5f5944b8629..1670d535404731fba2e77d8886fb1fb72683f689 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,9 +36,10 @@ if( CMAKE_COMPILER_IS_GNUCC ) set( BUILD_STATIC OFF CACHE BOOL "Build static executables" ) endif() -# set c++11 -set( CMAKE_CXX_STANDARD 11 ) +# set c++14 +set( CMAKE_CXX_STANDARD 14 ) set( CMAKE_CXX_STANDARD_REQUIRED ON ) +set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -Wall -fstrict-aliasing" ) # compile everything position independent (even static libraries) set( CMAKE_POSITION_INDEPENDENT_CODE TRUE ) diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/models/.DS_Store differ diff --git a/models/JVET_Z_EE_1.6_ChromaCNNFilter_InterSlice_float.sadl b/models/JVET_Z_EE_1.6_ChromaCNNFilter_InterSlice_float.sadl new file mode 100644 index 0000000000000000000000000000000000000000..148792b59f04c95ef435cf47ab83e296f6094d9e Binary files /dev/null and b/models/JVET_Z_EE_1.6_ChromaCNNFilter_InterSlice_float.sadl differ diff --git a/models/JVET_Z_EE_1.6_ChromaCNNFilter_InterSlice_int16.sadl b/models/JVET_Z_EE_1.6_ChromaCNNFilter_InterSlice_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..d5e38f094daf74092090cfc5a9eb6e4fee0f84ac Binary files /dev/null and b/models/JVET_Z_EE_1.6_ChromaCNNFilter_InterSlice_int16.sadl differ diff --git a/models/JVET_Z_EE_1.6_ChromaCNNFilter_IntraSlice_float.sadl b/models/JVET_Z_EE_1.6_ChromaCNNFilter_IntraSlice_float.sadl new file mode 100644 index 0000000000000000000000000000000000000000..5e3fca3eafe7648fbe1c44ddc7ed82ecbe2eac13 Binary files /dev/null and b/models/JVET_Z_EE_1.6_ChromaCNNFilter_IntraSlice_float.sadl differ diff --git a/models/JVET_Z_EE_1.6_ChromaCNNFilter_IntraSlice_int16.sadl b/models/JVET_Z_EE_1.6_ChromaCNNFilter_IntraSlice_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..3358225b97f4a6ee80815c61c086a7ff855751dc Binary files /dev/null and b/models/JVET_Z_EE_1.6_ChromaCNNFilter_IntraSlice_int16.sadl differ diff --git a/models/JVET_Z_EE_1.6_LumaCNNFilter_InterSlice_float.sadl b/models/JVET_Z_EE_1.6_LumaCNNFilter_InterSlice_float.sadl new file mode 100644 index 0000000000000000000000000000000000000000..5a806008d5149bce88275b1a82ca3a0eb8b3072a Binary files /dev/null and b/models/JVET_Z_EE_1.6_LumaCNNFilter_InterSlice_float.sadl differ diff --git a/models/JVET_Z_EE_1.6_LumaCNNFilter_InterSlice_int16.sadl b/models/JVET_Z_EE_1.6_LumaCNNFilter_InterSlice_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..bdf05a3bdd6621ac3b1b04db904ed29e88e07356 Binary files /dev/null and b/models/JVET_Z_EE_1.6_LumaCNNFilter_InterSlice_int16.sadl differ diff --git a/models/JVET_Z_EE_1.6_LumaCNNFilter_IntraSlice_float.sadl b/models/JVET_Z_EE_1.6_LumaCNNFilter_IntraSlice_float.sadl new file mode 100644 index 0000000000000000000000000000000000000000..d8ccf59465cea4b6ae1596ebe2aecb0dd50c73a8 Binary files /dev/null and b/models/JVET_Z_EE_1.6_LumaCNNFilter_IntraSlice_float.sadl differ diff --git a/models/JVET_Z_EE_1.6_LumaCNNFilter_IntraSlice_int16.sadl b/models/JVET_Z_EE_1.6_LumaCNNFilter_IntraSlice_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..825b5f4c751e1f44bf322cf0616686205a521666 Binary files /dev/null and b/models/JVET_Z_EE_1.6_LumaCNNFilter_IntraSlice_int16.sadl differ diff --git a/source/App/DecoderAnalyserApp/CMakeLists.txt b/source/App/DecoderAnalyserApp/CMakeLists.txt index ad272ca1f34efd94444f2a9c1a750c7d4c0fecc3..a1da18bf763f44d6500abc5967258e74dd13f80c 100644 --- a/source/App/DecoderAnalyserApp/CMakeLists.txt +++ b/source/App/DecoderAnalyserApp/CMakeLists.txt @@ -7,6 +7,8 @@ file( GLOB SRC_FILES "../DecoderApp/*.cpp" ) # get include files file( GLOB INC_FILES "../DecoderApp/*.h" ) +include_directories(../../../sadl) + # get additional libs for gcc on Ubuntu systems if( CMAKE_SYSTEM_NAME STREQUAL "Linux" ) if( CMAKE_CXX_COMPILER_ID STREQUAL "GNU" ) diff --git a/source/App/DecoderApp/CMakeLists.txt b/source/App/DecoderApp/CMakeLists.txt index 4e71c5c1e139ad10e15b9a973624f2fa2ea70274..440691ed2838111c2420857683e4ac466e188ec5 100644 --- a/source/App/DecoderApp/CMakeLists.txt +++ b/source/App/DecoderApp/CMakeLists.txt @@ -7,6 +7,8 @@ file( GLOB SRC_FILES "*.cpp" ) # get include files file( GLOB INC_FILES "*.h" ) +include_directories(../../../sadl) + # get additional libs for gcc on Ubuntu systems if( CMAKE_SYSTEM_NAME STREQUAL "Linux" ) if( CMAKE_CXX_COMPILER_ID STREQUAL "GNU" ) diff --git a/source/App/DecoderApp/DecApp.cpp b/source/App/DecoderApp/DecApp.cpp index d732c270e208b193e4999fe01ebbc3b6105cef1f..64958e2cb8ee34744802cbd7a3f02e977c0f9207 100644 --- a/source/App/DecoderApp/DecApp.cpp +++ b/source/App/DecoderApp/DecApp.cpp @@ -576,7 +576,13 @@ void DecApp::xCreateDecLib() ); m_cDecLib.setDecodedPictureHashSEIEnabled(m_decodedPictureHashSEIEnabled); - +#if CNN_FILTERING + m_cDecLib.setCnnlfInterLumaModelName (m_cnnlfInterLumaModelName); + m_cDecLib.setCnnlfInterChromaModelName (m_cnnlfInterChromaModelName); + m_cDecLib.setCnnlfIntraLumaModelName (m_cnnlfIntraLumaModelName); + m_cDecLib.setCnnlfIntraChromaModelName (m_cnnlfIntraChromaModelName); +#endif + if (!m_outputDecodedSEIMessagesFilename.empty()) { std::ostream &os=m_seiMessageFileStream.is_open() ? m_seiMessageFileStream : std::cout; diff --git a/source/App/DecoderApp/DecAppCfg.cpp b/source/App/DecoderApp/DecAppCfg.cpp index 801c24335b78022cdeee6c51cd1bfb6d0e4d9f48..3f1fbf617c8d566368b401f68391d6cac38aa700 100644 --- a/source/App/DecoderApp/DecAppCfg.cpp +++ b/source/App/DecoderApp/DecAppCfg.cpp @@ -80,7 +80,14 @@ bool DecAppCfg::parseCfg( int argc, char* argv[] ) #if NNVC_DUMP_DATA ("DumpBasename", m_dumpBasename, string(""), "basename for data dumping\n") #endif - + +#if CNN_FILTERING + ( "CnnlfInterLumaModel", m_cnnlfInterLumaModelName, string("models/JVET_Z_EE_1.6_LumaCNNFilter_InterSlice_int16.sadl"), "Cnnlf inter luma model name") + ( "CnnlfInterChromaModel", m_cnnlfInterChromaModelName, string("models/JVET_Z_EE_1.6_ChromaCNNFilter_InterSlice_int16.sadl"), "Cnnlf inter chroma model name") + ( "CnnlfIntraLumaModel", m_cnnlfIntraLumaModelName, string("models/JVET_Z_EE_1.6_LumaCNNFilter_IntraSlice_int16.sadl"), "Cnnlf intra luma model name") + ( "CnnlfIntraChromaModel", m_cnnlfIntraChromaModelName, string("models/JVET_Z_EE_1.6_ChromaCNNFilter_IntraSlice_int16.sadl"), "Cnnlf intra chroma model name") +#endif + ("OplFile,-opl", m_oplFilename , string(""), "opl-file name without extension for conformance testing\n") #if ENABLE_SIMD_OPT @@ -254,6 +261,7 @@ DecAppCfg::DecAppCfg() : m_bitstreamFileName() , m_reconFileName() , m_oplFilename() + , m_iSkipFrame(0) // m_outputBitDepth array initialised below , m_outputColourSpaceConvert(IPCOLOURSPACE_UNCHANGED) diff --git a/source/App/DecoderApp/DecAppCfg.h b/source/App/DecoderApp/DecAppCfg.h index 35fef707ccb3231e9a3541e0e92ce2953cd5331a..dd6e708fa163c440f88cd1493f9053de273758fa 100644 --- a/source/App/DecoderApp/DecAppCfg.h +++ b/source/App/DecoderApp/DecAppCfg.h @@ -61,9 +61,16 @@ protected: #if NNVC_DUMP_DATA std::string m_dumpBasename; ///< output basename for data -#endif +#endif std::string m_oplFilename; ///< filename to output conformance log. +#if CNN_FILTERING + std::string m_cnnlfInterLumaModelName; ///<inter luma cnnlf model + std::string m_cnnlfInterChromaModelName; ///<inter chroma cnnlf model + std::string m_cnnlfIntraLumaModelName; ///<intra luma cnnlf model + std::string m_cnnlfIntraChromaModelName; ///<inra chroma cnnlf model +#endif + int m_iSkipFrame; ///< counter for frames prior to the random access point to skip int m_outputBitDepth[MAX_NUM_CHANNEL_TYPE]; ///< bit depth used for writing output InputColourSpaceConversion m_outputColourSpaceConvert; diff --git a/source/App/DecoderApp/decmain.cpp b/source/App/DecoderApp/decmain.cpp index c8a6e3bd7070cdcbd867a95d3f39c27cca6739e1..f1590c75ae824401ea19063fba7e91c0a8670dc4 100644 --- a/source/App/DecoderApp/decmain.cpp +++ b/source/App/DecoderApp/decmain.cpp @@ -40,6 +40,9 @@ #include <time.h> #include "DecApp.h" #include "program_options_lite.h" +#if CNN_FILTERING +#include <chrono> +#endif //! \ingroup DecoderApp //! \{ @@ -81,7 +84,11 @@ int main(int argc, char* argv[]) // starting time double dResult; +#if CNN_FILTERING + auto startTime = std::chrono::steady_clock::now(); +#else clock_t lBefore = clock(); +#endif // call decoding function #ifndef _DEBUG @@ -108,7 +115,13 @@ int main(int argc, char* argv[]) #endif // ending time +#if CNN_FILTERING + auto endTime = std::chrono::steady_clock::now(); + dResult = std::chrono::duration_cast<std::chrono::milliseconds>( endTime - startTime).count(); + dResult = dResult / 1000.0; +#else dResult = (double)(clock()-lBefore) / CLOCKS_PER_SEC; +#endif printf("\n Total Time: %12.3f sec.\n", dResult); delete pcDecApp; diff --git a/source/App/EncoderApp/CMakeLists.txt b/source/App/EncoderApp/CMakeLists.txt index dd87e52d1f8244c607c34d42359dd0595e63cd88..befdbe2e236f85ad16c240de85996dab829d84eb 100644 --- a/source/App/EncoderApp/CMakeLists.txt +++ b/source/App/EncoderApp/CMakeLists.txt @@ -7,6 +7,8 @@ file( GLOB SRC_FILES "*.cpp" ) # get include files file( GLOB INC_FILES "*.h" ) +include_directories(../../../sadl) + # get additional libs for gcc on Ubuntu systems if( CMAKE_SYSTEM_NAME STREQUAL "Linux" ) if( CMAKE_CXX_COMPILER_ID STREQUAL "GNU" ) diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp index 9d706eb0661d87139c53e981a944b82eb9c04ffe..2f833372e12319f0e8949905c848953710aeea32 100644 --- a/source/App/EncoderApp/EncApp.cpp +++ b/source/App/EncoderApp/EncApp.cpp @@ -243,6 +243,12 @@ void EncApp::xInitLibCfg() } vps.setProfileTierLevel(ptls); vps.setVPSExtensionFlag ( false ); +#if CNN_FILTERING + m_cEncLib.setCnnlfInterLumaModelName (m_cnnlfInterLumaModelName); + m_cEncLib.setCnnlfInterChromaModelName (m_cnnlfInterChromaModelName); + m_cEncLib.setCnnlfIntraLumaModelName (m_cnnlfIntraLumaModelName); + m_cEncLib.setCnnlfIntraChromaModelName (m_cnnlfIntraChromaModelName); +#endif m_cEncLib.setProfile ( m_profile); m_cEncLib.setLevel ( m_levelTier, m_level); m_cEncLib.setFrameOnlyConstraintFlag ( m_frameOnlyConstraintFlag); @@ -1055,6 +1061,12 @@ void EncApp::xInitLibCfg() #endif m_cEncLib.setUseCCALF ( m_ccalf ); m_cEncLib.setCCALFQpThreshold ( m_ccalfQpThreshold ); +#if CNN_FILTERING + m_cEncLib.setUseCnnlf (m_cnnlf); + m_cEncLib.setCnnlfInferSizeBase (m_cnnlfInferSizeBase); + m_cEncLib.setCnnlfInferSizeExtension (m_cnnlfInferSizeExtension); + m_cEncLib.setCnnlfMaxNumParams (m_cnnlfMaxNumParams); +#endif m_cEncLib.setLmcs ( m_lmcsEnabled ); m_cEncLib.setReshapeSignalType ( m_reshapeSignalType ); m_cEncLib.setReshapeIntraCMD ( m_intraCMD ); diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp index 4ceb924d8074fc1235729ebe98edd29929c2c7e2..0c4d5a88f269c644b99c1bfd6eaabdc4fc8bd6f6 100644 --- a/source/App/EncoderApp/EncAppCfg.cpp +++ b/source/App/EncoderApp/EncAppCfg.cpp @@ -977,7 +977,11 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) ("NumHorVirtualBoundaries", m_numHorVirtualBoundaries, 0u, "Number of horizontal virtual boundaries (0-3, inclusive)") ("VirtualBoundariesPosX", cfg_virtualBoundariesPosX, cfg_virtualBoundariesPosX, "Locations of the vertical virtual boundaries in units of luma samples") ("VirtualBoundariesPosY", cfg_virtualBoundariesPosY, cfg_virtualBoundariesPosY, "Locations of the horizontal virtual boundaries in units of luma samples") +#if ENC_DB_OPT + ("EncDbOpt", m_encDbOpt, true, "Encoder optimization with deblocking filter") +#else ("EncDbOpt", m_encDbOpt, false, "Encoder optimization with deblocking filter") +#endif ("LMCSEnable", m_lmcsEnabled, false, "Enable LMCS (luma mapping with chroma scaling") ("LMCSSignalType", m_reshapeSignalType, 0u, "Input signal type: 0:SDR, 1:HDR-PQ, 2:HDR-HLG") ("LMCSUpdateCtrl", m_updateCtrl, 0, "LMCS model update control: 0:RA, 1:AI, 2:LDB/LDP") @@ -1410,6 +1414,16 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) #endif ( "CCALF", m_ccalf, true, "Cross-component Adaptive Loop Filter" ) ( "CCALFQpTh", m_ccalfQpThreshold, 37, "QP threshold above which encoder reduces CCALF usage") +#if CNN_FILTERING + ( "Cnnlf", m_cnnlf, true, "CNN-based loop filter" ) + ( "CnnlfInferSizeBase", m_cnnlfInferSizeBase, 128u, "Base inference size of CNN-based loop filter" ) + ( "CnnlfInferSizeExtension", m_cnnlfInferSizeExtension, 8u, "Extension of inference size of CNN-based loop filter" ) + ( "CnnlfMaxNumParams", m_cnnlfMaxNumParams, 3u, "Number of conditional parameters of CNN-based loop filter" ) + ( "CnnlfInterLumaModel", m_cnnlfInterLumaModelName, string("models/JVET_Z_EE_1.6_LumaCNNFilter_InterSlice_int16.sadl"), "Cnnlf inter luma model name") + ( "CnnlfInterChromaModel", m_cnnlfInterChromaModelName, string("models/JVET_Z_EE_1.6_ChromaCNNFilter_InterSlice_int16.sadl"), "Cnnlf inter chroma model name") + ( "CnnlfIntraLumaModel", m_cnnlfIntraLumaModelName, string("models/JVET_Z_EE_1.6_LumaCNNFilter_IntraSlice_int16.sadl"), "Cnnlf intra luma model name") + ( "CnnlfIntraChromaModel", m_cnnlfIntraChromaModelName, string("models/JVET_Z_EE_1.6_ChromaCNNFilter_IntraSlice_int16.sadl"), "Cnnlf intra chroma model name") +#endif ( "RPR", m_rprEnabledFlag, true, "Reference Sample Resolution" ) ( "ScalingRatioHor", m_scalingRatioHor, 1.0, "Scaling ratio in hor direction" ) ( "ScalingRatioVer", m_scalingRatioVer, 1.0, "Scaling ratio in ver direction" ) diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h index 47b2ce6827ae091755cd3bd7a3a841af919de156..3b2417ca3c0cc48b24ab89d5730185a60e15fd1c 100644 --- a/source/App/EncoderApp/EncAppCfg.h +++ b/source/App/EncoderApp/EncAppCfg.h @@ -86,6 +86,13 @@ protected: std::string m_bitstreamFileName; ///< output bitstream file std::string m_reconFileName; ///< output reconstruction file +#if CNN_FILTERING + std::string m_cnnlfInterLumaModelName; ///<inter luma cnnlf model + std::string m_cnnlfInterChromaModelName; ///<inter chroma cnnlf model + std::string m_cnnlfIntraLumaModelName; ///<intra luma cnnlf model + std::string m_cnnlfIntraChromaModelName; ///<inra chroma cnnlf model +#endif + // Lambda modifiers double m_adLambdaModifier[ MAX_TLAYER ]; ///< Lambda modifier array for each temporal layer std::vector<double> m_adIntraLambdaModifier; ///< Lambda modifier for Intra pictures, one for each temporal layer. If size>temporalLayer, then use [temporalLayer], else if size>0, use [size()-1], else use m_adLambdaModifier. @@ -701,6 +708,13 @@ protected: #endif bool m_ccalf; int m_ccalfQpThreshold; + +#if CNN_FILTERING + bool m_cnnlf; + unsigned m_cnnlfInferSizeBase; + unsigned m_cnnlfInferSizeExtension; + unsigned m_cnnlfMaxNumParams; +#endif bool m_rprEnabledFlag; double m_scalingRatioHor; diff --git a/source/Lib/CommonAnalyserLib/CMakeLists.txt b/source/Lib/CommonAnalyserLib/CMakeLists.txt index e915720f26f5944a1787f05472417cfea5f83d7d..12e809461a451cd8a03c20fb3be1ab0776340dae 100644 --- a/source/Lib/CommonAnalyserLib/CMakeLists.txt +++ b/source/Lib/CommonAnalyserLib/CMakeLists.txt @@ -43,6 +43,8 @@ set( SRC_FILES ${BASE_SRC_FILES} ${X86_SRC_FILES} ${SSE41_SRC_FILES} ${SSE42_SRC # get all include files set( INC_FILES ${BASE_INC_FILES} ${X86_INC_FILES} ${MD5_INC_FILES} ) +include_directories(../../../sadl) + # library add_library( ${LIB_NAME} STATIC ${SRC_FILES} ${INC_FILES} ${NATVIS_FILES} ) @@ -106,6 +108,12 @@ elseif( UNIX OR MINGW ) set_property( SOURCE ${AVX2_SRC_FILES} APPEND PROPERTY COMPILE_FLAGS "-mavx2" ) endif() +if( MSVC ) + set_property( SOURCE CNNFilter.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ") +elseif( UNIX OR MINGW ) + set_property( SOURCE CNNFilter.cpp APPEND PROPERTY COMPILE_FLAGS " -DNDEBUG=1 -mavx512f -mavx512bw") +endif() + # example: place header files in different folders source_group( "Natvis Files" FILES ${NATVIS_FILES} ) diff --git a/source/Lib/CommonLib/CMakeLists.txt b/source/Lib/CommonLib/CMakeLists.txt index 6ae75c82bc979827f168cb64cef7252649220cfc..1ec6b73465faf31c1a037a46fd558648087268ae 100644 --- a/source/Lib/CommonLib/CMakeLists.txt +++ b/source/Lib/CommonLib/CMakeLists.txt @@ -43,6 +43,8 @@ set( SRC_FILES ${BASE_SRC_FILES} ${X86_SRC_FILES} ${SSE41_SRC_FILES} ${SSE42_SRC # get all include files set( INC_FILES ${BASE_INC_FILES} ${X86_INC_FILES} ${MD5_INC_FILES} ) +include_directories(../../../sadl) + # library add_library( ${LIB_NAME} STATIC ${SRC_FILES} ${INC_FILES} ${NATVIS_FILES} ) @@ -104,6 +106,11 @@ elseif( UNIX OR MINGW ) set_property( SOURCE ${AVX2_SRC_FILES} APPEND PROPERTY COMPILE_FLAGS "-mavx2" ) endif() +if( MSVC ) + set_property( SOURCE CNNFilter.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ") +elseif( UNIX OR MINGW ) + set_property( SOURCE CNNFilter.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw") +endif() # example: place header files in different folders source_group( "Natvis Files" FILES ${NATVIS_FILES} ) diff --git a/source/Lib/CommonLib/CNNFilter.cpp b/source/Lib/CommonLib/CNNFilter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c9f096fee69afeebf4dc065583a2e17c7403899 --- /dev/null +++ b/source/Lib/CommonLib/CNNFilter.cpp @@ -0,0 +1,748 @@ +/** \file CNNFilter.cpp + \brief convolutional neural network-based filter class +*/ +#include "CNNFilter.h" +#include "NNInference.h" +#if CNN_FILTERING +#include "UnitTools.h" +#include "UnitPartitioner.h" +#include "CodingStructure.h" +#include "CommonLib/dtrace_codingstruct.h" +#include "CommonLib/dtrace_buffer.h" +#include <fstream> +#include <cmath> +#include <chrono> +#include <string.h> +#include <stdlib.h> +#include <stdio.h> +#include <math.h> +#include <ctime> +#include <algorithm> +#include <numeric> +#include <sadl/model.h> +using namespace std; + + +//! \ingroup CommonLib +//! \{ + +CNNFilter::CNNFilter() +{ +} +void CNNFilter::create( const int picWidth, const int picHeight, const ChromaFormat format, const int cnnlfNumParams) +{ + if (m_tempBuf.size() > 0) + return; + m_tempBuf.resize(cnnlfNumParams); + for (int i = 0; i < cnnlfNumParams; i ++) + { + m_tempBuf[i].create(format, Area(0, 0, picWidth, picHeight)); + } + +#if SCALE_NN_RESIDUE + m_tempScaledBuf.resize(cnnlfNumParams); + for (int i = 0; i < cnnlfNumParams; i ++) + { + m_tempScaledBuf[i].create(format, Area(0, 0, picWidth, picHeight)); + } +#endif +} +void CNNFilter::init(std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma) +{ + m_interLuma = interLuma; + m_interChroma = interChroma; + m_intraLuma = intraLuma; + m_intraChroma = intraChroma; +} +void CNNFilter::destroy() +{ + for (int i = 0; i < m_tempBuf.size(); i++) + { + m_tempBuf[i].destroy(); + } + +#if SCALE_NN_RESIDUE + for (int i = 0; i < m_tempScaledBuf.size(); i++) + { + m_tempScaledBuf[i].destroy(); + } +#endif +} + +template<typename T> +struct ModelData { + sadl::Model<T> model; + vector<sadl::Tensor<T>> inputs; + int hor,ver; + bool luma,inter; +}; + +template<typename T> +static std::vector<ModelData<T>> initSpace() { + std::vector<ModelData<T>> v; + v.reserve(5); + return v; +} + +#if NN_FIXED_POINT_IMPLEMENTATION +static std::vector<ModelData<int16_t>> models=initSpace<int16_t>(); +#else +static std::vector<ModelData<float>> models=initSpace<float>(); +#endif + +template<typename T> +static ModelData<T> &getModel(int ver, int hor, bool luma, bool inter, const string modelName) { + ModelData<T> *ptr = nullptr; + for(auto &m: models) + { + if (m.luma == luma && m.inter == inter) + { + ptr = &m; + break; + } + } + if (ptr == nullptr) + { + if (models.size() == models.capacity()) + { + std::cout << "[ERROR] increase cache" << std::endl; + exit(-1); + } + models.resize(models.size()+1); + ptr = &models.back(); + ModelData<T> &m = *ptr; + if (luma && inter) + { + cout << "[INFO] luma inter model loading" << endl; + } + else if (luma && !inter) + { + cout << "[INFO] luma intra model loading" << endl; + } + else if (!luma && inter) + { + cout << "[INFO] chroma inter model loading" << endl; + } + else if (!luma && !inter) + { + cout << "[INFO] chroma intra model loading" << endl; + } + ifstream file(modelName, ios::binary); + m.model.load(file); + m.luma = luma; + m.inter = inter; + m.ver = 0; + m.hor = 0; + } + ModelData<T> &m = *ptr; + if (m.ver != ver || m.hor != hor) + { + m.inputs = m.model.getInputsTemplate(); + if (luma) + { + for(auto &t: m.inputs) + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + } + else if (inter) // inter chroma + { + int inputId = 0; + for(auto &t: m.inputs) + { + if (t.dims()[3] == 1 && inputId != 4) // luma + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + else if (t.dims()[3] == 1 && inputId == 4) // qp + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 1})); + t.resize(dims); + } + else + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 2})); + t.resize(dims); + } + inputId ++; + } + } + else // intra chroma + { + int inputId = 0; + for(auto &t: m.inputs) + { + if (t.dims()[3] == 1 && inputId != 5 ) // luma + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver, hor, 1})); + t.resize(dims); + } + else if (t.dims()[3] == 1 && inputId == 5) // QP + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 1})); + t.resize(dims); + } + else + { + sadl::Dimensions dims(std::initializer_list<int>({1, ver/2, hor/2, 2})); + t.resize(dims); + } + inputId ++; + } + } + if (luma&&inter) cout << "[INFO] luma inter model initilization" << endl; + else if (luma&&!inter) cout << "[INFO] luma intra model initilization" << endl; + else if (!luma&&inter) cout << "[INFO] chroma inter model initilization" << endl; + else if (!luma&&!inter) cout << "[INFO] chroma intra model initilization" << endl; + if (!m.model.init(m.inputs)) + { + cerr << "[ERROR] issue during initialization" << endl; + exit(-1); + } + m.ver = ver; + m.hor = hor; + } + return m; +} + +template<typename T> +void prepareInputsLuma (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int qp, bool inter) +{ + double inputScale = 1024; +#if NN_FIXED_POINT_IMPLEMENTATION + int shiftInput = NN_INPUT_PRECISION; +#else + int shiftInput = 0; +#endif + PelBuf bufRec = pic->getRecBeforeDbfBuf(inferArea).get(COMPONENT_Y); + PelBuf bufPred = pic->getPredBufCustom(inferArea).get(COMPONENT_Y); + PelBuf bufPartition = pic->getCuAverageBuf(inferArea).get(COMPONENT_Y); + PelBuf bufBsMap = pic->getBsMapBuf(inferArea).get(COMPONENT_Y); + + sadl::Tensor<T>* inputRec, *inputPred, *inputPartition, *inputBs, *inputQp; + if (inter) + { + inputRec = &inputs[0]; + inputPred = &inputs[1]; + inputBs = &inputs[2]; + inputQp = &inputs[3]; + } + else + { + inputRec = &inputs[0]; + inputPred = &inputs[1]; + inputPartition = &inputs[2]; + inputBs = &inputs[3]; + inputQp = &inputs[4]; + } + + int hor = inferArea.lwidth(); + int ver = inferArea.lheight(); + + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + (*inputRec)(0, yy, xx, 0) = bufRec.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputPred)(0, yy, xx, 0) = bufPred.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputBs)(0, yy, xx, 0) = bufBsMap.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputQp)(0, yy, xx, 0) = qp / 64.0 * (1 << shiftInput); + if (!inter) + { + (*inputPartition)(0, yy, xx, 0) = bufPartition.at(xx, yy) / inputScale * (1 << shiftInput); + } + } + } +} + +template<typename T> +void prepareInputsChroma (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int qp, bool inter) +{ + double inputScale = 1024; +#if NN_FIXED_POINT_IMPLEMENTATION + int shiftInput = NN_INPUT_PRECISION; +#else + int shiftInput = 0; +#endif + PelBuf bufRecY = pic->getRecBeforeDbfBuf(inferArea).get(COMPONENT_Y); + PelBuf bufRecCb = pic->getRecBeforeDbfBuf(inferArea).get(COMPONENT_Cb); + PelBuf bufRecCr = pic->getRecBeforeDbfBuf(inferArea).get(COMPONENT_Cr); + + PelBuf bufPredCb = pic->getPredBufCustom(inferArea).get(COMPONENT_Cb); + PelBuf bufPredCr = pic->getPredBufCustom(inferArea).get(COMPONENT_Cr); + + PelBuf bufPartitionCb = pic->getCuAverageBuf(inferArea).get(COMPONENT_Cb); + PelBuf bufPartitionCr = pic->getCuAverageBuf(inferArea).get(COMPONENT_Cr); + + PelBuf bufBsMapCb = pic->getBsMapBuf(inferArea).get(COMPONENT_Cb); + PelBuf bufBsMapCr = pic->getBsMapBuf(inferArea).get(COMPONENT_Cr); + + sadl::Tensor<T>* inputRecCrossComponent, *inputRec, *inputPred, *inputPartition, *inputBs, *inputQp; + + if (inter) + { + inputRecCrossComponent = &inputs[0]; + inputRec = &inputs[1]; + inputPred = &inputs[2]; + inputBs = &inputs[3]; + inputQp = &inputs[4]; + } + else + { + inputRecCrossComponent = &inputs[0]; + inputRec = &inputs[1]; + inputPred = &inputs[2]; + inputPartition = &inputs[3]; + inputBs = &inputs[4]; + inputQp = &inputs[5]; + } + + int hor = inferArea.lwidth(); + int ver = inferArea.lheight(); + + int horC = inferArea.lwidth() >> 1; + int verC = inferArea.lheight() >> 1; + + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + (*inputRecCrossComponent)(0, yy, xx, 0) = bufRecY.at(xx, yy) / inputScale * (1 << shiftInput); + } + } + + for (int yy = 0; yy < verC; yy++) + { + for (int xx = 0; xx < horC; xx++) + { + + (*inputRec)(0, yy, xx, 0) = bufRecCb.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputRec)(0, yy, xx, 1) = bufRecCr.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputPred)(0, yy, xx, 0) = bufPredCb.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputPred)(0, yy, xx, 1) = bufPredCr.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputBs)(0, yy, xx, 0) = bufBsMapCb.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputBs)(0, yy, xx, 1) = bufBsMapCr.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputQp)(0, yy, xx, 0) = qp / 64.0 * (1 << shiftInput); + if (!inter) + { + (*inputPartition)(0, yy, xx, 0) = bufPartitionCb.at(xx, yy) / inputScale * (1 << shiftInput); + (*inputPartition)(0, yy, xx, 1) = bufPartitionCr.at(xx, yy) / inputScale * (1 << shiftInput); + } + } + } +} + +template<typename T> +void extractOutputsLuma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, PelStorage& tempScaledBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, bool inter) +{ +#if NN_FIXED_POINT_IMPLEMENTATION + int log2InputScale = 10; + int log2OutputScale = 10; + int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale; + int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale; + int offset = (1 << shiftOutput) / 2; +#else + double inputScale = 1024; + double outputScale = 1024; +#endif + auto output = m.result(0); + PelBuf bufDst = tempBuf.getBuf(inferArea).get(COMPONENT_Y); + +#if SCALE_NN_RESIDUE + PelBuf bufScaledDst = tempScaledBuf.getBuf(inferArea).get(COMPONENT_Y); +#endif + + int hor = inferArea.lwidth(); + int ver = inferArea.lheight(); + + PelBuf bufRec = pic->getRecBeforeDbfBuf(inferArea).get(COMPONENT_Y); + + for (int c = 0; c < 4; c++) // output includes 4 sub images + { + for (int y = 0; y < ver >> 1; y++) + { + for (int x = 0; x < hor >> 1; x++) + { + int yy = (y << 1) + c / 2; + int xx = (x << 1) + c % 2; + if (xx < extLeft || yy < extTop || xx >= hor - extRight || yy >= ver - extBottom) + { + continue; + } +#if NN_FIXED_POINT_IMPLEMENTATION + int out = ( output(0, y, x, c) + (bufRec.at(xx, yy) << shiftInput) + offset) >> shiftOutput; +#else + int out = ( output(0, y, x, c) + bufRec.at(xx, yy) / inputScale ) * outputScale + 0.5; +#endif + bufDst.at(xx, yy) = Pel(Clip3<int>( 0, 1023, out)); + + #if SCALE_NN_RESIDUE + bufScaledDst.at(xx, yy) = Pel(Clip3<int>(0, 1023 << NN_RESIDUE_ADDITIONAL_SHIFT, out * (1 << NN_RESIDUE_ADDITIONAL_SHIFT) ) ); + #endif + } + } + } +} + +template<typename T> +void extractOutputsChroma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, PelStorage& tempScaledBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, bool inter) +{ +#if NN_FIXED_POINT_IMPLEMENTATION +int log2InputScale = 10; + int log2OutputScale = 10; + int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale; + int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale; + int offset = (1 << shiftOutput) / 2; +#else + double inputScale = 1024; + double outputScale = 1024; +#endif + auto output = m.result(0); + PelBuf bufDstCb = tempBuf.getBuf(inferArea).get(COMPONENT_Cb); + PelBuf bufDstCr = tempBuf.getBuf(inferArea).get(COMPONENT_Cr); + +#if SCALE_NN_RESIDUE + PelBuf bufScaledDstCb = tempScaledBuf.getBuf(inferArea).get(COMPONENT_Cb); + PelBuf bufScaledDstCr = tempScaledBuf.getBuf(inferArea).get(COMPONENT_Cr); +#endif + + int hor = inferArea.lwidth() >> 1; + int ver = inferArea.lheight() >> 1; + + PelBuf bufRecCb = pic->getRecBeforeDbfBuf(inferArea).get(COMPONENT_Cb); + PelBuf bufRecCr = pic->getRecBeforeDbfBuf(inferArea).get(COMPONENT_Cr); + + for (int c = 0; c < 4; c++) // output includes 4 sub images + { + for (int y = 0; y < ver >> 1; y++) + { + for (int x = 0; x < hor >> 1; x++) + { + int yy = (y << 1) + c / 2; + int xx = (x << 1) + c % 2; + if (xx < extLeft || yy < extTop || xx >= hor - extRight || yy >= ver - extBottom) + { + continue; + } + +#if NN_FIXED_POINT_IMPLEMENTATION + int outCb = ( output(0, y, x, c) + (bufRecCb.at(xx, yy) << shiftInput) + offset) >> shiftOutput; + int outCr = ( output(0, y, x, c+4) + (bufRecCr.at(xx, yy) << shiftInput) + offset) >> shiftOutput; +#else + int outCb = ( output(0, y, x, c) + bufRecCb.at(xx, yy) / inputScale ) * outputScale + 0.5; + int outCr = ( output(0, y, x, c+4) + bufRecCr.at(xx, yy) / inputScale ) * outputScale + 0.5; +#endif + bufDstCb.at(xx, yy) = Pel(Clip3<int>( 0, 1023, outCb ) ); + bufDstCr.at(xx, yy) = Pel(Clip3<int>( 0, 1023, outCr ) ); + + #if SCALE_NN_RESIDUE + bufScaledDstCb.at(xx, yy) = Pel(Clip3<int>(0, 1023 << NN_RESIDUE_ADDITIONAL_SHIFT, outCb * (1 << NN_RESIDUE_ADDITIONAL_SHIFT ) ) ); + bufScaledDstCr.at(xx, yy) = Pel(Clip3<int>(0, 1023 << NN_RESIDUE_ADDITIONAL_SHIFT, outCr * (1 << NN_RESIDUE_ADDITIONAL_SHIFT ) ) ); + #endif + } + } + } +} + +template<typename T> +void CNNFilter::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter) +{ + //at::init_num_threads(); // use all available threads + + const int border_to_skip = 0; + if (border_to_skip>0) sadl::Tensor<float>::skip_border = true; + + // get model + ModelData<T> &m = getModel<T>(inferArea.lheight(), inferArea.lwidth(), true, inter, inter ? m_interLuma : m_intraLuma); + sadl::Model<T> &model = m.model; + + // get inputs + vector<sadl::Tensor<T>> &inputs = m.inputs; + + int seqQp = pic->slices[0]->getPPS()->getPicInitQPMinus26() + 26; + int sliceQp = pic->slices[0]->getSliceQp(); + int delta = inter ? paramIdx * 5 : paramIdx * 2; + if ( pic->slices[0]->getTLayer() >= 4 && paramIdx >= 2 ) + { + delta = 5 - delta; + } + int qp = inter ? seqQp - delta : sliceQp - delta; + +#if NN_COMMON_API + std::vector<InputData> listInputData; + if (inter) + { + InputData inputRec = {NN_INPUT_REC, 0, 1024, 3, true, false}; + InputData inputPred = {NN_INPUT_PRED, 1, 1024, 3, true, false}; + InputData inputBs = {NN_INPUT_BS, 2, 1024, 3, true, false}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 3, 64, 7, true, false}; + listInputData.push_back(inputRec); + listInputData.push_back(inputPred); + listInputData.push_back(inputBs); + listInputData.push_back(inputQp); + } + else + { + InputData inputRec = {NN_INPUT_REC, 0, 1024, 3, true, false}; + InputData inputPred = {NN_INPUT_PRED, 1, 1024, 3, true, false}; + InputData inputPartition = {NN_INPUT_PARTITION, 2, 1024, 3, true, false}; + InputData inputBs = {NN_INPUT_BS, 3, 1024, 3, true, false}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 4, 64, 7, true, false}; + listInputData.push_back(inputRec); + listInputData.push_back(inputPred); + listInputData.push_back(inputPartition); + listInputData.push_back(inputBs); + listInputData.push_back(inputQp); + } + NNInference::prepareInputs<T>(pic, inferArea, inputs, -1, qp, inter, listInputData); + NNInference::infer<T>(model, inputs); +#else + prepareInputsLuma<T>(pic, inferArea, inputs, qp, inter); + + // inference + chrono::steady_clock::time_point t1 = chrono::steady_clock::now(); + if (!model.apply(inputs)) + { + cerr << "[ERROR] issue during luma model inference" << endl; + exit(-1); + } + chrono::steady_clock::time_point t2 = chrono::steady_clock::now(); + chrono::duration<double> dt = chrono::duration_cast<chrono::duration<double>>(t2 - t1); + cout << "[INFO] luma model takes "<< dt.count() * 1000. << " ms" <<endl; + + if (border_to_skip) cout<<"[INFO] discard border size="<<model.result().border_skip<<endl; +#endif + + // get outputs + extractOutputsLuma(pic, model, m_tempBuf[paramIdx], m_tempScaledBuf[paramIdx], inferArea, extLeft, extRight, extTop, extBottom, inter); + +} + +template<typename T> +void CNNFilter::cnnFilterChromaBlock(Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter) +{ + + //at::init_num_threads(); + + const int border_to_skip = 0; + if (border_to_skip>0) sadl::Tensor<float>::skip_border = true; + + // get model + ModelData<T> &m = getModel<T>(inferArea.lheight(), inferArea.lwidth(), false, inter, inter ? m_interChroma : m_intraChroma); + sadl::Model<T> &model = m.model; + + // get inputs + vector<sadl::Tensor<T>> &inputs = m.inputs; + + int seqQp = pic->slices[0]->getPPS()->getPicInitQPMinus26() + 26; + int sliceQp = pic->slices[0]->getSliceQp(); + int delta = inter ? paramIdx * 5 : paramIdx * 2; + if ( pic->slices[0]->getTLayer() >= 4 && paramIdx >= 2 ) + { + delta = 5 - delta; + } + int qp = inter ? seqQp - delta : sliceQp - delta; + +#if NN_COMMON_API + std::vector<InputData> listInputData; + + if (inter) + { + InputData inputRecCrossComponent = {NN_INPUT_REC, 0, 1024, 3, true, false}; + InputData inputRec = {NN_INPUT_REC, 1, 1024, 3, false, true}; + InputData inputPred = {NN_INPUT_PRED, 2, 1024, 3, false, true}; + InputData inputBs = {NN_INPUT_BS, 3, 1024, 3, false, true}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 4, 64, 7, false, true}; + listInputData.push_back(inputRecCrossComponent); + listInputData.push_back(inputRec); + listInputData.push_back(inputPred); + listInputData.push_back(inputBs); + listInputData.push_back(inputQp); + } + else + { + InputData inputRecCrossComponent = {NN_INPUT_REC, 0, 1024, 3, true, false}; + InputData inputRec = {NN_INPUT_REC, 1, 1024, 3, false, true}; + InputData inputPred = {NN_INPUT_PRED, 2, 1024, 3, false, true}; + InputData inputPartition = {NN_INPUT_PARTITION, 3, 1024, 3, false, true}; + InputData inputBs = {NN_INPUT_BS, 4, 1024, 3, false, true}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 5, 64, 7, false, true}; + listInputData.push_back(inputRecCrossComponent); + listInputData.push_back(inputRec); + listInputData.push_back(inputPred); + listInputData.push_back(inputPartition); + listInputData.push_back(inputBs); + listInputData.push_back(inputQp); + } + NNInference::prepareInputs<T>(pic, inferArea, inputs, -1, qp, inter, listInputData); + NNInference::infer<T>(model, inputs); +#else + prepareInputsChroma<T>(pic, inferArea, inputs, qp, inter); + + // inference + chrono::steady_clock::time_point t1 = chrono::steady_clock::now(); + if (!model.apply(inputs)) + { + cerr << "[ERROR] issue during chroma model inference" << endl; + exit(-1); + } + chrono::steady_clock::time_point t2 = chrono::steady_clock::now(); + chrono::duration<double> dt = chrono::duration_cast<chrono::duration<double>>(t2 - t1); + cout << "[INFO] chroma model takes " << dt.count() * 1000. << " ms"<<endl; + + if (border_to_skip) cout << "[INFO] discard border size=" << model.result().border_skip << endl; +#endif + + // get outputs + extractOutputsChroma(pic, model, m_tempBuf[paramIdx], m_tempScaledBuf[paramIdx], inferArea, extLeft, extRight, extTop, extBottom, inter); + +} + +void CNNFilter::cnnFilter(Picture* pic) +{ + CodingStructure& cs = *pic->cs; + Slice* pcSlice = cs.slice; + const PreCalcValues& pcv = *cs.pcv; + int extension = cs.sps->getCnnlfInferSizeExtension(); + for (int chal = 0; chal < MAX_NUM_CHANNEL_TYPE; chal ++) + { + const ChannelType chType = ChannelType( chal ); + CnnlfInferGranularity cnnlfInferGranularity = pcSlice->getCnnlfInferGranularity(chType); + std::vector<uint8_t>* cnnlfParamIdx = pic->getCnnlfParamIdx(cnnlfInferGranularity); + int blockSize = cs.sps->getCnnlfInferSize (cnnlfInferGranularity); + for( int blockIdx = 0; blockIdx < pcv.sizeInCnnlfInferSize[cnnlfInferGranularity]; blockIdx++ ) + { + int xPosInBlocks = blockIdx % pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int yPosInBlocks = blockIdx / pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int xPos = xPosInBlocks * blockSize; + int yPos = yPosInBlocks * blockSize; + int width = (xPos + blockSize > pcv.lumaWidth) ? (pcv.lumaWidth - xPos) : blockSize; + int height = (yPos + blockSize > pcv.lumaHeight) ? (pcv.lumaHeight - yPos) : blockSize; + + const UnitArea block( cs.area.chromaFormat, Area( xPos, yPos, width, height ) ); + + int extLeft = xPos > 0 ? extension : 0; + int extRight = (xPos + width + extension > pcv.lumaWidth) ? (pcv.lumaWidth - xPos - width) : extension; + int extTop = yPos > 0 ? extension : 0; + int extBottom = (yPos + height + extension > pcv.lumaHeight) ? (pcv.lumaHeight - yPos - height) : extension; + + int extXPos = xPos - extLeft; + int extYPos = yPos - extTop; + int extWidth = width + extLeft + extRight; + int extHeight = height + extTop + extBottom; + const UnitArea extBlock( cs.area.chromaFormat, Area( extXPos, extYPos, extWidth, extHeight ) ); + if (cnnlfParamIdx[chal][blockIdx] && chal == 0) + { +#if NN_FIXED_POINT_IMPLEMENTATION + cnnFilterLumaBlock<int16_t>(pic, extBlock, extLeft, extRight, extTop, extBottom, cnnlfParamIdx[chal][blockIdx]-1, pcSlice->getSliceType() != I_SLICE); +#else + cnnFilterLumaBlock<float>(pic, extBlock, extLeft, extRight, extTop, extBottom, cnnlfParamIdx[chal][blockIdx]-1, pcSlice->getSliceType() != I_SLICE); +#endif + } + if (cnnlfParamIdx[chal][blockIdx] && chal > 0) + { +#if NN_FIXED_POINT_IMPLEMENTATION + cnnFilterChromaBlock<int16_t>(pic, extBlock, extLeft >> 1, extRight >> 1, extTop >> 1, extBottom >> 1, cnnlfParamIdx[chal][blockIdx]-1, pcSlice->getSliceType() != I_SLICE); +#else + cnnFilterChromaBlock<float>(pic, extBlock, extLeft >> 1, extRight >> 1, extTop >> 1, extBottom >> 1, cnnlfParamIdx[chal][blockIdx]-1, pcSlice->getSliceType() != I_SLICE); +#endif + } + +#if SCALE_NN_RESIDUE + if (cnnlfParamIdx[chal][blockIdx] && pcSlice->getNnScaleFlag(cnnlfParamIdx[chal][blockIdx] - 1, chType)) + { + if (chType == CHANNEL_TYPE_LUMA) + { + scaleResidualBlock(pic, block, cnnlfParamIdx[chal][blockIdx] - 1, COMPONENT_Y); + } + else + { + scaleResidualBlock(pic, block, cnnlfParamIdx[chal][blockIdx] - 1, COMPONENT_Cb); + scaleResidualBlock(pic, block, cnnlfParamIdx[chal][blockIdx] - 1, COMPONENT_Cr); + } + } +#endif + } + } + for (int comp = 0; comp < MAX_NUM_COMPONENT; comp ++) + { + const ComponentID compID = ComponentID( comp ); + int chal = toChannelType(compID); + CnnlfInferGranularity cnnlfInferGranularity = pcSlice->getCnnlfInferGranularity(toChannelType(compID)); + std::vector<uint8_t>* cnnlfParamIdx = pic->getCnnlfParamIdx(cnnlfInferGranularity); + int blockSize = cs.sps->getCnnlfInferSize (cnnlfInferGranularity); + + for( int blockIdx = 0; blockIdx < pcv.sizeInCnnlfInferSize[cnnlfInferGranularity]; blockIdx++ ) + { + int xPosInBlocks = blockIdx % pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int yPosInBlocks = blockIdx / pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int xPos = xPosInBlocks * blockSize; + int yPos = yPosInBlocks * blockSize; + int width = (xPos + blockSize > pcv.lumaWidth) ? (pcv.lumaWidth - xPos) : blockSize; + int height = (yPos + blockSize > pcv.lumaHeight) ? (pcv.lumaHeight - yPos) : blockSize; + + const UnitArea block( cs.area.chromaFormat, Area( xPos, yPos, width, height ) ); + if (cnnlfParamIdx[chal][blockIdx]) + { +#if SCALE_NN_RESIDUE + if (pcSlice->getNnScaleFlag(cnnlfParamIdx[chal][blockIdx] - 1, toChannelType(compID))) + pic->getRecoBuf(block).get(compID).copyFrom(m_tempScaledBuf[cnnlfParamIdx[chal][blockIdx] - 1].getBuf(block).get(compID)); + else +#endif + pic->getRecoBuf(block).get(compID).copyFrom(m_tempBuf[cnnlfParamIdx[chal][blockIdx] - 1].getBuf(block).get(compID)); + } + } + } +} + +#if SCALE_NN_RESIDUE +void CNNFilter::scaleResidualBlock(Picture *pic, UnitArea inferArea, int paramIdx, ComponentID compID) +{ + + CodingStructure &cs = *pic->cs; + Slice * pcSlice = cs.slice; + + const int scale = pcSlice->getNnScale(paramIdx, toChannelType(compID)); + + const int shift = NN_RESIDUE_SCALE_SHIFT + NN_RESIDUE_ADDITIONAL_SHIFT; + + const int offset = (1 << shift) / 2; + + PelBuf bufRec = pic->getRecoBuf(inferArea).get(compID); + + int strideRec = bufRec.stride; + Pel * pRec = bufRec.buf; + + int blockSizeHor = inferArea.lwidth(); + int blockSizeVer = inferArea.lheight(); + + if (compID) + { + blockSizeHor = blockSizeHor / 2; + blockSizeVer = blockSizeVer / 2; + } + + int idxDst = 0; + int idxRec = 0; + int blockSize = blockSizeVer * blockSizeHor; + + PelBuf bufScaledDst = m_tempScaledBuf[paramIdx].getBuf(inferArea).get(compID); + int strideDst = bufScaledDst.stride; + Pel * pScaledDst = bufScaledDst.buf; + + for (int pixelIdx = 0, yy = 0, xx = 0; pixelIdx < blockSize; pixelIdx++) + { + xx = pixelIdx % blockSizeHor; + yy = pixelIdx / blockSizeHor; + + idxDst = yy * strideDst + xx; + idxRec = yy * strideRec + xx; + pScaledDst[idxDst] = Clip3(0, 1023, pRec[idxRec] + (((pScaledDst[idxDst] - (pRec[idxRec] << NN_RESIDUE_ADDITIONAL_SHIFT)) * scale + offset) >> shift)); + } +} +#endif +#endif + +//! \} diff --git a/source/Lib/CommonLib/CNNFilter.h b/source/Lib/CommonLib/CNNFilter.h new file mode 100644 index 0000000000000000000000000000000000000000..b17fd8d6e8a3f9c90ed9fa110376a4a9b3343fd2 --- /dev/null +++ b/source/Lib/CommonLib/CNNFilter.h @@ -0,0 +1,39 @@ +/** \file CNNFilter.h + \brief convolutional neural network-based fiter class (header) +*/ + +#ifndef __CNNFILTER__ +#define __CNNFILTER__ + +#include "CommonDef.h" +#include "Unit.h" +#include "Picture.h" +#include "Reshape.h" +//! \ingroup CommonLib +//! \{ + + +class CNNFilter +{ +public: + CNNFilter(); + + std::vector<PelStorage> m_tempBuf; + std::string m_interLuma, m_interChroma, m_intraLuma, m_intraChroma; +#if SCALE_NN_RESIDUE + std::vector<PelStorage> m_tempScaledBuf; +#endif + template<typename T> void cnnFilterLumaBlock( Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter); + template<typename T> void cnnFilterChromaBlock( Picture* pic, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, int paramIdx, bool inter); + void cnnFilter( Picture* pic); +#if SCALE_NN_RESIDUE + void scaleResidualBlock(Picture *pic, UnitArea inferArea, int paramIdx, ComponentID compID); +#endif + void create(const int picWidth, const int picHeight, const ChromaFormat format, const int cnnlfNumParams); + void init(std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma); + void destroy(); +}; + +//! \} +#endif + diff --git a/source/Lib/CommonLib/Contexts.cpp b/source/Lib/CommonLib/Contexts.cpp index f3acf16b6f1533f43ba17c42405b123d83bb21d2..e96338f9144f3628e8fe848883eb78991f13f33a 100644 --- a/source/Lib/CommonLib/Contexts.cpp +++ b/source/Lib/CommonLib/Contexts.cpp @@ -783,7 +783,15 @@ const CtxSet ContextSetCfg::ctbAlfFlag = ContextSetCfg::addCtxSet { 62, 39, 39, 54, 39, 39, 31, 39, 39, }, { 0, 0, 0, 4, 0, 0, 1, 0, 0, }, }); - +#if CNN_FILTERING +const CtxSet ContextSetCfg::cnnlfParamIdx = ContextSetCfg::addCtxSet +({ + { 33, 52, 25, 61, }, + { 13, 23, 4, 61, }, + { 62, 39, 54, 39, }, + { 0, 0, 4, 0, }, +}); +#endif const CtxSet ContextSetCfg::ctbAlfAlternative = ContextSetCfg::addCtxSet ({ { 11, 26, }, diff --git a/source/Lib/CommonLib/Contexts.h b/source/Lib/CommonLib/Contexts.h index 5a94a2d954f200f295edae5bc683c9e0dc5fbc9e..f6a4c5aaa202b9f112767e223cf9a079a651f5e4 100644 --- a/source/Lib/CommonLib/Contexts.h +++ b/source/Lib/CommonLib/Contexts.h @@ -261,6 +261,9 @@ public: static const CtxSet ctbAlfAlternative; static const CtxSet AlfUseTemporalFilt; static const CtxSet CcAlfFilterControlFlag; +#if CNN_FILTERING + static const CtxSet cnnlfParamIdx; +#endif static const CtxSet CiipFlag; static const CtxSet SmvdFlag; static const CtxSet IBCFlag; diff --git a/source/Lib/CommonLib/Picture.h b/source/Lib/CommonLib/Picture.h index b2c6b1236c5580911951cbee46aa448fdbc35b34..25417d70934ae395b0fd599b8e0ec419561071b0 100644 --- a/source/Lib/CommonLib/Picture.h +++ b/source/Lib/CommonLib/Picture.h @@ -343,6 +343,35 @@ public: std::vector<SAOBlkParam> m_sao[2]; +#if CNN_FILTERING + std::vector<uint8_t> m_cnnlfParamIdx[MAX_NUM_CNNLF_INFER_GRANULARITY][MAX_NUM_CHANNEL_TYPE]; + uint8_t* getCnnlfParamIdx( int gra, int chal ) { return m_cnnlfParamIdx[gra][chal].data(); } + std::vector<uint8_t>* getCnnlfParamIdx(int gra) { return m_cnnlfParamIdx[gra]; } +#if SCALE_NN_RESIDUE + std::vector<uint8_t> m_cnnlfBackupParamIdx[MAX_NUM_CNNLF_INFER_GRANULARITY][MAX_NUM_CHANNEL_TYPE]; + uint8_t * getCnnlfBackupParamIdx(int gra, int chal) { return m_cnnlfBackupParamIdx[gra][chal].data(); } + std::vector<uint8_t> *getCnnlfBackupParamIdx(int gra) + { + return m_cnnlfBackupParamIdx[gra]; + } +#endif + void resizeCnnlfParamIdx(const unsigned int *numEntries) + { + for (int gra = 0; gra < MAX_NUM_CNNLF_INFER_GRANULARITY; gra ++) + { + for( int chal = 0; chal < MAX_NUM_CHANNEL_TYPE; chal++ ) + { + m_cnnlfParamIdx[gra][chal].resize( numEntries[gra] ); + std::fill( m_cnnlfParamIdx[gra][chal].begin(), m_cnnlfParamIdx[gra][chal].end(), 0 ); +#if SCALE_NN_RESIDUE + m_cnnlfBackupParamIdx[gra][chal].resize(numEntries[gra]); + std::fill(m_cnnlfBackupParamIdx[gra][chal].begin(), m_cnnlfBackupParamIdx[gra][chal].end(), 0); +#endif + } + } + } +#endif + std::vector<uint8_t> m_alfCtuEnableFlag[MAX_NUM_COMPONENT]; uint8_t* getAlfCtuEnableFlag( int compIdx ) { return m_alfCtuEnableFlag[compIdx].data(); } std::vector<uint8_t>* getAlfCtuEnableFlag() { return m_alfCtuEnableFlag; } diff --git a/source/Lib/CommonLib/Slice.cpp b/source/Lib/CommonLib/Slice.cpp index 7531b345cff586c13a649251e511337e7b3d19bb..2783172816a39f207c067c0338485ac8f050f470 100644 --- a/source/Lib/CommonLib/Slice.cpp +++ b/source/Lib/CommonLib/Slice.cpp @@ -2171,11 +2171,18 @@ void Slice::startProcessingTimer() m_iProcessingStartTime = clock(); } +#if CNN_FILTERING +void Slice::stopProcessingTimer(double elapsedTime) +{ + m_dProcessingTime += elapsedTime; +} +#else void Slice::stopProcessingTimer() { m_dProcessingTime += (double)(clock()-m_iProcessingStartTime) / CLOCKS_PER_SEC; m_iProcessingStartTime = 0; } +#endif unsigned Slice::getMinPictureDistance() const { @@ -2927,6 +2934,13 @@ SPS::SPS() m_uiMaxDecPicBuffering[i] = 1; m_maxNumReorderPics[i] = 0; } + +#if CNN_FILTERING + for ( int i = 0; i < MAX_NUM_CNNLF_INFER_GRANULARITY; i++ ) + { + m_nnlfSet1InferSize[i] = 128; + } +#endif ::memset(m_ltRefPicPocLsbSps, 0, sizeof(m_ltRefPicPocLsbSps)); ::memset(m_usedByCurrPicLtSPSFlag, 0, sizeof(m_usedByCurrPicLtSPSFlag)); diff --git a/source/Lib/CommonLib/Slice.h b/source/Lib/CommonLib/Slice.h index fc7fbeeb2b76d84a6cc59788270fbf8a70e4a878..8207209245c1037e18b42bc3f8f5f95e1821f3f5 100644 --- a/source/Lib/CommonLib/Slice.h +++ b/source/Lib/CommonLib/Slice.h @@ -1397,7 +1397,7 @@ private: unsigned m_dualITree; uint32_t m_uiMaxCUWidth; uint32_t m_uiMaxCUHeight; - + RPLList m_RPLList0; RPLList m_RPLList1; uint32_t m_numRPL0; @@ -1478,6 +1478,12 @@ private: bool m_alfEnabledFlag; bool m_ccalfEnabledFlag; +#if CNN_FILTERING + bool m_cnnlfEnabledFlag; + unsigned m_cnnlfInferSize[MAX_NUM_CNNLF_INFER_GRANULARITY]; + unsigned m_cnnlfInferSizeExtension; + unsigned m_cnnlfMaxNumParams; +#endif bool m_wrapAroundEnabledFlag; unsigned m_IBCFlag; bool m_useColorTrans; @@ -1658,6 +1664,14 @@ public: uint32_t getMaxCUWidth() const { return m_uiMaxCUWidth; } void setMaxCUHeight( uint32_t u ) { m_uiMaxCUHeight = u; } uint32_t getMaxCUHeight() const { return m_uiMaxCUHeight; } +#if CNN_FILTERING + void setCnnlfInferSize( uint32_t* cnnlfInferSize ) { m_cnnlfInferSize[0] = cnnlfInferSize[0]; m_cnnlfInferSize[1] = cnnlfInferSize[1]; m_cnnlfInferSize[2] = cnnlfInferSize[2];} + uint32_t getCnnlfInferSize(CnnlfInferGranularity cnnlfInferGranularity) const { return m_cnnlfInferSize[cnnlfInferGranularity]; } + void setCnnlfInferSizeExtension( uint32_t cnnlfInferSizeExtension ) { m_cnnlfInferSizeExtension = cnnlfInferSizeExtension; } + uint32_t getCnnlfInferSizeExtension() const { return m_cnnlfInferSizeExtension; } + void setCnnlfMaxNumParams( uint32_t cnnlfMaxNumParams ) { m_cnnlfMaxNumParams = cnnlfMaxNumParams; } + uint32_t getCnnlfMaxNumParams() const { return m_cnnlfMaxNumParams; } +#endif bool getTransformSkipEnabledFlag() const { return m_transformSkipEnabledFlag; } void setTransformSkipEnabledFlag( bool b ) { m_transformSkipEnabledFlag = b; } uint32_t getLog2MaxTransformSkipBlockSize() const { return m_log2MaxTransformSkipBlockSize; } @@ -1725,8 +1739,12 @@ public: bool getALFEnabledFlag() const { return m_alfEnabledFlag; } void setALFEnabledFlag( bool b ) { m_alfEnabledFlag = b; } -bool getCCALFEnabledFlag() const { return m_ccalfEnabledFlag; } -void setCCALFEnabledFlag( bool b ) { m_ccalfEnabledFlag = b; } + bool getCCALFEnabledFlag() const { return m_ccalfEnabledFlag; } + void setCCALFEnabledFlag( bool b ) { m_ccalfEnabledFlag = b; } +#if CNN_FILTERING + bool getCnnlfEnabledFlag() const { return m_cnnlfEnabledFlag; } + void setCnnlfEnabledFlag( bool b ) { m_cnnlfEnabledFlag = b; } +#endif void setJointCbCrEnabledFlag(bool bVal) { m_JointCbCrEnabledFlag = bVal; } bool getJointCbCrEnabledFlag() const { return m_JointCbCrEnabledFlag; } @@ -2598,6 +2616,10 @@ class Slice private: // Bitstream writing bool m_saoEnabledFlag[MAX_NUM_CHANNEL_TYPE]; +#if CNN_FILTERING + uint8_t m_cnnlfMode[MAX_NUM_CHANNEL_TYPE]; // 0: slice off + CnnlfInferGranularity m_cnnlfInferGranularity[MAX_NUM_CHANNEL_TYPE]; +#endif int m_iPOC; int m_iLastIDR; int m_prevGDRInSameLayerPOC; //< the previous GDR in the same layer @@ -2710,6 +2732,12 @@ private: int m_tileGroupCcAlfCrApsId; bool m_disableSATDForRd; bool m_isLossless; + +#if SCALE_NN_RESIDUE + bool m_sliceNnScaleFlag[3][MAX_NUM_CHANNEL_TYPE]; + int m_nnScale[3][MAX_NUM_CHANNEL_TYPE]; +#endif + public: Slice(); virtual ~Slice(); @@ -2732,6 +2760,21 @@ public: APS** getAlfAPSs() { return m_alfApss; } void setSaoEnabledFlag(ChannelType chType, bool s) {m_saoEnabledFlag[chType] =s; } bool getSaoEnabledFlag(ChannelType chType) const { return m_saoEnabledFlag[chType]; } +#if CNN_FILTERING + void setCnnlfMode(ChannelType chType, uint8_t m) {m_cnnlfMode[chType] = m; } + uint8_t getCnnlfMode(ChannelType chType) const { return m_cnnlfMode[chType]; } + void setCnnlfInferGranularity(ChannelType chType, CnnlfInferGranularity cnnlfInferGranularity) {m_cnnlfInferGranularity[chType] = cnnlfInferGranularity; } + CnnlfInferGranularity getCnnlfInferGranularity(ChannelType chType) const { return m_cnnlfInferGranularity[chType]; } + +#if SCALE_NN_RESIDUE + void setNnScale(int sc, uint8_t paramIdx, ChannelType chType) { m_nnScale[paramIdx][chType] = sc; } + int getNnScale(uint8_t paramIdx, ChannelType chType) const { return m_nnScale[paramIdx][chType]; } + + void setNnScaleFlag(bool b, int paramIdx, ChannelType chType) { m_sliceNnScaleFlag[paramIdx][chType] = b; } + bool getNnScaleFlag(int paramIdx, ChannelType chType) const { return m_sliceNnScaleFlag[paramIdx][chType]; } +#endif + +#endif ReferencePictureList* getRPL0() { return &m_RPL0; } ReferencePictureList* getRPL1() { return &m_RPL1; } void setRPL0idx(int rplIdx) { m_rpl0Idx = rplIdx; } @@ -2970,7 +3013,11 @@ public: ClpRngs& getClpRngs() { return m_clpRngs;} unsigned getMinPictureDistance() const ; void startProcessingTimer(); +#if CNN_FILTERING + void stopProcessingTimer(double elapsedTime); +#else void stopProcessingTimer(); +#endif void resetProcessingTime() { m_dProcessingTime = m_iProcessingStartTime = 0; } double getProcessingTime() const { return m_dProcessingTime; } @@ -3050,6 +3097,11 @@ public: , widthInCtus ( (pps.getPicWidthInLumaSamples () + sps.getMaxCUWidth () - 1) / sps.getMaxCUWidth () ) , heightInCtus ( (pps.getPicHeightInLumaSamples() + sps.getMaxCUHeight() - 1) / sps.getMaxCUHeight() ) , sizeInCtus ( widthInCtus * heightInCtus ) +#if CNN_FILTERING + , widthInCnnlfInferSize {(pps.getPicWidthInLumaSamples () + sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_SMALL) - 1) / sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_SMALL), (pps.getPicWidthInLumaSamples () + sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_BASE) - 1) / sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_BASE), (pps.getPicWidthInLumaSamples () + sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_LARGE) - 1) / sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_LARGE)} + , heightInCnnlfInferSize {(pps.getPicHeightInLumaSamples () + sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_SMALL) - 1) / sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_SMALL), (pps.getPicHeightInLumaSamples () + sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_BASE) - 1) / sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_BASE), (pps.getPicHeightInLumaSamples () + sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_LARGE) - 1) / sps.getCnnlfInferSize(CNNLF_INFER_GRANULARITY_LARGE)} + , sizeInCnnlfInferSize {widthInCnnlfInferSize[CNNLF_INFER_GRANULARITY_SMALL] * heightInCnnlfInferSize[CNNLF_INFER_GRANULARITY_SMALL], widthInCnnlfInferSize[CNNLF_INFER_GRANULARITY_BASE] * heightInCnnlfInferSize[CNNLF_INFER_GRANULARITY_BASE], widthInCnnlfInferSize[CNNLF_INFER_GRANULARITY_LARGE] * heightInCnnlfInferSize[CNNLF_INFER_GRANULARITY_LARGE]} +#endif , lumaWidth ( pps.getPicWidthInLumaSamples() ) , lumaHeight ( pps.getPicHeightInLumaSamples() ) , fastDeltaQPCuMaxSize( Clip3(1u << sps.getLog2MinCodingBlockSize(), sps.getMaxCUHeight(), 32u) ) @@ -3083,6 +3135,11 @@ public: const unsigned widthInCtus; const unsigned heightInCtus; const unsigned sizeInCtus; +#if CNN_FILTERING + const uint32_t widthInCnnlfInferSize[MAX_NUM_CNNLF_INFER_GRANULARITY]; + const uint32_t heightInCnnlfInferSize[MAX_NUM_CNNLF_INFER_GRANULARITY]; + const uint32_t sizeInCnnlfInferSize[MAX_NUM_CNNLF_INFER_GRANULARITY]; +#endif const unsigned lumaWidth; const unsigned lumaHeight; const unsigned fastDeltaQPCuMaxSize; diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index aadb4a4281627dad0327599558e63874e16b9d59..068d66ab7be4c927fde4646408fe998dac12c052 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -50,7 +50,9 @@ #include <assert.h> #include <cassert> + // clang-format off + #define NN_COMMON_API 1 #define NNVC_INFO_ENCODER 1 // add some info in encoder logs necessary to extract data @@ -65,8 +67,34 @@ #define NNVC_USE_QP 1 // QP slice #define NNVC_USE_SLICETYPE 1 // slice type +#define CNN_FILTERING 1 + +#if CNN_FILTERING +#define SCALE_NN_RESIDUE 1 +#define COMBINE_NN_WITH_LF 1 +#define NN_FIXED_POINT_IMPLEMENTATION 1 +#endif + +#if SCALE_NN_RESIDUE +#define NN_RESIDUE_SCALE_SHIFT 8 +#define NN_RESIDUE_ADDITIONAL_SHIFT 4 // maximum 4 +#define NN_RESIDUE_SCALE_DEVIATION_UP_BOUND 1.25 +#define NN_RESIDUE_SCALE_DEVIATION_BOT_BOUND 0.0625 +#endif + +#if COMBINE_NN_WITH_LF +#define ENC_DB_OPT 1 +#define FUSE_NN_AND_LF 1 +#endif + +#if NN_FIXED_POINT_IMPLEMENTATION +#define NN_INPUT_PRECISION 13 +#define NN_OUPUTPUT_PRECISION 13 +#endif + //########### place macros to be removed in next cycle below this line ############### + #define JVET_V0056 1 // MCTF changes as presented in JVET-V0056 #define JVET_S0096_RPL_CONSTRAINT 1// JVET-S0096 aspect 1: When pps_rpl_info_in_ph_flag is equal to 1 and ph_inter_slice_allowed_flag is equal to 1, the value of num_ref_entries[ 0 ][ RplsIdx[ 0 ] ] shall be greater than 0. @@ -434,7 +462,15 @@ enum ComponentID JOINT_CbCr = MAX_NUM_COMPONENT, MAX_NUM_TBLOCKS = MAX_NUM_COMPONENT }; - +#if CNN_FILTERING +enum CnnlfInferGranularity +{ + CNNLF_INFER_GRANULARITY_SMALL = 0, // half base size + CNNLF_INFER_GRANULARITY_BASE = 1, // specified in SPS + CNNLF_INFER_GRANULARITY_LARGE = 2, // double base size + MAX_NUM_CNNLF_INFER_GRANULARITY = 3 +}; +#endif #if NN_COMMON_API enum NNInputType { @@ -448,7 +484,6 @@ enum NNInputType MAX_NUM_NN_INPUT = 7 }; #endif - #define MAP_CHROMA(c) (ComponentID(c)) enum InputColourSpaceConversion // defined in terms of conversion prior to input of encoder. diff --git a/source/Lib/DecoderLib/CABACReader.cpp b/source/Lib/DecoderLib/CABACReader.cpp index 2980b49c3e0fa65fc548fe24306f6edf97349f32..7aca4738b803c1ab237a0a872c4a455ecb7e2e26 100644 --- a/source/Lib/DecoderLib/CABACReader.cpp +++ b/source/Lib/DecoderLib/CABACReader.cpp @@ -144,6 +144,55 @@ void CABACReader::coding_tree_unit( CodingStructure& cs, const UnitArea& area, i sao( cs, ctuRsAddr ); +#if CNN_FILTERING + if ( cs.sps->getCnnlfEnabledFlag() && ctuRsAddr == 0) + { + for( int chal = 0; chal < MAX_NUM_CHANNEL_TYPE; chal++ ) + { + CnnlfInferGranularity cnnlfInferGranularity = cs.slice->getCnnlfInferGranularity(ChannelType(chal)); + uint8_t* cnnlfParamIdx = cs.slice->getPic()->getCnnlfParamIdx( cnnlfInferGranularity, chal ); + uint8_t sliceCnnlfMode = cs.slice->getCnnlfMode(ChannelType(chal)); + int numParams = cs.sps->getCnnlfMaxNumParams(); + for( int unitIdx = 0; unitIdx < cs.pcv->sizeInCnnlfInferSize[cnnlfInferGranularity]; unitIdx++ ) + { + if (sliceCnnlfMode < numParams + 1) + { + cnnlfParamIdx[unitIdx] = sliceCnnlfMode; + continue; + } + bool useCnnlf, useFirstParam = false; + useCnnlf = m_BinDecoder.decodeBin( Ctx::cnnlfParamIdx( chal * 2 + 0 ) ); + if (numParams == 1) + { + cnnlfParamIdx[unitIdx] = useCnnlf ? 1 : 0; + } + else if (!useCnnlf) + { + cnnlfParamIdx[unitIdx] = 0; + } + else + { + useFirstParam = m_BinDecoder.decodeBin( Ctx::cnnlfParamIdx( chal * 2 + 1 ) ); + if (numParams == 2) + { + cnnlfParamIdx[unitIdx] = useFirstParam ? 1 : 2; + } + else if (useFirstParam) + { + cnnlfParamIdx[unitIdx] = 1; + } + else + { + uint32_t cnnlfParamIdxMinus2 = 0; + xReadTruncBinCode(cnnlfParamIdxMinus2, numParams - 1); + cnnlfParamIdx[unitIdx] = cnnlfParamIdxMinus2 + 2; + } + } + + } + } + } +#endif if (cs.sps->getALFEnabledFlag() && (cs.slice->getTileGroupAlfEnabledFlag(COMPONENT_Y))) { const PreCalcValues& pcv = *cs.pcv; diff --git a/source/Lib/DecoderLib/DecLib.cpp b/source/Lib/DecoderLib/DecLib.cpp index 5e05696bee27614f47928e3fd1a9ad124d05226f..cb8c4ba122144e77427a90efd3de1b34322ad75d 100644 --- a/source/Lib/DecoderLib/DecLib.cpp +++ b/source/Lib/DecoderLib/DecLib.cpp @@ -58,6 +58,10 @@ #include "CommonLib/CodingStatistics.h" #endif +#if CNN_FILTERING +#include <chrono> +#endif + bool tryDecodePicture( Picture* pcEncPic, const int expectedPoc, const std::string& bitstreamFileName, ParameterSetMap<APS> *apsMap, bool bDecodeUntilPocFound /* = false */, int debugCTU /* = -1*/, int debugPOC /* = -1*/ ) { int poc; @@ -490,6 +494,9 @@ void DecLib::create() { m_apcSlicePilot = new Slice; m_uiSliceSegmentIdx = 0; +#if CNN_FILTERING + m_pcCNNFilter = new CNNFilter; +#endif } void DecLib::destroy() @@ -502,7 +509,14 @@ void DecLib::destroy() delete m_dci; m_dci = NULL; } - +#if CNN_FILTERING + if (m_pcCNNFilter) + { + m_pcCNNFilter->destroy(); + delete m_pcCNNFilter; + m_pcCNNFilter = NULL; + } +#endif m_cSliceDecoder.destroy(); } @@ -617,8 +631,18 @@ void DecLib::executeLoopFilters() m_pcPic->cs->slice->startProcessingTimer(); +#if CNN_FILTERING + auto startTime = std::chrono::steady_clock::now(); +#endif + CodingStructure& cs = *m_pcPic->cs; - +#if CNN_FILTERING + if (cs.sps->getCnnlfEnabledFlag()) + { + m_pcCNNFilter->create(cs.pcv->lumaWidth, cs.pcv->lumaHeight, cs.pcv->chrFormat, cs.sps->getCnnlfMaxNumParams()); + m_pcCNNFilter->init(getCnnlfInterLumaModelName(), getCnnlfInterChromaModelName(), getCnnlfIntraLumaModelName(), getCnnlfIntraChromaModelName()); + } +#endif if (cs.sps->getUseLmcs() && cs.picHeader->getLmcsEnabledFlag()) { const PreCalcValues& pcv = *cs.pcv; @@ -653,7 +677,6 @@ void DecLib::executeLoopFilters() } } #endif - m_cReshaper.setRecReshaped(false); m_cSAO.setReshaper(&m_cReshaper); } @@ -663,17 +686,32 @@ void DecLib::executeLoopFilters() #if NNVC_USE_REC_BEFORE_DBF m_pcPic->getRecBeforeDbfBuf().copyFrom(m_pcPic->getRecoBuf()); #endif - // deblocking filter m_cLoopFilter.loopFilterPic( cs ); #if NNVC_USE_REC_AFTER_DBF m_pcPic->getRecAfterDbfBuf().copyFrom(m_pcPic->getRecoBuf()); #endif CS::setRefinedMotionField(cs); + if( cs.sps->getSAOEnabledFlag() ) { m_cSAO.SAOProcess( cs, cs.picture->getSAO() ); } + +#if COMBINE_NN_WITH_LF && !FUSE_NN_AND_LF + if (cs.sps->getCnnlfEnabledFlag()) + { + m_pcPic->getRecoBuf().copyFrom(m_pcPic->getUnfilteredRecBuf()); + } +#endif + + //CNN filter +#if CNN_FILTERING + if (cs.sps->getCnnlfEnabledFlag()) + { + m_pcCNNFilter->cnnFilter(m_pcPic); + } +#endif if( cs.sps->getALFEnabledFlag() ) { @@ -708,7 +746,13 @@ void DecLib::executeLoopFilters() } } +#if CNN_FILTERING + auto endTime = std::chrono::steady_clock::now(); + auto encTime = std::chrono::duration_cast<std::chrono::milliseconds>( endTime - startTime).count(); + m_pcPic->cs->slice->stopProcessingTimer(encTime/1000.0); +#else m_pcPic->cs->slice->stopProcessingTimer(); +#endif } void DecLib::finishPictureLight(int& poc, PicList*& rpcListPic ) diff --git a/source/Lib/DecoderLib/DecLib.h b/source/Lib/DecoderLib/DecLib.h index bb15b81eb4c2ef735348d6532529936d1b47122a..16cfdaf2fdf7c3b1373ae2735024a364fc9edebc 100644 --- a/source/Lib/DecoderLib/DecLib.h +++ b/source/Lib/DecoderLib/DecLib.h @@ -54,6 +54,9 @@ #include "CommonLib/SEI.h" #include "CommonLib/Unit.h" #include "CommonLib/Reshape.h" +#if CNN_FILTERING +#include "CommonLib/CNNFilter.h" +#endif class InputNALUnit; @@ -68,6 +71,12 @@ bool tryDecodePicture( Picture* pcPic, const int expectedPoc, const std::string& class DecLib { private: +#if CNN_FILTERING + std::string m_cnnlfInterLumaModelName; ///<inter luma cnnlf model + std::string m_cnnlfInterChromaModelName; ///<inter chroma cnnlf model + std::string m_cnnlfIntraLumaModelName; ///<intra luma cnnlf model + std::string m_cnnlfIntraChromaModelName; ///<inra chroma cnnlf model +#endif int m_iMaxRefPicNum; bool m_isFirstGeneralHrd; GeneralHrdParams m_prevGeneralHrdParams; @@ -114,6 +123,9 @@ private: AdaptiveLoopFilter m_cALF; Reshape m_cReshaper; ///< reshaper class HRD m_HRD; +#if CNN_FILTERING + CNNFilter* m_pcCNNFilter; +#endif // decoder side RD cost computation RdCost m_cRdCost; ///< RD cost computation class #if JVET_J0090_MEMORY_BANDWITH_MEASURE @@ -209,6 +221,17 @@ public: void create (); void destroy (); + +#if CNN_FILTERING + std::string getCnnlfInterLumaModelName() { return m_cnnlfInterLumaModelName; } + std::string getCnnlfInterChromaModelName() { return m_cnnlfInterChromaModelName; } + std::string getCnnlfIntraLumaModelName() { return m_cnnlfIntraLumaModelName; } + std::string getCnnlfIntraChromaModelName() { return m_cnnlfIntraChromaModelName; } + void setCnnlfInterLumaModelName(std::string s) { m_cnnlfInterLumaModelName = s; } + void setCnnlfInterChromaModelName(std::string s) { m_cnnlfInterChromaModelName = s; } + void setCnnlfIntraLumaModelName(std::string s) { m_cnnlfIntraLumaModelName = s; } + void setCnnlfIntraChromaModelName(std::string s) { m_cnnlfIntraChromaModelName = s; } +#endif void setDecodedPictureHashSEIEnabled(int enabled) { m_decodedPictureHashSEIEnabled=enabled; } diff --git a/source/Lib/DecoderLib/DecSlice.cpp b/source/Lib/DecoderLib/DecSlice.cpp index 6a9a1e1d27e318912c22961487b8c7ecf6d3c90f..166e1e226ce60594976cf8f593ab0b84914815f2 100644 --- a/source/Lib/DecoderLib/DecSlice.cpp +++ b/source/Lib/DecoderLib/DecSlice.cpp @@ -41,6 +41,10 @@ #include <vector> +#if CNN_FILTERING +#include <chrono> +#endif + //! \ingroup DecoderLib //! \{ @@ -75,6 +79,10 @@ void DecSlice::decompressSlice( Slice* slice, InputBitstream* bitstream, int deb //-- For time output for each slice slice->startProcessingTimer(); +#if CNN_FILTERING + auto startTime = std::chrono::steady_clock::now(); +#endif + const SPS* sps = slice->getSPS(); Picture* pic = slice->getPic(); CABACReader& cabacReader = *m_CABACDecoder->getCABACReader( 0 ); @@ -102,7 +110,9 @@ void DecSlice::decompressSlice( Slice* slice, InputBitstream* bitstream, int deb cs.picture->resizeAlfCtbFilterIndex(cs.pcv->sizeInCtus); cs.picture->resizeAlfCtuAlternative( cs.pcv->sizeInCtus ); } - +#if CNN_FILTERING + cs.picture->resizeCnnlfParamIdx( cs.pcv->sizeInCnnlfInferSize ); +#endif const unsigned numSubstreams = slice->getNumberOfSubstreamSizes() + 1; // init each couple {EntropyDecoder, Substream} @@ -294,7 +304,13 @@ void DecSlice::decompressSlice( Slice* slice, InputBitstream* bitstream, int deb { delete substr; } +#if CNN_FILTERING + auto endTime = std::chrono::steady_clock::now(); + auto encTime = std::chrono::duration_cast<std::chrono::milliseconds>( endTime - startTime).count(); + slice->stopProcessingTimer(encTime/1000.0); +#else slice->stopProcessingTimer(); +#endif } //! \} diff --git a/source/Lib/DecoderLib/VLCReader.cpp b/source/Lib/DecoderLib/VLCReader.cpp index 97651347a617ba1718e8b8c0c5eeaba7c9283c0a..77dcd22ae470ec4d1d5092e2d89894788029693a 100644 --- a/source/Lib/DecoderLib/VLCReader.cpp +++ b/source/Lib/DecoderLib/VLCReader.cpp @@ -1728,6 +1728,20 @@ void HLSyntaxReader::parseSPS(SPS* pcSPS) { pcSPS->setCCALFEnabledFlag(false); } + +#if CNN_FILTERING + READ_FLAG( uiCode, "sps_cnnlf_enabled_flag" ); pcSPS->setCnnlfEnabledFlag ( uiCode ? true : false ); + if (pcSPS->getCnnlfEnabledFlag()) + { + READ_UVLC( uiCode, "sps_cnnlf_infer_size_base" ); + unsigned cnnlfInferSize[] = {uiCode >> 1, uiCode, uiCode << 1}; + pcSPS->setCnnlfInferSize (cnnlfInferSize ); + READ_UVLC( uiCode, "sps_cnnlf_infer_size_extension" ); + pcSPS->setCnnlfInferSizeExtension ( uiCode ); + READ_UVLC( uiCode, "sps_cnnlf_max_num_params" ); + pcSPS->setCnnlfMaxNumParams (uiCode ); + } +#endif READ_FLAG(uiCode, "sps_lmcs_enable_flag"); pcSPS->setUseLmcs(uiCode == 1); @@ -4162,6 +4176,73 @@ void HLSyntaxReader::parseSliceHeader (Slice* pcSlice, PicHeader* picHeader, Par pcSlice->setUseChromaQpAdj(false); } +#if CNN_FILTERING + if (sps->getCnnlfEnabledFlag()) + { + READ_UVLC(uiCode, "slice_luma_cnnlf_mode"); pcSlice->setCnnlfMode(CHANNEL_TYPE_LUMA, uiCode); + READ_UVLC(uiCode, "slice_chroma_cnnlf_mode"); pcSlice->setCnnlfMode(CHANNEL_TYPE_CHROMA, uiCode); + +#if SCALE_NN_RESIDUE + for (int chal = 0; chal < MAX_NUM_CHANNEL_TYPE; chal++) + { + ChannelType chType = ChannelType(chal); + if (pcSlice->getCnnlfMode(chType)) + { + int numParams = sps->getCnnlfMaxNumParams(); + if (pcSlice->getCnnlfMode(chType) == numParams + 1) + { + for (int paramIdx = 0; paramIdx < numParams; paramIdx++) + { + READ_FLAG(uiCode, "slice_cnnlf_scale_flag"); + pcSlice->setNnScaleFlag(uiCode != 0, paramIdx, chType); + if (uiCode) + { + READ_SCODE(NN_RESIDUE_SCALE_SHIFT + 1, iCode, "nnScale"); + pcSlice->setNnScale(iCode + (1 << NN_RESIDUE_SCALE_SHIFT), paramIdx, chType); + } + } + } + else + { + READ_FLAG(uiCode, "slice_cnnlf_scale_flag"); + pcSlice->setNnScaleFlag(uiCode != 0, pcSlice->getCnnlfMode(chType) - 1, chType); + if (uiCode) + { + READ_SCODE(NN_RESIDUE_SCALE_SHIFT + 1, iCode, "nnScale"); + pcSlice->setNnScale(iCode + (1 << NN_RESIDUE_SCALE_SHIFT), pcSlice->getCnnlfMode(chType) - 1, chType); + } + } + } + } +#endif + + CnnlfInferGranularity cnnlfInferGranularityLuma = CNNLF_INFER_GRANULARITY_BASE; + CnnlfInferGranularity cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_BASE; + if (pcSlice ->getSliceType() == I_SLICE) + { + cnnlfInferGranularityLuma = CNNLF_INFER_GRANULARITY_LARGE; + cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_LARGE; + } + else if (pcSlice->getSliceQp() < 23) + { + cnnlfInferGranularityLuma = CNNLF_INFER_GRANULARITY_BASE; + cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_BASE; + } + else if (pcSlice->getSliceQp() < 29) + { + cnnlfInferGranularityLuma = CNNLF_INFER_GRANULARITY_BASE; + cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_LARGE; + } + else + { + cnnlfInferGranularityLuma = pps->getPicWidthInLumaSamples() <=832 ? CNNLF_INFER_GRANULARITY_BASE : CNNLF_INFER_GRANULARITY_LARGE; + cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_LARGE; + } + pcSlice->setCnnlfInferGranularity(CHANNEL_TYPE_LUMA, cnnlfInferGranularityLuma); + pcSlice->setCnnlfInferGranularity(CHANNEL_TYPE_CHROMA, cnnlfInferGranularityChroma); + } +#endif + if (sps->getSAOEnabledFlag() && !pps->getSaoInfoInPhFlag()) { READ_FLAG(uiCode, "sh_sao_luma_used_flag"); pcSlice->setSaoEnabledFlag(CHANNEL_TYPE_LUMA, (bool)uiCode); diff --git a/source/Lib/EncoderLib/CABACWriter.cpp b/source/Lib/EncoderLib/CABACWriter.cpp index 087cb0c179cb7c632cbebc16f4f94fac166108d1..b72a82fb49f5deca5e74de96efdf3b3868f2e634 100644 --- a/source/Lib/EncoderLib/CABACWriter.cpp +++ b/source/Lib/EncoderLib/CABACWriter.cpp @@ -166,7 +166,18 @@ void CABACWriter::coding_tree_unit( CodingStructure& cs, const UnitArea& area, i { sao( *cs.slice, ctuRsAddr ); } - +#if CNN_FILTERING + if (cs.sps->getCnnlfEnabledFlag() && !skipSao && ctuRsAddr == 0) + { + int numParams = cs.sps->getCnnlfMaxNumParams(); + for (int chal = 0; chal < MAX_NUM_CHANNEL_TYPE; chal++) + { + if (cs.slice->getCnnlfMode(ChannelType(chal)) < (numParams + 1)) + continue; + codeCnnlfParamIdx(cs, ChannelType(chal)); + } + } +#endif if (!skipAlf) { for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) @@ -3234,7 +3245,36 @@ void CABACWriter::exp_golomb_eqprob( unsigned symbol, unsigned count ) m_BinEncoder.encodeBinsEP(bins, numBins); m_BinEncoder.encodeBinsEP(symbol, count); } +#if CNN_FILTERING +void CABACWriter::codeCnnlfParamIdx( CodingStructure& cs, ChannelType chType) +{ + CnnlfInferGranularity cnnlfInferGranularity = cs.slice->getCnnlfInferGranularity(chType); + for( int unitIdx = 0; unitIdx < cs.pcv->sizeInCnnlfInferSize[cnnlfInferGranularity]; unitIdx++ ) + { + codeCnnlfParamIdx( cs, unitIdx, chType ); + } +} + +void CABACWriter::codeCnnlfParamIdx( CodingStructure& cs, uint32_t unitRsAddr, const int chal) +{ + CnnlfInferGranularity cnnlfInferGranularity = cs.slice->getCnnlfInferGranularity(ChannelType(chal)); + uint8_t* cnnlfParamIdx = cs.slice->getPic()->getCnnlfParamIdx( cnnlfInferGranularity, chal ); + uint8_t uiCode = cnnlfParamIdx[unitRsAddr]; + uint8_t cnnlfMaxNumParams = cs.sps->getCnnlfMaxNumParams(); + + m_BinEncoder.encodeBin( uiCode > 0, Ctx::cnnlfParamIdx( chal * 2 + 0 ) ); + + if (cnnlfMaxNumParams > 1 && uiCode > 0) + { + m_BinEncoder.encodeBin( uiCode == 1, Ctx::cnnlfParamIdx( chal * 2 + 1 ) ); + } + if (cnnlfMaxNumParams > 2 && uiCode > 1) + { + xWriteTruncBinCode(uiCode - 2, cnnlfMaxNumParams - 1); + } +} +#endif void CABACWriter::codeAlfCtuEnableFlags( CodingStructure& cs, ChannelType channel, AlfParam* alfParam) { if( isLuma( channel ) ) diff --git a/source/Lib/EncoderLib/CABACWriter.h b/source/Lib/EncoderLib/CABACWriter.h index 62a39d7a6d29c9fbc2b9755412fae7abfbc2d7ad..5a078c017adf7acfa6a712118b6ffc09d6be01a2 100644 --- a/source/Lib/EncoderLib/CABACWriter.h +++ b/source/Lib/EncoderLib/CABACWriter.h @@ -164,6 +164,11 @@ public: void codeCcAlfFilterControlIdc(uint8_t idcVal, CodingStructure &cs, const ComponentID compID, const int curIdx, const uint8_t *filterControlIdc, Position lumaPos, const int filterCount); +#if CNN_FILTERING + void codeCnnlfParamIdx ( CodingStructure& cs, ChannelType chType); + void codeCnnlfParamIdx ( CodingStructure& cs, uint32_t ctuRsAddr, const int chal ); +#endif + private: void unary_max_symbol ( unsigned symbol, unsigned ctxId0, unsigned ctxIdN, unsigned maxSymbol ); void unary_max_eqprob ( unsigned symbol, unsigned maxSymbol ); diff --git a/source/Lib/EncoderLib/EncCNNFilter.cpp b/source/Lib/EncoderLib/EncCNNFilter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5373d7c16b7f1f9bbc8752ef099615318f394a0a --- /dev/null +++ b/source/Lib/EncoderLib/EncCNNFilter.cpp @@ -0,0 +1,505 @@ +/** \file EncCNNFilter.cpp + \brief encoder convolutional neural network-based filter class +*/ + +#include "EncCNNFilter.h" +#if CNN_FILTERING +#include "UnitTools.h" +#include "UnitPartitioner.h" +#include "CodingStructure.h" +#include "CommonLib/dtrace_codingstruct.h" +#include "CommonLib/dtrace_buffer.h" +#include <chrono> +#include <string.h> +#include <stdlib.h> +#include <stdio.h> +#include <math.h> +#include <ctime> +#include <algorithm> +#include <numeric> +//! \ingroup CommonLib +//! \{ + +#define CNNLFCtx(c) SubCtx( Ctx::cnnlfParamIdx, c ) + +EncCNNFilter::EncCNNFilter() +{ + m_CABACEstimator = NULL; + m_singleModelISlice = false; +} + +EncCNNFilter::~EncCNNFilter() +{ +} +void EncCNNFilter::initCABACEstimator( CABACEncoder* cabacEncoder, CtxCache* ctxCache, Slice* pcSlice ) +{ + m_CABACEstimator = cabacEncoder->getCABACEstimator( pcSlice->getSPS() ); + m_CtxCache = ctxCache; + m_CABACEstimator->initCtxModels( *pcSlice ); + m_CABACEstimator->resetBits(); +} + +double getDistortion(PelBuf buf1, PelBuf buf2, int width, int height) +{ + Pel* p1 = buf1.buf; + int stride1 = buf1.stride; + Pel* p2 = buf2.buf; + int stride2 = buf2.stride; + double dist = 0; + for (int y = 0; y < height; y++) + { + for (int x = 0; x < width; x++) + { + int diff = p1[y * stride1 + x] - p2[y * stride2 + x]; + dist += diff * diff; + } + } + return dist; +} +void setCnnlfInferGranularity(Picture* pic) +{ + Slice* pcSlice = pic->cs->slice; + CnnlfInferGranularity cnnlfInferGranularityLuma = CNNLF_INFER_GRANULARITY_BASE; + CnnlfInferGranularity cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_BASE; + if (pcSlice ->getSliceType() == I_SLICE) + { + pcSlice->setCnnlfInferGranularity(CHANNEL_TYPE_LUMA, CNNLF_INFER_GRANULARITY_LARGE); + pcSlice->setCnnlfInferGranularity(CHANNEL_TYPE_CHROMA, CNNLF_INFER_GRANULARITY_LARGE); + return; + } + if (pcSlice->getSliceQp() < 23) + { + cnnlfInferGranularityLuma = CNNLF_INFER_GRANULARITY_BASE; + cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_BASE; + } + else if (pcSlice->getSliceQp() < 29) + { + cnnlfInferGranularityLuma = CNNLF_INFER_GRANULARITY_BASE; + cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_LARGE; + } + else + { + cnnlfInferGranularityLuma = pic->getPicWidthInLumaSamples() <=832 ? CNNLF_INFER_GRANULARITY_BASE : CNNLF_INFER_GRANULARITY_LARGE; + cnnlfInferGranularityChroma = CNNLF_INFER_GRANULARITY_LARGE; + } + pcSlice->setCnnlfInferGranularity(CHANNEL_TYPE_LUMA, cnnlfInferGranularityLuma); + pcSlice->setCnnlfInferGranularity(CHANNEL_TYPE_CHROMA, cnnlfInferGranularityChroma); +} + +void EncCNNFilter::cnnFilterPicture(Picture* pic, int numParams) +{ + CodingStructure& cs = *pic->cs; + Slice* pcSlice = cs.slice; + const PreCalcValues& pcv = *cs.pcv; + int extension = cs.sps->getCnnlfInferSizeExtension(); + const int numValidChannels = getNumberValidChannels( cs.area.chromaFormat ); + for (int paramIdx = 0; paramIdx < numParams; paramIdx ++) + { + for( int chal = 0; chal < numValidChannels; chal++ ) + { + const ChannelType chType = ChannelType( chal ); + CnnlfInferGranularity cnnlfInferGranularity = pcSlice->getCnnlfInferGranularity(chType); + int blockSize = cs.sps->getCnnlfInferSize (cnnlfInferGranularity); + for( int blockIdx = 0; blockIdx < pcv.sizeInCnnlfInferSize[cnnlfInferGranularity]; blockIdx++ ) + { + int xPosInBlocks = blockIdx % pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int yPosInBlocks = blockIdx / pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int xPos = xPosInBlocks * blockSize; + int yPos = yPosInBlocks * blockSize; + int width = (xPos + blockSize > pcv.lumaWidth) ? (pcv.lumaWidth - xPos) : blockSize; + int height = (yPos + blockSize > pcv.lumaHeight) ? (pcv.lumaHeight - yPos) : blockSize; + + int extLeft = xPos > 0 ? extension : 0; + int extRight = (xPos + width + extension > pcv.lumaWidth) ? (pcv.lumaWidth - xPos - width) : extension; + int extTop = yPos > 0 ? extension : 0; + int extBottom = (yPos + height + extension > pcv.lumaHeight) ? (pcv.lumaHeight - yPos - height) : extension; + + int extXPos = xPos - extLeft; + int extYPos = yPos - extTop; + int extWidth = width + extLeft + extRight; + int extHeight = height + extTop + extBottom; + + const UnitArea extBlock( cs.area.chromaFormat, Area( extXPos, extYPos, extWidth, extHeight ) ); + + if (chal == 0) + { +#if NN_FIXED_POINT_IMPLEMENTATION + cnnFilterLumaBlock<int16_t>(pic, extBlock, extLeft, extRight, extTop, extBottom, paramIdx, pcSlice->getSliceType() != I_SLICE); +#else + cnnFilterLumaBlock<float>(pic, extBlock, extLeft, extRight, extTop, extBottom, paramIdx, pcSlice->getSliceType() != I_SLICE); +#endif + } + else + { +#if NN_FIXED_POINT_IMPLEMENTATION + cnnFilterChromaBlock<int16_t>(pic, extBlock, extLeft >> 1, extRight >> 1, extTop >> 1, extBottom >> 1, paramIdx, pcSlice->getSliceType() != I_SLICE); +#else + cnnFilterChromaBlock<float>(pic, extBlock, extLeft >> 1, extRight >> 1, extTop >> 1, extBottom >> 1, paramIdx, pcSlice->getSliceType() != I_SLICE); +#endif + } + } + } + } +} +#if SCALE_NN_RESIDUE +void EncCNNFilter::calcRDCost(Picture *pic, std::vector<PelStorage>& tempBuf, int numParams, double* minCost, bool scaled) +#else +void EncCNNFilter::calcRDCost(Picture *pic, std::vector<PelStorage>& tempBuf, int numParams, double* minCost) +#endif +{ + CodingStructure& cs = *pic->cs; + Slice* pcSlice = cs.slice; + const PreCalcValues& pcv = *cs.pcv; + const int numValidChannels = getNumberValidChannels( cs.area.chromaFormat ); + for( int chal = 0; chal < numValidChannels; chal++ ) + { + const ChannelType chType = ChannelType( chal ); + CnnlfInferGranularity cnnlfInferGranularity = pcSlice->getCnnlfInferGranularity(chType); + int blockSize = cs.sps->getCnnlfInferSize (cnnlfInferGranularity); + std::vector<uint8_t>* cnnlfParamIdx = pic->getCnnlfParamIdx(cnnlfInferGranularity); + + int64_t distSliceLevelOff = 0, distCtuLevelOn = 0, distSliceLevelOn[numParams]; + for (int paramIdx = 0; paramIdx < numParams; paramIdx ++) + { + distSliceLevelOn[paramIdx] = 0; + } + + // calculate distortion + for( int blockIdx = 0; blockIdx < pcv.sizeInCnnlfInferSize[cnnlfInferGranularity]; blockIdx++ ) + { + int xPosInBlocks = blockIdx % pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int yPosInBlocks = blockIdx / pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int xPos = xPosInBlocks * blockSize; + int yPos = yPosInBlocks * blockSize; + int width = (xPos + blockSize > pcv.lumaWidth) ? (pcv.lumaWidth - xPos) : blockSize; + int height = (yPos + blockSize > pcv.lumaHeight) ? (pcv.lumaHeight - yPos) : blockSize; + + const UnitArea block( cs.area.chromaFormat, Area( xPos, yPos, width, height ) ); + + // block distortion w/o cnnlf + double distRec = chal == 0 ? getDistortion(pic->getOrigBuf(block).get(COMPONENT_Y), pic->getRecoBuf(block).get(COMPONENT_Y), width, height) : getDistortion(pic->getOrigBuf(block).get(COMPONENT_Cb), pic->getRecoBuf(block).get(COMPONENT_Cb), width >> 1, height >> 1) + getDistortion(pic->getOrigBuf(block).get(COMPONENT_Cr), pic->getRecoBuf(block).get(COMPONENT_Cr), width >> 1, height >> 1); + + // slice distortion w/o cnnlf + distSliceLevelOff += distRec; + + double distBest = distRec; + int paramIdxBest = -1; + for (int paramIdx = 0; paramIdx < numParams; paramIdx ++) + { + // block distortion w/ cnnlf paramIdx + double distTemp = chal == 0 ? getDistortion(pic->getOrigBuf(block).get(COMPONENT_Y), tempBuf[paramIdx].getBuf(block).get(COMPONENT_Y), width, height) : getDistortion(pic->getOrigBuf(block).get(COMPONENT_Cb), tempBuf[paramIdx].getBuf(block).get(COMPONENT_Cb), width >> 1, height >> 1) + getDistortion(pic->getOrigBuf(block).get(COMPONENT_Cr), tempBuf[paramIdx].getBuf(block).get(COMPONENT_Cr), width >> 1, height >> 1); + + // slice distortion w/ cnnlf paramIdx + distSliceLevelOn[paramIdx] += distTemp; + + // the best for the block + paramIdxBest = distTemp < distBest ? paramIdx : paramIdxBest; + distBest = distTemp < distBest ? distTemp : distBest; + } + + // the best for the slice + distCtuLevelOn += distBest; + cnnlfParamIdx[chal][blockIdx] = paramIdxBest + 1; + } + + // calculate RD cost + const TempCtx ctxStart ( m_CtxCache, CNNLFCtx( m_CABACEstimator->getCtx() ) ); + m_CABACEstimator->getCtx() = CNNLFCtx( ctxStart ); + m_CABACEstimator->resetBits(); + m_CABACEstimator->codeCnnlfParamIdx(cs, chType); + double rate = FRAC_BITS_SCALE * m_CABACEstimator->getEstFracBits(); + double costSliceLevelOff = (double)distSliceLevelOff; + double costCtuLevelOn; + std::vector<double> costSliceLevelOn(numParams); +#if SCALE_NN_RESIDUE + if (scaled) + { + costCtuLevelOn = (double) distCtuLevelOn + m_lambda[chal] * (rate + 3* (NN_RESIDUE_SCALE_SHIFT + 1)); + for (int paramIdx = 0; paramIdx < numParams; paramIdx ++) + { + costSliceLevelOn[paramIdx] = (double) distSliceLevelOn[paramIdx] + m_lambda[chal] * (NN_RESIDUE_SCALE_SHIFT + 1); + } + } + else +#endif + { + costCtuLevelOn = (double)distCtuLevelOn + m_lambda[chal]*rate; + for (int paramIdx = 0; paramIdx < numParams; paramIdx ++) + { + costSliceLevelOn[paramIdx] = (double)distSliceLevelOn[paramIdx]; + } + } + + // find the best cost + double costBest; + if (costSliceLevelOff < costCtuLevelOn) + { + costBest = costSliceLevelOff; + pcSlice->setCnnlfMode(chType, 0); + } + else + { + costBest = costCtuLevelOn; + pcSlice->setCnnlfMode(chType, numParams + 1); + } + + for (int paramIdx = 0; paramIdx < numParams; paramIdx ++) + { + if (costSliceLevelOn[paramIdx] < costBest) + { + costBest = costSliceLevelOn[paramIdx]; + pcSlice->setCnnlfMode(chType, paramIdx + 1); + } + } + + minCost[chal] = costBest; + } +} +void EncCNNFilter::cnnFilterEncoder(Picture *pic, const double *lambdas) +{ + CodingStructure& cs = *pic->cs; + Slice* pcSlice = cs.slice; + const PreCalcValues& pcv = *cs.pcv; + memcpy(m_lambda, lambdas, sizeof(m_lambda)); + const int numValidComponents = getNumberValidComponents( cs.area.chromaFormat ); + const int numValidChannels = getNumberValidChannels( cs.area.chromaFormat ); + + setCnnlfInferGranularity(pic); + + int numParams = cs.sps->getCnnlfMaxNumParams(); + + cnnFilterPicture(pic, numParams); + + double minCost[2] = {MAX_DOUBLE, MAX_DOUBLE}; + +#if SCALE_NN_RESIDUE + calcRDCost(pic, m_tempBuf, numParams, minCost, false); + scalePicture(pic, numParams); + double minScaledCost[2] = {MAX_DOUBLE, MAX_DOUBLE}; + double sliceCnnlfMode[2]; + for( int chal = 0; chal < numValidChannels; chal++ ) + { + const ChannelType chType = ChannelType( chal ); + CnnlfInferGranularity cnnlfInferGranularity = pcSlice->getCnnlfInferGranularity(chType); + std::vector<uint8_t> *cnnlfParamIdx = pic->getCnnlfParamIdx(cnnlfInferGranularity); + std::vector<uint8_t> *cnnlfBackupParamIdx = pic->getCnnlfBackupParamIdx(cnnlfInferGranularity); + for (int blockIdx = 0; blockIdx < pcv.sizeInCnnlfInferSize[cnnlfInferGranularity]; blockIdx++) + { + cnnlfBackupParamIdx[chal][blockIdx] = cnnlfParamIdx[chal][blockIdx]; + } + sliceCnnlfMode[chal] = pcSlice->getCnnlfMode(chType); + } + calcRDCost(pic, m_tempScaledBuf, numParams, minScaledCost, true); + for (int chal = 0; chal < numValidChannels; chal++) + { + const ChannelType chType = ChannelType( chal ); + CnnlfInferGranularity cnnlfInferGranularity = pcSlice->getCnnlfInferGranularity(chType); + if (minCost[chal] <= minScaledCost[chal]) + { + std::vector<uint8_t> *cnnlfParamIdx = pic->getCnnlfParamIdx(cnnlfInferGranularity); + std::vector<uint8_t> *cnnlfBackupParamIdx = pic->getCnnlfBackupParamIdx(cnnlfInferGranularity); + for (int blockIdx = 0; blockIdx < pcv.sizeInCnnlfInferSize[cnnlfInferGranularity]; blockIdx++) + { + cnnlfParamIdx[chal][blockIdx] = cnnlfBackupParamIdx[chal][blockIdx]; + } + pcSlice->setCnnlfMode(chType, sliceCnnlfMode[chal]); + for (int paramIdx = 0; paramIdx < numParams; paramIdx ++) + { + pcSlice->setNnScaleFlag(false, paramIdx, chType); + } + } + else + { + for (int paramIdx = 0; paramIdx < numParams; paramIdx ++) + { + pcSlice->setNnScaleFlag(true, paramIdx, chType); + } + } + } +#else + calcRDCost(pic, m_tempBuf, numParams, minCost); +#endif + + for( int comp = 0; comp < numValidComponents; comp++ ) + { + const ComponentID compID = ComponentID( comp ); + int chal = toChannelType(compID); + CnnlfInferGranularity cnnlfInferGranularity = pcSlice->getCnnlfInferGranularity(toChannelType(compID)); + int blockSize = cs.sps->getCnnlfInferSize (cnnlfInferGranularity); + int sliceCnnlfMode = pcSlice->getCnnlfMode(toChannelType(compID)); + + for( int blockIdx = 0; blockIdx < pcv.sizeInCnnlfInferSize[cnnlfInferGranularity]; blockIdx++ ) + { + int xPosInBlocks = blockIdx % pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int yPosInBlocks = blockIdx / pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int xPos = xPosInBlocks * blockSize; + int yPos = yPosInBlocks * blockSize; + int width = (xPos + blockSize > pcv.lumaWidth) ? (pcv.lumaWidth - xPos) : blockSize; + int height = (yPos + blockSize > pcv.lumaHeight) ? (pcv.lumaHeight - yPos) : blockSize; + + const UnitArea block( cs.area.chromaFormat, Area( xPos, yPos, width, height ) ); + + std::vector<uint8_t>* cnnlfParamIdx = pic->getCnnlfParamIdx(cnnlfInferGranularity); + if (sliceCnnlfMode < (numParams + 1)) + cnnlfParamIdx[chal][blockIdx] = sliceCnnlfMode; + if (cnnlfParamIdx[chal][blockIdx]) + { +#if SCALE_NN_RESIDUE + if (pcSlice->getNnScaleFlag(cnnlfParamIdx[chal][blockIdx]-1, toChannelType(compID))) + pic->getRecoBuf(block).get(compID).copyFrom(m_tempScaledBuf[cnnlfParamIdx[chal][blockIdx] - 1].getBuf(block).get(compID)); + else +#endif + pic->getRecoBuf(block).get(compID).copyFrom(m_tempBuf[cnnlfParamIdx[chal][blockIdx] - 1].getBuf(block).get(compID)); + } + } + } +} + + +#if SCALE_NN_RESIDUE +#define NN_SCALE_STABLIZING_FACTOR (0.1 * (1 << NN_RESIDUE_ADDITIONAL_SHIFT)) +void EncCNNFilter::scaleFactorDerivation(Picture *pic, int paramIdx) +{ + CodingStructure &cs = *pic->cs; + Slice * pcSlice = cs.slice; + + PelUnitBuf recoBuf = pic->getRecoBuf(); + PelUnitBuf origBuf = pic->getOrigBuf(); + PelUnitBuf filteredBuf = m_tempScaledBuf[paramIdx]; + + PelBuf origBufY = origBuf.get(COMPONENT_Y); + PelBuf origBufCb = origBuf.get(COMPONENT_Cb); + PelBuf origBufCr = origBuf.get(COMPONENT_Cr); + PelBuf reconBufY = recoBuf.get(COMPONENT_Y); + PelBuf reconBufCb = recoBuf.get(COMPONENT_Cb); + PelBuf reconBufCr = recoBuf.get(COMPONENT_Cr); + PelBuf nnFilteredBufY = filteredBuf.get(COMPONENT_Y); + PelBuf nnFilteredBufCb = filteredBuf.get(COMPONENT_Cb); + PelBuf nnFilteredBufCr = filteredBuf.get(COMPONENT_Cr); + + int pic_w = reconBufY.width; + int pic_h = reconBufY.height; + int pic_wh = pic_w * pic_h; + + int scaleY, scaleC; + int shiftY = NN_RESIDUE_SCALE_SHIFT; + int shiftC = NN_RESIDUE_SCALE_SHIFT; + + int scaleUpBoundY = int(NN_RESIDUE_SCALE_DEVIATION_UP_BOUND * (1 << shiftY)); + int scaleLowBoundY = int(NN_RESIDUE_SCALE_DEVIATION_BOT_BOUND * (1 << shiftY)); + + int scaleUpBoundC = int(NN_RESIDUE_SCALE_DEVIATION_UP_BOUND * (1 << shiftC)); + int scaleLowBoundC = int(NN_RESIDUE_SCALE_DEVIATION_BOT_BOUND * (1 << shiftC)); + + double selfMulti[MAX_NUM_CHANNEL_TYPE] = { 0., 0.}; + double crossMulti[MAX_NUM_CHANNEL_TYPE] = { 0., 0.}; + double sumOriResi[MAX_NUM_CHANNEL_TYPE] = { 0., 0.}; + double sumNnResi[MAX_NUM_CHANNEL_TYPE] = { 0., 0.}; + + for (int y = 0; y < pic_h; y++) + { + for (int x = 0; x < pic_w; x++) + { + int oriResi = (origBufY.at(x, y) - reconBufY.at(x, y)) << NN_RESIDUE_ADDITIONAL_SHIFT; + int nnResi = nnFilteredBufY.at(x, y) - (reconBufY.at(x, y) << NN_RESIDUE_ADDITIONAL_SHIFT); + + selfMulti[CHANNEL_TYPE_LUMA] += nnResi * nnResi; + crossMulti[CHANNEL_TYPE_LUMA] += nnResi * oriResi; + sumOriResi[CHANNEL_TYPE_LUMA] += oriResi; + sumNnResi[CHANNEL_TYPE_LUMA] += nnResi; + } + } + + scaleY = int(((pic_wh * crossMulti[CHANNEL_TYPE_LUMA] - sumOriResi[CHANNEL_TYPE_LUMA] * sumNnResi[CHANNEL_TYPE_LUMA] + + pic_wh * pic_wh * NN_SCALE_STABLIZING_FACTOR) + / (pic_wh * selfMulti[CHANNEL_TYPE_LUMA] - sumNnResi[CHANNEL_TYPE_LUMA] * sumNnResi[CHANNEL_TYPE_LUMA] + + pic_wh * pic_wh * NN_SCALE_STABLIZING_FACTOR)) + * (1 << shiftY) + + 0.5); + + if (scaleY > scaleUpBoundY) + { + scaleY = scaleUpBoundY; + } + if (scaleY < scaleLowBoundY) + { + scaleY = scaleLowBoundY; + } + + for (int y = 0; y < pic_h / 2; y++) + { + for (int x = 0; x < pic_w / 2; x++) + { + + int oriResiCb = (origBufCb.at(x, y) - reconBufCb.at(x, y)) << NN_RESIDUE_ADDITIONAL_SHIFT; + int nnResiCb = nnFilteredBufCb.at(x, y) - (reconBufCb.at(x, y) << NN_RESIDUE_ADDITIONAL_SHIFT); + + int oriResiCr = (origBufCr.at(x, y) - reconBufCr.at(x, y)) << NN_RESIDUE_ADDITIONAL_SHIFT; + int nnResiCr = nnFilteredBufCr.at(x, y) - (reconBufCr.at(x, y) << NN_RESIDUE_ADDITIONAL_SHIFT); + + selfMulti[CHANNEL_TYPE_CHROMA] += nnResiCb * nnResiCb + nnResiCr * nnResiCr; + crossMulti[CHANNEL_TYPE_CHROMA] += nnResiCb * oriResiCb + nnResiCr * oriResiCr; + sumOriResi[CHANNEL_TYPE_CHROMA] += oriResiCb + oriResiCr; + sumNnResi[CHANNEL_TYPE_CHROMA] += nnResiCb + nnResiCr; + } + } + + scaleC = int(((pic_wh / 4 * 2 * crossMulti[CHANNEL_TYPE_CHROMA] - sumOriResi[CHANNEL_TYPE_CHROMA] * sumNnResi[CHANNEL_TYPE_CHROMA] + + pic_wh / 4 * 2 * pic_wh / 4 * 2 * NN_SCALE_STABLIZING_FACTOR) + / (pic_wh / 4 * 2 * selfMulti[CHANNEL_TYPE_CHROMA] - sumNnResi[CHANNEL_TYPE_CHROMA] * sumNnResi[CHANNEL_TYPE_CHROMA] + + pic_wh / 4 * 2 * pic_wh / 4 * 2 * CHANNEL_TYPE_CHROMA)) + * (1 << shiftC) + + 0.5); + + if (scaleC > scaleUpBoundC) + { + scaleC = scaleUpBoundC; + } + if (scaleC < scaleLowBoundC) + { + scaleC = scaleLowBoundC; + } + + pcSlice->setNnScale(scaleY, paramIdx, CHANNEL_TYPE_LUMA); + + pcSlice->setNnScale(scaleC, paramIdx, CHANNEL_TYPE_CHROMA); + +} + +void EncCNNFilter::scalePicture(Picture* pic, int numParams) +{ + CodingStructure& cs = *pic->cs; + Slice* pcSlice = cs.slice; + const PreCalcValues& pcv = *cs.pcv; + const int numValidComponents = getNumberValidComponents( cs.area.chromaFormat ); + for (int paramIdx = 0; paramIdx < numParams; paramIdx++) + { + scaleFactorDerivation(pic, paramIdx); + } + + for (int paramIdx = 0; paramIdx < numParams; paramIdx++) + { + for (int comp = 0; comp < numValidComponents; comp++) + { + const ComponentID compID = ComponentID(comp); + CnnlfInferGranularity cnnlfInferGranularity = pcSlice->getCnnlfInferGranularity(toChannelType(compID)); + int blockSize = cs.sps->getCnnlfInferSize (cnnlfInferGranularity); + for (int blockIdx = 0; blockIdx < pcv.sizeInCnnlfInferSize[cnnlfInferGranularity]; blockIdx++) + { + int xPosInBlocks = blockIdx % pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int yPosInBlocks = blockIdx / pcv.widthInCnnlfInferSize[cnnlfInferGranularity]; + int xPos = xPosInBlocks * blockSize; + int yPos = yPosInBlocks * blockSize; + int width = (xPos + blockSize > pcv.lumaWidth) ? (pcv.lumaWidth - xPos) : blockSize; + int height = (yPos + blockSize > pcv.lumaHeight) ? (pcv.lumaHeight - yPos) : blockSize; + const UnitArea block( cs.area.chromaFormat, Area( xPos, yPos, width, height ) ); + scaleResidualBlock(pic, block, paramIdx, compID); + } + } + + } +} +#endif +#endif + +//! \} diff --git a/source/Lib/EncoderLib/EncCNNFilter.h b/source/Lib/EncoderLib/EncCNNFilter.h new file mode 100644 index 0000000000000000000000000000000000000000..9804c428ce6f7763d72f872976b1bffe81c864a7 --- /dev/null +++ b/source/Lib/EncoderLib/EncCNNFilter.h @@ -0,0 +1,47 @@ +/** \file EncCNNFilter.h + \brief encoder convolutional neural network-based filter class (header) +*/ + +#ifndef __ENCCNNFILTER__ +#define __ENCCNNFILTER__ + +#include "CommonDef.h" +#include "Unit.h" +#include "Picture.h" +#include "Reshape.h" +#include "CABACWriter.h" +#include "CommonLib/CNNFilter.h" +//! \ingroup CommonLib +//! \{ + + +class EncCNNFilter : public CNNFilter +{ +public: + EncCNNFilter(); + virtual ~EncCNNFilter(); + + //for RDO + CABACWriter* m_CABACEstimator; + CtxCache* m_CtxCache; + double m_lambda[MAX_NUM_COMPONENT]; + bool m_singleModelISlice; + + void initCABACEstimator( CABACEncoder* cabacEncoder, CtxCache* ctxCache, Slice* pcSlice ); + +#if SCALE_NN_RESIDUE + void scaleFactorDerivation(Picture *pic, int paramIdx); + void scalePicture(Picture* pic, int numParams); +#endif + void cnnFilterEncoder(Picture *pic, const double *lambdas); + void cnnFilterPicture(Picture* pic, int numParams); +#if SCALE_NN_RESIDUE + void calcRDCost(Picture *pic, std::vector<PelStorage>& tempBuf, int numParams, double* minCost, bool scaled); +#else + void calcRDCost(Picture *pic,std::vector<PelStorage>& tempBuf, int numParams, double* minCost) +#endif +}; + +//! \} +#endif + diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h index fa16f2895268bf28700edff711e19d0ead12c15d..7134aeed9f17eeb427445488ec681b7ee51357b4 100644 --- a/source/Lib/EncoderLib/EncCfg.h +++ b/source/Lib/EncoderLib/EncCfg.h @@ -158,6 +158,12 @@ class EncCfg { protected: //==== File I/O ======== +#if CNN_FILTERING + std::string m_cnnlfInterLumaModelName; ///<inter luma cnnlf model + std::string m_cnnlfInterChromaModelName; ///<inter chroma cnnlf model + std::string m_cnnlfIntraLumaModelName; ///<intra luma cnnlf model + std::string m_cnnlfIntraChromaModelName; ///<inra chroma cnnlf model +#endif int m_iFrameRate; int m_FrameSkip; uint32_t m_temporalSubsampleRatio; @@ -745,6 +751,12 @@ protected: #endif bool m_ccalf; int m_ccalfQpThreshold; +#if CNN_FILTERING + bool m_cnnlf; + unsigned m_cnnlfInferSizeBase; + unsigned m_cnnlfInferSizeExtension; + unsigned m_cnnlfMaxNumParams; +#endif #if JVET_O0756_CALCULATE_HDRMETRICS double m_whitePointDeltaE[hdrtoolslib::NB_REF_WHITE]; double m_maxSampleValue; @@ -776,6 +788,16 @@ public: virtual ~EncCfg() {} +#if CNN_FILTERING + std::string getCnnlfInterLumaModelName() { return m_cnnlfInterLumaModelName; } + std::string getCnnlfInterChromaModelName() { return m_cnnlfInterChromaModelName; } + std::string getCnnlfIntraLumaModelName() { return m_cnnlfIntraLumaModelName; } + std::string getCnnlfIntraChromaModelName() { return m_cnnlfIntraChromaModelName; } + void setCnnlfInterLumaModelName(std::string s) { m_cnnlfInterLumaModelName = s; } + void setCnnlfInterChromaModelName(std::string s) { m_cnnlfInterChromaModelName = s; } + void setCnnlfIntraLumaModelName(std::string s) { m_cnnlfIntraLumaModelName = s; } + void setCnnlfIntraChromaModelName(std::string s) { m_cnnlfIntraChromaModelName = s; } +#endif void setProfile(Profile::Name profile) { m_profile = profile; } void setLevel(Level::Tier tier, Level::Name level) { m_levelTier = tier; m_level = level; } bool getFrameOnlyConstraintFlag() const { return m_frameOnlyConstraintFlag; } @@ -1940,6 +1962,16 @@ public: bool getUseCCALF() const { return m_ccalf; } void setCCALFQpThreshold( int b ) { m_ccalfQpThreshold = b; } int getCCALFQpThreshold() const { return m_ccalfQpThreshold; } +#if CNN_FILTERING + void setUseCnnlf( bool b ) { m_cnnlf = b; } + bool getUseCnnlf() const { return m_cnnlf; } + void setCnnlfInferSizeBase( unsigned s ) { m_cnnlfInferSizeBase = s; } + unsigned getCnnlfInferSizeBase() const { return m_cnnlfInferSizeBase; } + void setCnnlfInferSizeExtension( unsigned s ) { m_cnnlfInferSizeExtension = s; } + unsigned getCnnlfInferSizeExtension() const { return m_cnnlfInferSizeExtension; } + void setCnnlfMaxNumParams( unsigned s ) { m_cnnlfMaxNumParams = s; } + unsigned getCnnlfMaxNumParams() const { return m_cnnlfMaxNumParams; } +#endif #if JVET_O0756_CALCULATE_HDRMETRICS void setWhitePointDeltaE( uint32_t index, double value ) { m_whitePointDeltaE[ index ] = value; } double getWhitePointDeltaE( uint32_t index ) const { return m_whitePointDeltaE[ index ]; } diff --git a/source/Lib/EncoderLib/EncGOP.cpp b/source/Lib/EncoderLib/EncGOP.cpp index 5ccc0d3602ef602716fac32087dc4df5a827d223..75654047f42bed61c06c12a07daac820e9feacb0 100644 --- a/source/Lib/EncoderLib/EncGOP.cpp +++ b/source/Lib/EncoderLib/EncGOP.cpp @@ -178,6 +178,9 @@ void EncGOP::create() { m_bLongtermTestPictureHasBeenCoded = 0; m_bLongtermTestPictureHasBeenCoded2 = 0; +#if CNN_FILTERING + m_pcCNNFilter = new EncCNNFilter; +#endif } void EncGOP::destroy() @@ -202,6 +205,14 @@ void EncGOP::destroy() delete m_picOrig; m_picOrig = NULL; } +#if CNN_FILTERING + if (m_pcCNNFilter) + { + m_pcCNNFilter->destroy(); + delete m_pcCNNFilter; + m_pcCNNFilter = NULL; + } +#endif } void EncGOP::init ( EncLib* pcEncLib ) @@ -222,6 +233,10 @@ void EncGOP::init ( EncLib* pcEncLib ) m_AUWriterIf = pcEncLib->getAUWriterIf(); +#if CNN_FILTERING + m_pcCNNFilter->create(m_pcCfg->getSourceWidth(), m_pcCfg->getSourceHeight(), m_pcCfg->getChromaFormatIdc(), m_pcCfg->getCnnlfMaxNumParams()); +#endif + #if WCG_EXT if (m_pcCfg->getLmcs()) { @@ -2744,6 +2759,10 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, pcPic->resizeAlfCtbFilterIndex(numberOfCtusInFrame); } +#if CNN_FILTERING + pcPic->resizeCnnlfParamIdx( pcPic->cs->pcv->sizeInCnnlfInferSize ); +#endif + bool decPic = false; bool encPic = false; // test if we can skip the picture entirely or decode instead of encoding @@ -3021,7 +3040,7 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, #endif CS::setRefinedMotionField(cs); - + if( pcSlice->getSPS()->getSAOEnabledFlag() ) { bool sliceEnabled[MAX_NUM_COMPONENT]; @@ -3051,6 +3070,22 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, } } + +#if COMBINE_NN_WITH_LF && !FUSE_NN_AND_LF + if ( cs.sps->getCnnlfEnabledFlag() ) + { + pcPic->getRecoBuf().copyFrom(pcPic->getUnfilteredRecBuf()); + } +#endif + +#if CNN_FILTERING + if ( cs.sps->getCnnlfEnabledFlag() ) + { + m_pcCNNFilter->init(m_pcEncLib->getCnnlfInterLumaModelName(), m_pcEncLib->getCnnlfInterChromaModelName(), m_pcEncLib->getCnnlfIntraLumaModelName(), m_pcEncLib->getCnnlfIntraChromaModelName()); + m_pcCNNFilter->initCABACEstimator( m_pcEncLib->getCABACEncoder(), m_pcEncLib->getCtxCache(), pcSlice ); + m_pcCNNFilter->cnnFilterEncoder(pcPic, pcSlice->getLambdas()); + } +#endif if( pcSlice->getSPS()->getALFEnabledFlag() ) { diff --git a/source/Lib/EncoderLib/EncGOP.h b/source/Lib/EncoderLib/EncGOP.h index 5a89d4898a9a3e225e068491317fefa9c12c06de..43c87bf25e9285ce74377f6bd787f6313311270f 100644 --- a/source/Lib/EncoderLib/EncGOP.h +++ b/source/Lib/EncoderLib/EncGOP.h @@ -71,6 +71,10 @@ #include <chrono> #endif +#if CNN_FILTERING +#include "EncoderLib/EncCNNFilter.h" +#endif + //! \ingroup EncoderLib //! \{ @@ -136,6 +140,10 @@ private: EncSlice* m_pcSliceEncoder; PicList* m_pcListPic; +#if CNN_FILTERING + EncCNNFilter* m_pcCNNFilter; +#endif + HLSWriter* m_HLSWriter; LoopFilter* m_pcLoopFilter; diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp index bb5e51f65e59e0bfc00ba9105d6f17209808746f..c6bf1b5afb7da85bb193ce92c7743b81a7329755 100644 --- a/source/Lib/EncoderLib/EncLib.cpp +++ b/source/Lib/EncoderLib/EncLib.cpp @@ -1293,7 +1293,7 @@ void EncLib::xInitSPS( SPS& sps ) sps.setMaxCUHeight ( m_maxCUHeight ); sps.setLog2MinCodingBlockSize ( m_log2MinCUSize ); sps.setChromaFormatIdc ( m_chromaFormatIDC ); - + sps.setCTUSize ( m_CTUSize ); sps.setSplitConsOverrideEnabledFlag ( m_useSplitConsOverride ); sps.setMinQTSizes ( m_uiMinQT ); @@ -1412,6 +1412,24 @@ void EncLib::xInitSPS( SPS& sps ) sps.setCCALFEnabledFlag( m_ccalf ); sps.setFieldSeqFlag(false); sps.setVuiParametersPresentFlag(getVuiParametersPresentFlag()); + +#if CNN_FILTERING + sps.setCnnlfEnabledFlag(m_cnnlf); + if (sps.getCnnlfEnabledFlag()) + { + unsigned cnnlfInferSize[] = {m_cnnlfInferSizeBase >> 1, m_cnnlfInferSizeBase, m_cnnlfInferSizeBase << 1}; + sps.setCnnlfInferSize(cnnlfInferSize); + sps.setCnnlfInferSizeExtension(m_cnnlfInferSizeExtension); + if (m_intraPeriod == 1) + { + sps.setCnnlfMaxNumParams(1); // to be removed + } + else + { + sps.setCnnlfMaxNumParams(m_cnnlfMaxNumParams); + } + } +#endif if (sps.getVuiParametersPresentFlag()) { diff --git a/source/Lib/EncoderLib/EncSampleAdaptiveOffset.cpp b/source/Lib/EncoderLib/EncSampleAdaptiveOffset.cpp index 90583a5a4b8a0446a23d3896789dd455e8837cdd..e14f09ffb74ab0f247846324b5bca5e057b5817f 100644 --- a/source/Lib/EncoderLib/EncSampleAdaptiveOffset.cpp +++ b/source/Lib/EncoderLib/EncSampleAdaptiveOffset.cpp @@ -213,6 +213,14 @@ void EncSampleAdaptiveOffset::SAOProcess( CodingStructure& cs, bool* sliceEnable #endif const bool bTestSAODisableAtPictureLevel, const double saoEncodingRate, const double saoEncodingRateChroma, const bool isPreDBFSamplesUsed, bool isGreedyMergeEncoding ) { +#if CNN_FILTERING + if (cs.sps->getCnnlfEnabledFlag()) + { + sliceEnabled[0] = sliceEnabled[1] = sliceEnabled[2] = false; + return; + } +#endif + PelUnitBuf org = cs.getOrgBuf(); PelUnitBuf res = cs.getRecoBuf(); PelUnitBuf src = m_tempBuf; diff --git a/source/Lib/EncoderLib/VLCWriter.cpp b/source/Lib/EncoderLib/VLCWriter.cpp index e57585fa193d71ae5359fbdf1801c2c28b62a11f..86a2bbbbbe8f210be4847c25b36058ee61bd95e9 100644 --- a/source/Lib/EncoderLib/VLCWriter.cpp +++ b/source/Lib/EncoderLib/VLCWriter.cpp @@ -1006,6 +1006,15 @@ void HLSWriter::codeSPS( const SPS* pcSPS ) { WRITE_FLAG( pcSPS->getCCALFEnabledFlag(), "sps_ccalf_enabled_flag" ); } +#if CNN_FILTERING + WRITE_FLAG( pcSPS->getCnnlfEnabledFlag(), "sps_cnnlf_enabled_flag" ); + if (pcSPS->getCnnlfEnabledFlag()) + { + WRITE_UVLC( pcSPS->getCnnlfInferSize(CNNLF_INFER_GRANULARITY_BASE), "sps_cnnlf_infer_size_base" ); + WRITE_UVLC( pcSPS->getCnnlfInferSizeExtension(), "sps_cnnlf_infer_size_extension" ); + WRITE_UVLC( pcSPS->getCnnlfMaxNumParams(), "sps_cnnlf_max_num_params" ); + } +#endif WRITE_FLAG(pcSPS->getUseLmcs() ? 1 : 0, "sps_lmcs_enable_flag"); WRITE_FLAG(pcSPS->getUseWP() ? 1 : 0, "sps_weighted_pred_flag"); // Use of Weighting Prediction (P_SLICE) WRITE_FLAG(pcSPS->getUseWPBiPred() ? 1 : 0, "sps_weighted_bipred_flag"); // Use of Weighting Bi-Prediction (B_SLICE) @@ -2473,6 +2482,44 @@ void HLSWriter::codeSliceHeader ( Slice* pcSlice, PicHeader *picHeader ) WRITE_FLAG(pcSlice->getUseChromaQpAdj(), "sh_cu_chroma_qp_offset_enabled_flag"); } +#if CNN_FILTERING + if (pcSlice->getSPS()->getCnnlfEnabledFlag()) + { + WRITE_UVLC( pcSlice->getCnnlfMode( CHANNEL_TYPE_LUMA ), "slice_luma_cnnlf_mode" ); + WRITE_UVLC( pcSlice->getCnnlfMode( CHANNEL_TYPE_CHROMA ), "slice_chroma_cnnlf_mode" ); + +#if SCALE_NN_RESIDUE + for (int chal = 0; chal < MAX_NUM_CHANNEL_TYPE; chal++) + { + ChannelType chType = ChannelType(chal); + if (pcSlice->getCnnlfMode(chType)) + { + int numParams = pcSlice->getSPS()->getCnnlfMaxNumParams(); + if (pcSlice->getCnnlfMode(chType) == numParams + 1) + { + for (int paramIdx = 0; paramIdx < numParams; paramIdx++) + { + WRITE_FLAG(pcSlice->getNnScaleFlag( paramIdx, chType) , "slice_cnnlf_scale_flag"); + if (pcSlice->getNnScaleFlag(paramIdx, chType)) + { + WRITE_SCODE(pcSlice->getNnScale(paramIdx, chType) - (1 << NN_RESIDUE_SCALE_SHIFT), NN_RESIDUE_SCALE_SHIFT + 1, "nnScale"); + } + } + } + else + { + WRITE_FLAG(pcSlice->getNnScaleFlag( pcSlice->getCnnlfMode(chType) - 1, chType) , "slice_cnnlf_scale_flag"); + if (pcSlice->getNnScaleFlag(pcSlice->getCnnlfMode(chType) - 1, chType)) + { + WRITE_SCODE(pcSlice->getNnScale( pcSlice->getCnnlfMode(chType) - 1, chType) - (1 << NN_RESIDUE_SCALE_SHIFT), NN_RESIDUE_SCALE_SHIFT + 1, "nnScale"); + } + } + } + } +#endif + } +#endif + if (pcSlice->getSPS()->getSAOEnabledFlag() && !pcSlice->getPPS()->getSaoInfoInPhFlag()) { WRITE_FLAG( pcSlice->getSaoEnabledFlag( CHANNEL_TYPE_LUMA ), "sh_sao_luma_used_flag" ); @@ -3077,4 +3124,4 @@ void HLSWriter::alfFilter( const AlfParam& alfParam, const bool isChroma, const } -//! \} \ No newline at end of file +//! \}