From cd9026865ccae728b82425ec4c9c625f5b3e047c Mon Sep 17 00:00:00 2001 From: Liqiang Wang <liqiangwang@tencent.com> Date: Tue, 1 Nov 2022 15:03:18 +0800 Subject: [PATCH] JVET-AB0083: EE1-1.8: More refinements on NN based in-loop filter with a single model (Test 1) --- source/Lib/CommonLib/CommonDef.h | 4 + source/Lib/CommonLib/NNFilterSet0.cpp | 92 +++++++++++++++++++++- source/Lib/CommonLib/NNFilterSet0.h | 15 ++++ source/Lib/CommonLib/Slice.h | 10 +++ source/Lib/CommonLib/TypeDef.h | 1 + source/Lib/DecoderLib/VLCReader.cpp | 22 ++++++ source/Lib/DecoderLib/VLCReader.h | 3 + source/Lib/EncoderLib/EncNNFilterSet0.cpp | 94 +++++++++++++++++++++++ source/Lib/EncoderLib/VLCWriter.cpp | 13 ++++ source/Lib/EncoderLib/VLCWriter.h | 3 + 10 files changed, 256 insertions(+), 1 deletion(-) diff --git a/source/Lib/CommonLib/CommonDef.h b/source/Lib/CommonLib/CommonDef.h index 9370aa6649..b2d3086e60 100644 --- a/source/Lib/CommonLib/CommonDef.h +++ b/source/Lib/CommonLib/CommonDef.h @@ -190,6 +190,10 @@ static const float NN_RESIDUE_SCALE_DEVIATION_BOT_BOUND= 0.0625f; #if NN_FILTERING_SET_0 static const int MAX_NUM_CNN = 4; static constexpr float NN_SCALE_SAFETY_FACTOR= (0.1f * (1<< NN_RESIDUE_ADDITIONAL_SHIFT)); +#if JVET_AB0083_QPADJ +static const int QP_OFFSET_NUM = 2; +static const int NNQPOFFSET[QP_OFFSET_NUM] = { -5, 5 }; +#endif #endif #if NN_FILTERING_SET_1 && NN_FIXED_POINT_IMPLEMENTATION static const int NN_INPUT_PRECISION= 13; diff --git a/source/Lib/CommonLib/NNFilterSet0.cpp b/source/Lib/CommonLib/NNFilterSet0.cpp index 28c6f06b8a..1807fa37cd 100644 --- a/source/Lib/CommonLib/NNFilterSet0.cpp +++ b/source/Lib/CommonLib/NNFilterSet0.cpp @@ -70,7 +70,18 @@ void NNFilterSet0::PreCNNLFProcess(Picture* pic, CodingStructure& cs, CnnlfSlice PelUnitBuf cnnYuv = m_tempCnnBuf[0].getBuf(cs.area); // run DRNLF +#if JVET_AB0083_QPADJ + if (cs.slice->getNNQPOffsetEnable()) + { + const int offsetIdx = cs.slice->getNNBestQPOffset(); + const int QPoffset = NNQPOFFSET[offsetIdx]; + runCNNLF(pic, cnnYuv, cnnYuv, QPoffset, true); + } + else + runCNNLF(pic, cnnYuv, cnnYuv, 0, true); +#else runCNNLF(pic, cnnYuv, 0, true); +#endif for (int i = 1; i < MAX_NUM_CNN; i++) { @@ -154,6 +165,21 @@ void NNFilterSet0::create( const int picWidth, const int picHeight, const Chroma m_tempCnnBuf[i].destroy(); m_tempCnnBuf[i].create(format, Area(0, 0, picWidth, picHeight), maxCUWidth, 0, false); } +#if JVET_AB0083_QPADJ + m_bestCnnAdjBuf.destroy(); + m_bestCnnAdjBuf.create(format, Area(0, 0, picWidth, picHeight), maxCUWidth, 0, false); + m_tempCnnAdjBuf.destroy(); + m_tempCnnAdjBuf.create(format, Area(0, 0, picWidth, picHeight), maxCUWidth, 0, false); + m_origCnnAdjBuf.destroy(); + m_origCnnAdjBuf.create(format, Area(0, 0, picWidth, picHeight), maxCUWidth, 0, false); + m_beforeCnnBuftemp.destroy(); + m_beforeCnnBuftemp.create(format, Area(0, 0, picWidth, picHeight), maxCUWidth, 0, false); + + m_wtScaleCnnBuf.destroy(); + m_wtScaleCnnBuf.create(format, Area(0, 0, picWidth, picHeight), maxCUWidth, 0, false); + m_wtScaleAdjBuf.destroy(); + m_wtScaleAdjBuf.create(format, Area(0, 0, picWidth, picHeight), maxCUWidth, 0, false); +#endif } void NNFilterSet0::destroy() @@ -162,6 +188,15 @@ void NNFilterSet0::destroy() { m_tempCnnBuf[i].destroy(); } +#if JVET_AB0083_QPADJ + m_bestCnnAdjBuf.destroy(); + m_tempCnnAdjBuf.destroy(); + m_origCnnAdjBuf.destroy(); + m_beforeCnnBuftemp.destroy(); + + m_wtScaleCnnBuf.destroy(); + m_wtScaleAdjBuf.destroy(); +#endif } @@ -222,7 +257,11 @@ void NNFilterSet0::filterBlk( PelUnitBuf &recUnitBuf, const CPelUnitBuf& cnnUnit recBlk.copyFrom(cnnBlk); } +#if JVET_AB0083_QPADJ +bool NNFilterSet0::skipPatch(int x, int y, int ctuRsAddr, bool is_dec, const bool qp_opt, const bool isSec) +#else bool NNFilterSet0::skipPatch(int x, int y, int ctuRsAddr, bool is_dec) +#endif { if (is_dec) { @@ -231,10 +270,33 @@ bool NNFilterSet0::skipPatch(int x, int y, int ctuRsAddr, bool is_dec) return true; } } +#if JVET_AB0083_QPADJ + else if (qp_opt) + { + if (isSec) + { + if (!(x % 2 != 0 || y % 2 != 0)) + { + return true; + } + } + else + { + if (x % 2 != 0 || y % 2 != 0) + { + return true; + } + } + } +#endif return false; } -void NNFilterSet0::runCNNLF(Picture *pic, PelUnitBuf &cnnUnitBuf, const int baseQPoffset, bool is_dec) +#if JVET_AB0083_QPADJ +void NNFilterSet0::runCNNLF(Picture* pic, PelUnitBuf& cnnUnitBuf, PelUnitBuf& cnnUnitBufOrg, const int baseQPoffset, bool is_dec, const bool qp_opt, const bool isSec) +#else +void NNFilterSet0::runCNNLF(Picture* pic, PelUnitBuf& cnnUnitBuf, const int baseQPoffset, bool is_dec) +#endif { CodingStructure &cs = *pic->cs; const int slice_type = cs.slice->getSliceType() != I_SLICE ? 1023 : 0; @@ -260,7 +322,11 @@ void NNFilterSet0::runCNNLF(Picture *pic, PelUnitBuf &cnnUnitBuf, const int base { for (int x = 0; x < picWidthInPatchs; x++) { +#if JVET_AB0083_QPADJ + if (skipPatch(x, y, ctuRsAddr, is_dec, qp_opt, isSec)) +#else if (skipPatch(x, y, ctuRsAddr, is_dec)) +#endif { ctuRsAddr++; continue; @@ -300,14 +366,22 @@ void NNFilterSet0::runCNNLF(Picture *pic, PelUnitBuf &cnnUnitBuf, const int base NNInference::infer<TypeSadl>(m_Module, m_Input); // extract the results +#if JVET_AB0083_QPADJ + extractOutputs(pic, pix_x, pix_y, pix_x_end, pix_y_end, st_w, st_h, cnnUnitBuf, cnnUnitBufOrg, qp_opt); +#else extractOutputs(pic, pix_x, pix_y, pix_x_end, pix_y_end, st_w, st_h, cnnUnitBuf); +#endif ctuRsAddr++; } } } +#if JVET_AB0083_QPADJ +void NNFilterSet0::extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_end, int pix_y_end, int st_w, int st_h, PelUnitBuf &cnnUnitBuf, PelUnitBuf &cnnUnitBufOrg, const bool qp_opt) +#else void NNFilterSet0::extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_end, int pix_y_end, int st_w, int st_h, PelUnitBuf &cnnUnitBuf) +#endif { CodingStructure &cs = *pic->cs; PelUnitBuf recUnitBuf = cs.getRecoBuf(); @@ -317,6 +391,10 @@ void NNFilterSet0::extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_ int nn_scale_shift = 0; #endif +#if JVET_AB0083_QPADJ + int nn_scale_offset = 1 << (nn_scale_shift - 1); +#endif + #if !NN_FIXED_POINT_IMPLEMENTATION double in_maxValue = 1023; #endif @@ -348,6 +426,12 @@ void NNFilterSet0::extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_ cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int(m_Module.result(0)(0, (id_y - st_h) >> 1, (id_x - st_w) >> 1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) << in_left_shift)) << out_left_shift)); #else cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int((m_Module.result(0)(0, (id_y - st_h) >> 1, (id_x - st_w)>>1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) / in_maxValue)) * out_maxValue + 0.5))); +#endif +#if JVET_AB0083_QPADJ + if (qp_opt) + { + cnnUnitBufOrg.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, real_maxValue, (cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) + nn_scale_offset) >> nn_scale_shift)); + } #endif } } @@ -379,6 +463,12 @@ void NNFilterSet0::extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_ cnnUnitBuf.get(comp).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int(sample + (recUnitBuf.get(comp).at(id_x, id_y) << in_left_shift)) << out_left_shift)); #else cnnUnitBuf.get(comp).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int((sample + (recUnitBuf.get(comp).at(id_x, id_y) / in_maxValue)) * out_maxValue + 0.5))); +#endif +#if JVET_AB0083_QPADJ + if (qp_opt) + { + cnnUnitBufOrg.get(comp).at(id_x, id_y) = Pel(Clip3<int>(0, real_maxValue, (cnnUnitBuf.get(comp).at(id_x, id_y) + nn_scale_offset) >> nn_scale_shift)); + } #endif } } diff --git a/source/Lib/CommonLib/NNFilterSet0.h b/source/Lib/CommonLib/NNFilterSet0.h index b39d1450fc..b8837ce27e 100644 --- a/source/Lib/CommonLib/NNFilterSet0.h +++ b/source/Lib/CommonLib/NNFilterSet0.h @@ -64,6 +64,15 @@ public: protected: uint8_t* m_ctuEnableFlag[MAX_NUM_COMPONENT]; PelStorage m_tempCnnBuf[MAX_NUM_CNN]; +#if JVET_AB0083_QPADJ + PelStorage m_tempCnnAdjBuf; + PelStorage m_bestCnnAdjBuf; + PelStorage m_origCnnAdjBuf; + PelStorage m_beforeCnnBuftemp; + + PelStorage m_wtScaleCnnBuf; + PelStorage m_wtScaleAdjBuf; +#endif int m_inputBitDepth[MAX_NUM_CHANNEL_TYPE]; int m_picWidth; @@ -88,9 +97,15 @@ protected: void initCnnModel(); void initPatch(const int PatchWidth, const int PatchHeight); +#if JVET_AB0083_QPADJ + bool skipPatch(int x, int y, int ctuRsAddr, bool is_dec, const bool qp_opt, const bool isSec); + void runCNNLF(Picture* pic, PelUnitBuf& cnnUnitBuf, PelUnitBuf& cnnUnitBufOrg, const int baseQPoffset, bool is_dec, const bool qp_opt = false, const bool isSec = false); + void extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_end, int pix_y_end, int st_w, int st_h, PelUnitBuf &cnnUnitBuf, PelUnitBuf &cnnUnitBufOrg, const bool qp_opt); +#else bool skipPatch(int x, int y, int ctuRsAddr, bool is_dec); void runCNNLF(Picture* pic, PelUnitBuf& cnnUnitBuf, const int baseQPoffset, bool is_dec); void extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_end, int pix_y_end, int st_w, int st_h, PelUnitBuf &cnnUnitBuf); +#endif #if NN_SCALE void scaleResidue(CodingStructure& cs, PelUnitBuf recUnitBuf, PelUnitBuf cnnYuv, int *scale_list, bool is_dec); diff --git a/source/Lib/CommonLib/Slice.h b/source/Lib/CommonLib/Slice.h index 51cc3d934a..231684dbf7 100644 --- a/source/Lib/CommonLib/Slice.h +++ b/source/Lib/CommonLib/Slice.h @@ -2763,6 +2763,10 @@ private: #if NN_SCALE int nn_scale[MAX_NUM_COMPONENT]; #endif +#if JVET_AB0083_QPADJ + bool useOffset; + int nnQPOffsetIdx; +#endif #endif public: @@ -3051,6 +3055,12 @@ public: void setNnScale(int sc, ComponentID id) { nn_scale[id] = sc; } int getNnScale(ComponentID id) const { return nn_scale[id]; } #endif +#if JVET_AB0083_QPADJ + void setNNQPOffsetEnable(bool use) { useOffset = use; } + bool getNNQPOffsetEnable() { return useOffset; } + void setNNBestQPOffset(int idx) { nnQPOffsetIdx = idx; } + int getNNBestQPOffset() { return nnQPOffsetIdx; } +#endif #endif void resetTileGroupAlfEnabledFlag() { memset(m_tileGroupAlfEnabledFlag, 0, sizeof(m_tileGroupAlfEnabledFlag)); } diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 7b024f1f79..566508253d 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -81,6 +81,7 @@ using TypeSadl = float; // options set 0 #if NN_FILTERING_SET_0 #define NN_SCALE 1 +#define JVET_AB0083_QPADJ 1 // JVET-AB0083: EE1-1.8: More refinements on NN based in-loop filter with a single model (Test 1) #endif diff --git a/source/Lib/DecoderLib/VLCReader.cpp b/source/Lib/DecoderLib/VLCReader.cpp index b2634bb024..fbeccce480 100644 --- a/source/Lib/DecoderLib/VLCReader.cpp +++ b/source/Lib/DecoderLib/VLCReader.cpp @@ -4225,6 +4225,14 @@ void HLSyntaxReader::parseSliceHeader (Slice* pcSlice, PicHeader* picHeader, Par READ_SCODE(NN_RESIDUE_SCALE_SHIFT + 1, iCode, "nn scale Cr"); pcSlice->setNnScale(iCode + (1 << NN_RESIDUE_SCALE_SHIFT), COMPONENT_Cr); } +#endif +#if JVET_AB0083_QPADJ + cnnlfFrameQP(*pcSlice); + if (pcSlice->getNNQPOffsetEnable()) + { + READ_FLAG(uiCode, "nn QP offset index"); + pcSlice->setNNBestQPOffset(uiCode); + } #endif } #endif @@ -5281,6 +5289,20 @@ void HLSyntaxReader::cnnlf(CnnlfSliceParam& cnnlfSliceParam) } } } + +#if JVET_AB0083_QPADJ +void HLSyntaxReader::cnnlfFrameQP(Slice& slice) +{ + uint32_t code; + if (slice.getSliceType() != I_SLICE) + { + READ_FLAG(code, "nn QP offset enable flag"); + slice.setNNQPOffsetEnable(code ? true : false); + } + else + slice.setNNQPOffsetEnable(false); +} +#endif #endif //! \} diff --git a/source/Lib/DecoderLib/VLCReader.h b/source/Lib/DecoderLib/VLCReader.h index 9e107e3380..9829be57ff 100644 --- a/source/Lib/DecoderLib/VLCReader.h +++ b/source/Lib/DecoderLib/VLCReader.h @@ -190,6 +190,9 @@ public: private: #if NN_FILTERING_SET_0 void cnnlf(CnnlfSliceParam& cnnlfSliceParam); +#if JVET_AB0083_QPADJ + void cnnlfFrameQP(Slice& slice); +#endif #endif protected: diff --git a/source/Lib/EncoderLib/EncNNFilterSet0.cpp b/source/Lib/EncoderLib/EncNNFilterSet0.cpp index 61b1b77019..f503337819 100644 --- a/source/Lib/EncoderLib/EncNNFilterSet0.cpp +++ b/source/Lib/EncoderLib/EncNNFilterSet0.cpp @@ -89,12 +89,106 @@ void EncNNFilterSet0::PreCNNLFProcess(Picture* pic, CodingStructure& cs) CHECK(orgUnitBuf.bufs.size() != cnnUnitBuf.bufs.size(), "Error buf size."); // run CNNLF +#if JVET_AB0083_QPADJ + m_wtScaleCnnBuf.copyFrom(recUnitBuf); + m_wtScaleAdjBuf.copyFrom(recUnitBuf); + PelUnitBuf cnnUnitwtScale = m_wtScaleCnnBuf.getBuf(cs.area); + PelUnitBuf cnnUnitwtScaleAdj = m_wtScaleAdjBuf.getBuf(cs.area); + + if (cs.slice->getSliceType() != I_SLICE) + runCNNLF(pic, cnnUnitwtScale, cnnUnitBuf, 0, false, true); + else + runCNNLF(pic, cnnUnitBuf, cnnUnitBuf, 0, false); +#else runCNNLF(pic, cnnUnitBuf, 0, false); +#endif + + +#if JVET_AB0083_QPADJ + if (cs.slice->getSliceType() != I_SLICE) + { + m_bestCnnAdjBuf.copyFrom(cnnUnitBuf); + m_tempCnnAdjBuf.copyFrom(cnnUnitBuf); + m_origCnnAdjBuf.copyFrom(cnnUnitBuf); + m_beforeCnnBuftemp.copyFrom(cnnUnitBuf); + PelUnitBuf cnnUnitAdjBufOrig = m_origCnnAdjBuf.getBuf(cs.area); + PelUnitBuf cnnUnitAdjBufBest = m_bestCnnAdjBuf.getBuf(cs.area); + PelUnitBuf cnnUnitAdjBuf = m_tempCnnAdjBuf.getBuf(cs.area); + + const TempCtx ctxStart(m_CtxCache, CnnlfCtx(m_CABACEstimator->getCtx())); + TempCtx ctxBest(m_CtxCache); + double NNcost = 0., NNcostbest = MAX_DOUBLE; + for (int channel = 0; channel < 2; channel++) + { + NNcost += getCnnCost(orgUnitBuf, cnnUnitAdjBufOrig, ChannelType(channel)); + } + if (NNcost < NNcostbest) + { + NNcostbest = NNcost; + } + + double AdjBestcost = MAX_DOUBLE; + int bestOffsetIdx = -1; + + for (int offsetIdx = 0; offsetIdx < QP_OFFSET_NUM; offsetIdx++) + { + double Adjcost = 0.; + const int QPoffset = NNQPOFFSET[offsetIdx]; + + runCNNLF(pic, cnnUnitwtScaleAdj, cnnUnitAdjBuf, QPoffset, false, true); + + for (int channel = 0; channel < 2; channel++) + { + Adjcost += getCnnCost(orgUnitBuf, cnnUnitAdjBuf, ChannelType(channel)); + } + if (Adjcost < AdjBestcost) + { + AdjBestcost = Adjcost; + bestOffsetIdx = offsetIdx; + m_bestCnnAdjBuf.copyFrom(cnnUnitwtScaleAdj); + } + if (AdjBestcost < NNcostbest) + { + break; + } + } + if (AdjBestcost < NNcostbest) + { + cs.slice->setNNQPOffsetEnable(true); + cs.slice->setNNBestQPOffset(bestOffsetIdx); + + const int QPoffset = NNQPOFFSET[bestOffsetIdx]; + runCNNLF(pic, cnnUnitAdjBufBest, cnnUnitwtScaleAdj, QPoffset, false, true, true); + for (int i = 0; i < MAX_NUM_CNN; i++) + { + m_tempCnnBuf[i].copyFrom(cnnUnitAdjBufBest); + } + } + else + { + cs.slice->setNNQPOffsetEnable(false); + runCNNLF(pic, cnnUnitwtScale, cnnUnitBuf, 0, false, true, true); + for (int i = 0; i < MAX_NUM_CNN; i++) + { + m_tempCnnBuf[i].copyFrom(cnnUnitwtScale); + } + } + } + else + { + for (int i = 1; i < MAX_NUM_CNN; i++) + { + m_tempCnnBuf[i].copyFrom(cnnUnitBuf); + } + } +#else for (int i = 1; i < MAX_NUM_CNN; i++) { m_tempCnnBuf[i].copyFrom(cnnUnitBuf); } +#endif + } void EncNNFilterSet0::CNNLFProcess(CodingStructure& cs, const double *lambdas, CnnlfSliceParam& cnnlfSliceParam) diff --git a/source/Lib/EncoderLib/VLCWriter.cpp b/source/Lib/EncoderLib/VLCWriter.cpp index 35aa22113b..10dba722dd 100644 --- a/source/Lib/EncoderLib/VLCWriter.cpp +++ b/source/Lib/EncoderLib/VLCWriter.cpp @@ -2526,6 +2526,11 @@ void HLSWriter::codeSliceHeader ( Slice* pcSlice, PicHeader *picHeader ) { WRITE_SCODE(pcSlice->getNnScale(COMPONENT_Cr) - (1 << NN_RESIDUE_SCALE_SHIFT), NN_RESIDUE_SCALE_SHIFT + 1, "nn cale Cr"); } +#endif +#if JVET_AB0083_QPADJ + cnnlfFrameQP(*pcSlice); + if (pcSlice->getNNQPOffsetEnable() && pcSlice->getSliceType() != I_SLICE) + WRITE_FLAG(pcSlice->getNNBestQPOffset(), "nn QP offset index"); #endif } #endif @@ -3182,6 +3187,14 @@ void HLSWriter::cnnlf(const CnnlfSliceParam& cnnlfSliceParam) WRITE_UVLC(map_list[code], "nn-filter mode"); } } + +#if JVET_AB0083_QPADJ +void HLSWriter::cnnlfFrameQP(Slice& slice) +{ + if (slice.getSliceType() != I_SLICE) + WRITE_FLAG(slice.getNNQPOffsetEnable() ? 1 : 0, "nn QP adjustment"); +} +#endif #endif //! \} diff --git a/source/Lib/EncoderLib/VLCWriter.h b/source/Lib/EncoderLib/VLCWriter.h index fb5ab35870..f55d2e0ef6 100644 --- a/source/Lib/EncoderLib/VLCWriter.h +++ b/source/Lib/EncoderLib/VLCWriter.h @@ -156,6 +156,9 @@ public: private: #if NN_FILTERING_SET_0 void cnnlf(const CnnlfSliceParam& cnnlfSliceParam); +#if JVET_AB0083_QPADJ + void cnnlfFrameQP(Slice& slice); +#endif #endif }; -- GitLab