diff --git a/source/App/DecoderApp/DecApp.cpp b/source/App/DecoderApp/DecApp.cpp index d3296a34f7bb4a193d23d1890ce6060e4e0e1cae..d1031389e440bbda03d833205b55a952829c8ddb 100644 --- a/source/App/DecoderApp/DecApp.cpp +++ b/source/App/DecoderApp/DecApp.cpp @@ -620,6 +620,9 @@ void DecApp::xCreateDecLib() m_cDecLib.setNnlfSet1InterChromaModelName (m_nnlfSet1InterChromaModelName); m_cDecLib.setNnlfSet1IntraLumaModelName (m_nnlfSet1IntraLumaModelName); m_cDecLib.setNnlfSet1IntraChromaModelName (m_nnlfSet1IntraChromaModelName); +#if JVET_AC0177_MULTI_FRAME + m_cDecLib.setNnlfSet1AlternativeInterLumaModelName (m_nnlfSet1AlternativeInterLumaModelName); +#endif #endif if (!m_outputDecodedSEIMessagesFilename.empty()) diff --git a/source/App/DecoderApp/DecAppCfg.cpp b/source/App/DecoderApp/DecAppCfg.cpp index 027d4a3e58beb7e32d3a168524d9316d72fa051a..8202fbc76c50e25f300456b2d1f01ac9ae30f6ce 100644 --- a/source/App/DecoderApp/DecAppCfg.cpp +++ b/source/App/DecoderApp/DecAppCfg.cpp @@ -90,6 +90,9 @@ bool DecAppCfg::parseCfg( int argc, char* argv[] ) ( "NnlfSet1InterChromaModel", m_nnlfSet1InterChromaModelName, string("models/NnlfSet1_ChromaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter chroma model name") ( "NnlfSet1IntraLumaModel", m_nnlfSet1IntraLumaModelName, string("models/NnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra luma model name") ( "NnlfSet1IntraChromaModel", m_nnlfSet1IntraChromaModelName, string("models/NnlfSet1_ChromaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra chroma model name") +#if JVET_AC0177_MULTI_FRAME + ( "NnlfSet1AlternativeInterLumaModel", m_nnlfSet1AlternativeInterLumaModelName, string("models/NnlfSet1_LumaCNNFilter_InterSlice_MultiframePrior_Tid345_int16.sadl"), "NnlfSet1 alternative inter luma model name") +#endif #endif ("OplFile,-opl", m_oplFilename , string(""), "opl-file name without extension for conformance testing\n") diff --git a/source/App/DecoderApp/DecAppCfg.h b/source/App/DecoderApp/DecAppCfg.h index 7dbb765f403657bd0be5138cd1f3c5c8bfdc547a..e8a5dad49462dd56f0735de8e11359a8478927d6 100644 --- a/source/App/DecoderApp/DecAppCfg.h +++ b/source/App/DecoderApp/DecAppCfg.h @@ -72,6 +72,9 @@ protected: std::string m_nnlfSet1InterChromaModelName; ///<inter chroma nnlf set1 model std::string m_nnlfSet1IntraLumaModelName; ///<intra luma nnlf set1 model std::string m_nnlfSet1IntraChromaModelName; ///<inra chroma nnlf set1 model +#if JVET_AC0177_MULTI_FRAME + std::string m_nnlfSet1AlternativeInterLumaModelName; ///<alternative inter luma nnlf set1 model +#endif #endif int m_iSkipFrame; ///< counter for frames prior to the random access point to skip diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp index 03f3e3cca136aab2459b699402e3dce050eecc87..d39de5428f18e5fcaad1e513f1b6dfb81c7d46ea 100644 --- a/source/App/EncoderApp/EncApp.cpp +++ b/source/App/EncoderApp/EncApp.cpp @@ -248,6 +248,9 @@ void EncApp::xInitLibCfg() m_cEncLib.setNnlfSet1InterChromaModelName (m_nnlfSet1InterChromaModelName); m_cEncLib.setNnlfSet1IntraLumaModelName (m_nnlfSet1IntraLumaModelName); m_cEncLib.setNnlfSet1IntraChromaModelName (m_nnlfSet1IntraChromaModelName); +#if JVET_AC0177_MULTI_FRAME + m_cEncLib.setNnlfSet1AlternativeInterLumaModelName (m_nnlfSet1AlternativeInterLumaModelName); +#endif #endif #if JVET_AB0068_RD m_cEncLib.setUseEncNnlfOpt (m_encNnlfOpt); @@ -1087,6 +1090,9 @@ void EncApp::xInitLibCfg() m_cEncLib.setNnlfSet1InferSizeBase (m_nnlfSet1InferSizeBase); m_cEncLib.setNnlfSet1InferSizeExtension (m_nnlfSet1InferSizeExtension); m_cEncLib.setNnlfSet1MaxNumParams (m_nnlfSet1MaxNumParams); +#if JVET_AC0177_MULTI_FRAME + m_cEncLib.setNnlfSet1UseMultiframe (m_nnlfSet1Multiframe); +#endif #endif m_cEncLib.setLmcs ( m_lmcsEnabled ); m_cEncLib.setReshapeSignalType ( m_reshapeSignalType ); diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp index 8ba5c3cf29be3cdc8b459fc01376c5cee846cf86..629b53b137536e797fa2e11a338cf04d60b0aff0 100644 --- a/source/App/EncoderApp/EncAppCfg.cpp +++ b/source/App/EncoderApp/EncAppCfg.cpp @@ -1443,6 +1443,10 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) ( "NnlfSet1InterChromaModel", m_nnlfSet1InterChromaModelName, string("models/NnlfSet1_ChromaCNNFilter_InterSlice_int16.sadl"), "NnlfSet1 inter chroma model name") ( "NnlfSet1IntraLumaModel", m_nnlfSet1IntraLumaModelName, string("models/NnlfSet1_LumaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra luma model name") ( "NnlfSet1IntraChromaModel", m_nnlfSet1IntraChromaModelName, string("models/NnlfSet1_ChromaCNNFilter_IntraSlice_int16.sadl"), "NnlfSet1 intra chroma model name") +#if JVET_AC0177_MULTI_FRAME + ( "NnlfSet1AlternativeInterLumaModel", m_nnlfSet1AlternativeInterLumaModelName, string("models/NnlfSet1_LumaCNNFilter_InterSlice_MultiframePrior_Tid345_int16.sadl"), "NnlfSet1 alternative inter luma model name") + ( "NnlfSet1Multiframe", m_nnlfSet1Multiframe, false, "Input multiple frames in NN-based loop filter set 1" ) +#endif #endif #if JVET_AB0068_RD ( "EncNnlfOpt", m_encNnlfOpt, false, "Encoder optimization with NN-based loop filter") diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h index 5d59c2ab366a57df7907921ea173779fad99244f..ba685eced6e7143cd8b1f926333d03d13bf9fe25 100644 --- a/source/App/EncoderApp/EncAppCfg.h +++ b/source/App/EncoderApp/EncAppCfg.h @@ -94,6 +94,9 @@ protected: std::string m_nnlfSet1InterChromaModelName; ///<inter chroma nnlf set1 model std::string m_nnlfSet1IntraLumaModelName; ///<intra luma nnlf set1 model std::string m_nnlfSet1IntraChromaModelName; ///<inra chroma nnlf set1 model +#if JVET_AC0177_MULTI_FRAME + std::string m_nnlfSet1AlternativeInterLumaModelName; ///<alternative inter luma nnlf set1 model +#endif #endif #if JVET_AB0068_RD std::string m_rdoCnnlfInterLumaModelName; ///<inter luma cnnlf model @@ -733,6 +736,9 @@ protected: unsigned m_nnlfSet1InferSizeBase; unsigned m_nnlfSet1InferSizeExtension; unsigned m_nnlfSet1MaxNumParams; +#if JVET_AC0177_MULTI_FRAME + bool m_nnlfSet1Multiframe; +#endif #endif #if JVET_AB0068_RD diff --git a/source/Lib/CommonLib/CommonDef.h b/source/Lib/CommonLib/CommonDef.h index a8c42e17eb4de8c6d90fdeddfa7d81bf4ee9fd5d..c037e9e3d5c84dd20e024e2314872bc5b11b6357 100644 --- a/source/Lib/CommonLib/CommonDef.h +++ b/source/Lib/CommonLib/CommonDef.h @@ -198,6 +198,9 @@ static const int NNQPOFFSET[QP_OFFSET_NUM] = { -5, 5 }; #if NN_FILTERING_SET_1 static const int NN_INPUT_PRECISION= 13; static const int NN_OUTPUT_PRECISION= 13; +#if JVET_AC0177_MULTI_FRAME +static const int MINIMUM_TID_ENABLING_TEMPORAL_INPUTS = 3; // JVET-AC0177: minimum temporal layer id enabling temporal inputs +#endif #endif diff --git a/source/Lib/CommonLib/NNFilterSet1.cpp b/source/Lib/CommonLib/NNFilterSet1.cpp index 4cac99ccf4b6670b074240b2219f9b1ac5ae7c5f..e2606d203c30536847b5a64554f1725d54059e1e 100644 --- a/source/Lib/CommonLib/NNFilterSet1.cpp +++ b/source/Lib/CommonLib/NNFilterSet1.cpp @@ -80,12 +80,19 @@ void NNFilterSet1::create( const int picWidth, const int picHeight, const Chroma } #endif } +#if JVET_AC0177_MULTI_FRAME +void NNFilterSet1::init(std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma, std::string alternativeInterLuma) +#else void NNFilterSet1::init(std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma) +#endif { m_interLuma = interLuma; m_interChroma = interChroma; m_intraLuma = intraLuma; m_intraChroma = intraChroma; +#if JVET_AC0177_MULTI_FRAME + m_alternativeInterLuma = alternativeInterLuma; +#endif } void NNFilterSet1::destroy() { @@ -108,6 +115,9 @@ struct ModelData { vector<sadl::Tensor<T>> inputs; int hor,ver; bool luma,inter; +#if JVET_AC0177_MULTI_FRAME + bool alternative; +#endif }; template<typename T> @@ -120,12 +130,20 @@ static std::vector<ModelData<T>> initSpace() { static std::vector<ModelData<TypeSadl>> models=initSpace<TypeSadl>(); template<typename T> +#if JVET_AC0177_MULTI_FRAME +static ModelData<T> &getModel(int ver, int hor, bool luma, bool inter, const string modelName, bool alternative=false) +#else static ModelData<T> &getModel(int ver, int hor, bool luma, bool inter, const string modelName) +#endif { ModelData<T> *ptr = nullptr; for(auto &m: models) { +#if JVET_AC0177_MULTI_FRAME + if (m.luma == luma && m.inter == inter && m.alternative == alternative) +#else if (m.luma == luma && m.inter == inter) +#endif { ptr = &m; break; @@ -151,6 +169,9 @@ static ModelData<T> &getModel(int ver, int hor, bool luma, bool inter, const str m.inter = inter; m.ver = 0; m.hor = 0; +#if JVET_AC0177_MULTI_FRAME + m.alternative = alternative; +#endif } ModelData<T> &m = *ptr; if (m.ver != ver || m.hor != hor) @@ -279,6 +300,60 @@ void extractOutputsLuma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, P } } +#if JVET_AC0177_FLIP_INPUT +template<typename T> +void extractOutputsLuma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, PelStorage& tempScaledBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, bool inter, bool flip) +{ +#if NN_FIXED_POINT_IMPLEMENTATION + int log2InputScale = 10; + int log2OutputScale = 10; + int shiftInput = NN_OUTPUT_PRECISION - log2InputScale; + int shiftOutput = NN_OUTPUT_PRECISION - log2OutputScale; + int offset = (1 << shiftOutput) / 2; +#else + double inputScale = 1024; + double outputScale = 1024; +#endif + auto output = m.result(0); + PelBuf bufDst = tempBuf.getBuf(inferArea).get(COMPONENT_Y); + +#if SCALE_NN_RESIDUE + PelBuf bufScaledDst = tempScaledBuf.getBuf(inferArea).get(COMPONENT_Y); +#endif + + int hor = inferArea.lwidth(); + int ver = inferArea.lheight(); + + PelBuf bufRec = pic->getRecBeforeDbfBuf(inferArea).get(COMPONENT_Y); + + for (int c = 0; c < 4; c++) // output includes 4 sub images + { + for (int y = 0; y < ver >> 1; y++) + { + for (int x = 0; x < hor >> 1; x++) + { + int yy = (y << 1) + c / 2; + int xx = (x << 1) + c % 2; + NNInference::geometricTransform(ver, hor, yy, xx, flip); + if (xx < extLeft || yy < extTop || xx >= hor - extRight || yy >= ver - extBottom) + { + continue; + } +#if NN_FIXED_POINT_IMPLEMENTATION + int out = ( output(0, y, x, c) + (bufRec.at(xx, yy) << shiftInput) + offset) >> shiftOutput; +#else + int out = ( output(0, y, x, c) + bufRec.at(xx, yy) / inputScale ) * outputScale + 0.5; +#endif + bufDst.at(xx, yy) = Pel(Clip3<int>( 0, 1023, out)); + + #if SCALE_NN_RESIDUE + bufScaledDst.at(xx, yy) = Pel(Clip3<int>(0, 1023 << NN_RESIDUE_ADDITIONAL_SHIFT, out * (1 << NN_RESIDUE_ADDITIONAL_SHIFT) ) ); + #endif + } + } + } +} +#endif template<typename T> void extractOutputsChroma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, PelStorage& tempScaledBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, bool inter) { @@ -353,12 +428,17 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL if (border_to_skip>0) sadl::Tensor<float>::skip_border = true; // get model +#if JVET_AC0177_MULTI_FRAME + bool alternative = pic->slices[0]->getSPS()->getNnlfSet1MultiframeEnabledFlag() && inter && pic->slices[0]->getTLayer() >= MINIMUM_TID_ENABLING_TEMPORAL_INPUTS ? true : false; + // get model + ModelData<T> &m = getModel<T>(inferArea.lheight(), inferArea.lwidth(), true, inter, inter ? (alternative ? m_alternativeInterLuma : m_interLuma) : m_intraLuma, alternative); +#else ModelData<T> &m = getModel<T>(inferArea.lheight(), inferArea.lwidth(), true, inter, inter ? m_interLuma : m_intraLuma); +#endif sadl::Model<T> &model = m.model; // get inputs vector<sadl::Tensor<T>> &inputs = m.inputs; - int seqQp = pic->slices[0]->getPPS()->getPicInitQPMinus26() + 26; int sliceQp = pic->slices[0]->getSliceQp(); int delta = inter ? paramIdx * 5 : paramIdx * 2; @@ -366,20 +446,47 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL { delta = 5 - delta; } +#if JVET_AC0177_FLIP_INPUT + bool flip = pic->slices[0]->getSPS()->getNnlfSet1MultiframeEnabledFlag() && (paramIdx == 2) ? true : false; + if (flip) + { + delta = 0; + } +#endif int qp = inter ? seqQp - delta : sliceQp - delta; std::vector<InputData> listInputData; #if JVET_AC0089_COMBINE_INTRA_INTER - InputData inputRec = { NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false }; - InputData inputPred = { NN_INPUT_PRED, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false }; - InputData inputBs = { NN_INPUT_BS, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false }; - InputData inputIpb = { NN_INPUT_IPB, 3, 1, NN_INPUT_PRECISION, true, false }; - InputData inputQp = { NN_INPUT_LOCAL_QP, 4, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false }; - listInputData.push_back(inputRec); - listInputData.push_back(inputPred); - listInputData.push_back(inputBs); - listInputData.push_back(inputIpb); - listInputData.push_back(inputQp); +#if JVET_AC0177_MULTI_FRAME + if (alternative) + { + InputData inputRec = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputPred = {NN_INPUT_PRED, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputRefList0 = {NN_INPUT_REF_LIST_0, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputRefList1 = {NN_INPUT_REF_LIST_1, 3, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 4, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false}; + listInputData.push_back(inputRec); + listInputData.push_back(inputPred); + listInputData.push_back(inputRefList0); + listInputData.push_back(inputRefList1); + listInputData.push_back(inputQp); + } + else + { +#endif + InputData inputRec = { NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false }; + InputData inputPred = { NN_INPUT_PRED, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false }; + InputData inputBs = { NN_INPUT_BS, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false }; + InputData inputIpb = { NN_INPUT_IPB, 3, 1, NN_INPUT_PRECISION, true, false }; + InputData inputQp = { NN_INPUT_LOCAL_QP, 4, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false }; + listInputData.push_back(inputRec); + listInputData.push_back(inputPred); + listInputData.push_back(inputBs); + listInputData.push_back(inputIpb); + listInputData.push_back(inputQp); +#if JVET_AC0177_MULTI_FRAME + } +#endif #else if (inter) { @@ -414,11 +521,19 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL listInputData.push_back(inputQp); } #endif +#if JVET_AC0177_FLIP_INPUT + NNInference::prepareInputs<T>(pic, inferArea, inputs, -1, qp, -1, listInputData, flip); +#else NNInference::prepareInputs<T>(pic, inferArea, inputs, -1, qp, -1, listInputData); +#endif NNInference::infer<T>(model, inputs); // get outputs +#if JVET_AC0177_FLIP_INPUT + extractOutputsLuma(pic, model, m_tempBuf[paramIdx], m_tempScaledBuf[paramIdx], inferArea, extLeft, extRight, extTop, extBottom, inter, flip); +#else extractOutputsLuma(pic, model, m_tempBuf[paramIdx], m_tempScaledBuf[paramIdx], inferArea, extLeft, extRight, extTop, extBottom, inter); +#endif } diff --git a/source/Lib/CommonLib/NNFilterSet1.h b/source/Lib/CommonLib/NNFilterSet1.h index b97214a6830f79383afa2df55595debd0848d89a..6a6dfa4e6394637f6305a6b9949a285fbb46cf65 100644 --- a/source/Lib/CommonLib/NNFilterSet1.h +++ b/source/Lib/CommonLib/NNFilterSet1.h @@ -53,6 +53,9 @@ public: std::vector<PelStorage> m_tempBuf; std::string m_interLuma, m_interChroma, m_intraLuma, m_intraChroma; +#if JVET_AC0177_MULTI_FRAME + std::string m_alternativeInterLuma; +#endif #if SCALE_NN_RESIDUE std::vector<PelStorage> m_tempScaledBuf; #endif @@ -63,7 +66,11 @@ public: void scaleResidualBlock(Picture *pic, UnitArea inferArea, int paramIdx, ComponentID compID); #endif void create(const int picWidth, const int picHeight, const ChromaFormat format, const int nnlfSet1NumParams); +#if JVET_AC0177_MULTI_FRAME + void init(std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma, std::string alternativeInterLuma); +#else void init(std::string interLuma, std::string interChroma, std::string intraLuma, std::string intraChroma); +#endif void destroy(); }; diff --git a/source/Lib/CommonLib/NNInference.h b/source/Lib/CommonLib/NNInference.h index ed40b606367a245c3e120aa2749f6ca850c82f51..84ffa13b630993584ef788b608f9026584aa2f3c 100644 --- a/source/Lib/CommonLib/NNInference.h +++ b/source/Lib/CommonLib/NNInference.h @@ -135,6 +135,90 @@ public: } #endif } +#if JVET_AC0177_FLIP_INPUT + static void geometricTransform (int ver, int hor, int& y, int& x, bool flip) + { + x = flip ? hor - 1 - x : x; + } + template<typename T> + static void fillInputFromBuf (Picture* pic, UnitArea inferArea, sadl::Tensor<T> &input, PelUnitBuf buf, bool luma, bool chroma, double scale, int shift, bool flip) + { + PelBuf bufY, bufCb, bufCr; + + if (luma) + { + bufY = buf.get(COMPONENT_Y); + } + if (chroma) + { + bufCb = buf.get(COMPONENT_Cb); + bufCr = buf.get(COMPONENT_Cr); + } + + int hor, ver; + if (luma) + { + hor = inferArea.lwidth(); + ver = inferArea.lheight(); + } + else + { + hor = inferArea.lwidth() >> 1; + ver = inferArea.lheight() >> 1; + } +#if NN_FIXED_POINT_IMPLEMENTATION + for (int y = 0; y < ver; y++) + { + for (int x = 0; x < hor; x++) + { + int yT = y; + int xT = x; + geometricTransform(ver, hor, yT, xT, flip); + if (luma && !chroma) + { + input(0, yT, xT, 0) = bufY.at(x, y) << shift; + } + else if (!luma && chroma) + { + input(0, yT, xT, 0) = bufCb.at(x, y) << shift; + input(0, yT, xT, 1) = bufCr.at(x, y) << shift; + } + else if (luma && chroma) + { + input(0, yT, xT, 0) = bufY.at(x, y) << shift; + input(0, yT, xT, 1) = bufCb.at(x >> 1, y >> 1) << shift; + input(0, yT, xT, 2) = bufCr.at(x >> 1, y >> 1) << shift; + } + } + } +#else + for (int y = 0; y < ver; y++) + { + for (int x = 0; x < hor; x++) + { + int yT = y; + int xT = x; + geometricTransform(ver, hor, yT, xT, flip); + if (luma && !chroma) + { + input(0, yT, xT, 0) = bufY.at(x, y) / scale; + } + else if (!luma && chroma) + { + input(0, yT, xT, 0) = bufCb.at(x, y) / scale; + input(0, yT, xT, 1) = bufCr.at(x, y) / scale; + } + else if (luma && chroma) + { + input(0, yT, xT, 0) = bufY.at(x, y) / scale; + input(0, yT, xT, 1) = bufCb.at(x >> 1, y >> 1) / scale; + input(0, yT, xT, 2) = bufCr.at(x >> 1, y >> 1) / scale; + } + } + } +#endif + } +#endif template<typename T> static void fillInputFromConstant (Picture* pic, UnitArea inferArea, sadl::Tensor<T> &input, int c, bool luma, double scale, int shift) { @@ -242,6 +326,87 @@ public: } #endif } +#if JVET_AC0177_FLIP_INPUT + template<typename T> + static void fillInputFromBufIpb(Picture *pic, UnitArea inferArea, sadl::Tensor<T> &input, PelUnitBuf buf, bool luma, + bool chroma, double scale, int shift, bool flip) + { + PelBuf bufY, bufCb, bufCr; + + if (luma) + { + bufY = buf.get(COMPONENT_Y); + } + if (chroma) + { + bufCb = buf.get(COMPONENT_Cb); + bufCr = buf.get(COMPONENT_Cr); + } + + int hor, ver; + if (luma) + { + hor = inferArea.lwidth(); + ver = inferArea.lheight(); + } + else + { + hor = inferArea.lwidth() >> 1; + ver = inferArea.lheight() >> 1; + } +#if NN_FIXED_POINT_IMPLEMENTATION + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + int yT = yy; + int xT = xx; + geometricTransform(ver, hor, yT, xT, flip); + if (luma && !chroma) + { + input(0, yT, xT, 0) = ((bufY.at(xx, yy) >> 1) + ((bufY.at(xx, yy) >> 1) == 0) - 1) << (shift - 1); + } + else if (!luma && chroma) + { + input(0, yT, xT, 0) = ((bufCb.at(xx, yy) >> 1) + ((bufCb.at(xx, yy) >> 1) == 0) - 1) << (shift - 1); + input(0, yT, xT, 1) = ((bufCr.at(xx, yy) >> 1) + ((bufCr.at(xx, yy) >> 1) == 0) - 1) << (shift - 1); + } + else if (luma && chroma) + { + input(0, yT, xT, 0) = ((bufY.at(xx, yy) >> 1) + ((bufY.at(xx, yy) >> 1) == 0) - 1) << (shift - 1); + input(0, yT, xT, 1) = ((bufCb.at(xx, yy) >> 1) + ((bufCb.at(xx, yy) >> 1) == 0) - 1) << (shift - 1); + input(0, yT, xT, 2) = ((bufCr.at(xx, yy) >> 1) + ((bufCr.at(xx, yy) >> 1) == 0) - 1) << (shift - 1); + } + } + } +#else + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + int yT = yy; + int xT = xx; + geometricTransform(ver, hor, yT, xT, flip); + if (luma && !chroma) + { + input(0, yT, xT, 0) = ((bufY.at(xx, yy) >> 1) + ((bufY.at(xx, yy) >> 1) == 0) - 1) / 2.0; + } + else if (!luma && chroma) + { + input(0, yT, xT, 0) = ((bufCb.at(xx, yy) >> 1) + ((bufCb.at(xx, yy) >> 1) == 0) - 1) / 2.0; + input(0, yT, xT, 1) = ((bufCr.at(xx, yy) >> 1) + ((bufCr.at(xx, yy) >> 1) == 0) - 1) / 2.0; + } + else if (luma && chroma) + { + input(0, yT, xT, 0) = ((bufY.at(xx, yy) >> 1) + ((bufY.at(xx, yy) >> 1) == 0) - 1) / 2.0; + input(0, yT, xT, 1) = ((bufCb.at(xx, yy) >> 1) + ((bufCb.at(xx, yy) >> 1) == 0) - 1) / 2.0; + input(0, yT, xT, 2) = ((bufCr.at(xx, yy) >> 1) + ((bufCr.at(xx, yy) >> 1) == 0) - 1) / 2.0; + } + } + } +#endif + } +#endif #endif template<typename T> @@ -272,6 +437,14 @@ public: case NN_INPUT_SLICE_TYPE: fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], sliceType, inputData.luma, inputData.scale, inputData.shift); break; +#if JVET_AC0177_MULTI_FRAME + case NN_INPUT_REF_LIST_0: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->slices[0]->getRefPic(REF_PIC_LIST_0, 0)->getRecoBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift); + break; + case NN_INPUT_REF_LIST_1: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->slices[0]->getRefPic(REF_PIC_LIST_1, 0)->getRecoBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift); + break; +#endif #if JVET_AC0089_COMBINE_INTRA_INTER case NN_INPUT_IPB: fillInputFromBufIpb<T>(pic, inferArea, inputs[inputData.index], pic->getBlockPredModeBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift); @@ -283,6 +456,55 @@ public: } } } +#if JVET_AC0177_FLIP_INPUT + template<typename T> + static void prepareInputs (Picture* pic, UnitArea inferArea, std::vector<sadl::Tensor<T>> &inputs, int globalQp, int localQp, int sliceType, const std::vector<InputData> &listInputData, bool flip) + { + for (auto inputData : listInputData) + { + switch (inputData.nnInputType) + { + case NN_INPUT_REC: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getRecBeforeDbfBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift, flip); + break; + case NN_INPUT_PRED: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getPredBufCustom(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift, flip); + break; + case NN_INPUT_PARTITION: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getCuAverageBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift, flip); + break; + case NN_INPUT_BS: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getBsMapBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift, flip); + break; + case NN_INPUT_GLOBAL_QP: + fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], globalQp, inputData.luma, inputData.scale, inputData.shift); + break; + case NN_INPUT_LOCAL_QP: + fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], localQp, inputData.luma, inputData.scale, inputData.shift); + break; + case NN_INPUT_SLICE_TYPE: + fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], sliceType, inputData.luma, inputData.scale, inputData.shift); + break; +#if JVET_AC0177_MULTI_FRAME + case NN_INPUT_REF_LIST_0: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->slices[0]->getRefPic(REF_PIC_LIST_0, 0)->getRecoBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift, flip); + break; + case NN_INPUT_REF_LIST_1: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->slices[0]->getRefPic(REF_PIC_LIST_1, 0)->getRecoBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift, flip); + break; +#endif +#if JVET_AC0089_COMBINE_INTRA_INTER + case NN_INPUT_IPB: + fillInputFromBufIpb<T>(pic, inferArea, inputs[inputData.index], pic->getBlockPredModeBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift, flip); + break; +#endif + default: + THROW("invalid input data"); + break; + } + } + } +#endif template<typename T> static void infer(sadl::Model<T> &model, std::vector<sadl::Tensor<T>> &inputs); }; diff --git a/source/Lib/CommonLib/Slice.h b/source/Lib/CommonLib/Slice.h index 231684dbf7c4acce8cfa6ed410140b416c2a5f06..20eb48bd3101f4abd46b666ba8af600d05bf6d50 100644 --- a/source/Lib/CommonLib/Slice.h +++ b/source/Lib/CommonLib/Slice.h @@ -1492,6 +1492,9 @@ private: unsigned m_nnlfSet1InferSize[MAX_NUM_CNNLF_INFER_GRANULARITY]; unsigned m_nnlfSet1InferSizeExtension; unsigned m_nnlfSet1MaxNumParams; +#if JVET_AC0177_MULTI_FRAME + bool m_nnlfSet1MultiframeEnabledFlag; +#endif #endif bool m_wrapAroundEnabledFlag; unsigned m_IBCFlag; @@ -1765,6 +1768,10 @@ public: #if NN_FILTERING_SET_1 bool getNnlfSet1EnabledFlag() const { return m_nnlfSet1EnabledFlag; } void setNnlfSet1EnabledFlag( bool b ) { m_nnlfSet1EnabledFlag = b; } +#if JVET_AC0177_MULTI_FRAME + bool getNnlfSet1MultiframeEnabledFlag() const { return m_nnlfSet1MultiframeEnabledFlag; } + void setNnlfSet1MultiframeEnabledFlag( bool b ) { m_nnlfSet1MultiframeEnabledFlag = b; } +#endif #endif void setJointCbCrEnabledFlag(bool bVal) { m_JointCbCrEnabledFlag = bVal; } bool getJointCbCrEnabledFlag() const { return m_JointCbCrEnabledFlag; } diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 9e968498f79e8e452b99b14eb76f7e7d00739413..1de614434fb173c9e466d03451aa559af9f9d7c0 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -122,10 +122,10 @@ using TypeSadl = float; #define BYPASS_INTER_SLICE 0 // only used for training data generation #define JVET_AC0089_COMBINE_INTRA_INTER 1 // JVET-AC0089: EE1-1.5.3 Use combined inter/intra models. Luma model uses IPB input. Chroma model does not use IPB input. - +#define JVET_AC0177_MULTI_FRAME 1 // JVET-AC0177: EE1-1.7: Deep In-Loop Filter with Additional Input Information +#define JVET_AC0177_FLIP_INPUT 1 // JVET-AC0177: flip input and output of NN filter model #endif - #define JVET_AB0068_RD 1 // JVET-AB0068: EE1-1.6: RDO Considering Deep In-Loop Filtering //########### place macros to be removed in next cycle below this line ############### @@ -519,8 +519,15 @@ enum NNInputType NN_INPUT_LOCAL_QP = 5, NN_INPUT_SLICE_TYPE = 6, #if JVET_AC0089_NNVC_USE_BPM_INFO +#if JVET_AC0177_MULTI_FRAME + NN_INPUT_IPB = 7, + NN_INPUT_REF_LIST_0 = 8, + NN_INPUT_REF_LIST_1 = 9, + MAX_NUM_NN_INPUT = 10 +#else NN_INPUT_IPB = 7, MAX_NUM_NN_INPUT = 8 +#endif #else MAX_NUM_NN_INPUT = 7 #endif diff --git a/source/Lib/DecoderLib/DecLib.cpp b/source/Lib/DecoderLib/DecLib.cpp index 67eb2769700d60e7982af10ddef41d888576047f..75f3a1decf592fcd2fa1ad49c848ddb2af2ad690 100644 --- a/source/Lib/DecoderLib/DecLib.cpp +++ b/source/Lib/DecoderLib/DecLib.cpp @@ -651,7 +651,11 @@ void DecLib::executeLoopFilters() if (cs.sps->getNnlfSet1EnabledFlag()) { m_pcNNFilterSet1.create(cs.pcv->lumaWidth, cs.pcv->lumaHeight, cs.pcv->chrFormat, cs.sps->getNnlfSet1MaxNumParams()); +#if JVET_AC0177_MULTI_FRAME + m_pcNNFilterSet1.init(getNnlfSet1InterLumaModelName(), getNnlfSet1InterChromaModelName(), getNnlfSet1IntraLumaModelName(), getNnlfSet1IntraChromaModelName(), getNnlfSet1AlternativeInterLumaModelName()); +#else m_pcNNFilterSet1.init(getNnlfSet1InterLumaModelName(), getNnlfSet1InterChromaModelName(), getNnlfSet1IntraLumaModelName(), getNnlfSet1IntraChromaModelName()); +#endif } #endif diff --git a/source/Lib/DecoderLib/DecLib.h b/source/Lib/DecoderLib/DecLib.h index 3c4e673400402a8c0cf5715714ef0ba723f4d5a5..f85c4d2b8ca4e4c55568e8611c59be6d3e20edcc 100644 --- a/source/Lib/DecoderLib/DecLib.h +++ b/source/Lib/DecoderLib/DecLib.h @@ -77,10 +77,13 @@ class DecLib { private: #if NN_FILTERING_SET_1 - std::string m_nnlfSet1InterLumaModelName; ///<inter luma nnlfSet1 model - std::string m_nnlfSet1InterChromaModelName; ///<inter chroma nnlfSet1 model - std::string m_nnlfSet1IntraLumaModelName; ///<intra luma nnlfSet1 model - std::string m_nnlfSet1IntraChromaModelName; ///<inra chroma nnlfSet1 model + std::string m_nnlfSet1InterLumaModelName; ///<inter luma nnlf set1 model + std::string m_nnlfSet1InterChromaModelName; ///<inter chroma nnlf set1 model + std::string m_nnlfSet1IntraLumaModelName; ///<intra luma nnlf set1 model + std::string m_nnlfSet1IntraChromaModelName; ///<inra chroma nnlf set1 model +#if JVET_AC0177_MULTI_FRAME + std::string m_nnlfSet1AlternativeInterLumaModelName; ///<alternative inter luma nnlf set1 model +#endif #endif int m_iMaxRefPicNum; bool m_isFirstGeneralHrd; @@ -252,6 +255,10 @@ public: void setNnlfSet1InterChromaModelName(std::string s) { m_nnlfSet1InterChromaModelName = s; } void setNnlfSet1IntraLumaModelName(std::string s) { m_nnlfSet1IntraLumaModelName = s; } void setNnlfSet1IntraChromaModelName(std::string s) { m_nnlfSet1IntraChromaModelName = s; } +#if JVET_AC0177_MULTI_FRAME + std::string getNnlfSet1AlternativeInterLumaModelName() { return m_nnlfSet1AlternativeInterLumaModelName; } + void setNnlfSet1AlternativeInterLumaModelName(std::string s) { m_nnlfSet1AlternativeInterLumaModelName = s; } +#endif #endif void setDecodedPictureHashSEIEnabled(int enabled) { m_decodedPictureHashSEIEnabled=enabled; } diff --git a/source/Lib/DecoderLib/VLCReader.cpp b/source/Lib/DecoderLib/VLCReader.cpp index fbeccce480ce75f7f4e6cfc3da9329ea7f8a836e..143c49511264e5bff9c4ac72f571174bd2dc620e 100644 --- a/source/Lib/DecoderLib/VLCReader.cpp +++ b/source/Lib/DecoderLib/VLCReader.cpp @@ -1748,6 +1748,9 @@ void HLSyntaxReader::parseSPS(SPS* pcSPS) pcSPS->setNnlfSet1InferSizeExtension ( uiCode ); READ_UVLC( uiCode, "sps_nnlf_set1_max_num_params" ); pcSPS->setNnlfSet1MaxNumParams (uiCode ); +#if JVET_AC0177_MULTI_FRAME + READ_FLAG( uiCode, "sps_nnlf_set1_multi_frame_enabled_flag" ); pcSPS->setNnlfSet1MultiframeEnabledFlag ( uiCode ? true : false ); +#endif } #endif diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h index 016d0fad20876f2e0b47819058a611ec655205f5..3ff3f7d294ecdc66dcaa600da7fa4f383796c287 100644 --- a/source/Lib/EncoderLib/EncCfg.h +++ b/source/Lib/EncoderLib/EncCfg.h @@ -163,6 +163,9 @@ protected: std::string m_nnlfSet1InterChromaModelName; ///<inter chroma nnlf set1 model std::string m_nnlfSet1IntraLumaModelName; ///<intra luma nnlf set1 model std::string m_nnlfSet1IntraChromaModelName; ///<inra chroma nnlf set1 model +#if JVET_AC0177_MULTI_FRAME + std::string m_nnlfSet1AlternativeInterLumaModelName; ///<alternative inter luma nnlf set1 model +#endif #endif #if JVET_AB0068_RD bool m_encNnlfOpt; @@ -776,6 +779,9 @@ protected: unsigned m_nnlfSet1InferSizeBase; unsigned m_nnlfSet1InferSizeExtension; unsigned m_nnlfSet1MaxNumParams; +#if JVET_AC0177_MULTI_FRAME + bool m_nnlfSet1Multiframe; +#endif #endif #if JVET_O0756_CALCULATE_HDRMETRICS double m_whitePointDeltaE[hdrtoolslib::NB_REF_WHITE]; @@ -842,6 +848,10 @@ public: void setNnlfSet1InterChromaModelName(std::string s) { m_nnlfSet1InterChromaModelName = s; } void setNnlfSet1IntraLumaModelName(std::string s) { m_nnlfSet1IntraLumaModelName = s; } void setNnlfSet1IntraChromaModelName(std::string s) { m_nnlfSet1IntraChromaModelName = s; } +#if JVET_AC0177_MULTI_FRAME + std::string getNnlfSet1AlternativeInterLumaModelName() { return m_nnlfSet1AlternativeInterLumaModelName; } + void setNnlfSet1AlternativeInterLumaModelName(std::string s) { m_nnlfSet1AlternativeInterLumaModelName = s; } +#endif #endif #if JVET_AB0068_RD std::string getRdoCnnlfInterLumaModelName() { return m_rdoCnnlfInterLumaModelName; } @@ -2043,6 +2053,10 @@ public: unsigned getNnlfSet1InferSizeExtension() const { return m_nnlfSet1InferSizeExtension; } void setNnlfSet1MaxNumParams( unsigned s ) { m_nnlfSet1MaxNumParams = s; } unsigned getNnlfSet1MaxNumParams() const { return m_nnlfSet1MaxNumParams; } +#if JVET_AC0177_MULTI_FRAME + void setNnlfSet1UseMultiframe( bool b ) { m_nnlfSet1Multiframe = b; } + bool getNnlfSet1UseMultiframe() const { return m_nnlfSet1Multiframe; } +#endif #endif #if JVET_O0756_CALCULATE_HDRMETRICS void setWhitePointDeltaE( uint32_t index, double value ) { m_whitePointDeltaE[ index ] = value; } diff --git a/source/Lib/EncoderLib/EncGOP.cpp b/source/Lib/EncoderLib/EncGOP.cpp index 609e63118a3bda676a21feaf4a70c917070cce2f..996fc82098d416776852387fe4bd8284a1a7b747 100644 --- a/source/Lib/EncoderLib/EncGOP.cpp +++ b/source/Lib/EncoderLib/EncGOP.cpp @@ -3106,7 +3106,11 @@ void EncGOP::compressGOP( int iPOCLast, int iNumPicRcvd, PicList& rcListPic, #if NN_FILTERING_SET_1 if ( cs.sps->getNnlfSet1EnabledFlag() ) { +#if JVET_AC0177_MULTI_FRAME + m_pcNNFilterSet1.init(m_pcEncLib->getNnlfSet1InterLumaModelName(), m_pcEncLib->getNnlfSet1InterChromaModelName(), m_pcEncLib->getNnlfSet1IntraLumaModelName(), m_pcEncLib->getNnlfSet1IntraChromaModelName(), m_pcEncLib->getNnlfSet1AlternativeInterLumaModelName()); +#else m_pcNNFilterSet1.init(m_pcEncLib->getNnlfSet1InterLumaModelName(), m_pcEncLib->getNnlfSet1InterChromaModelName(), m_pcEncLib->getNnlfSet1IntraLumaModelName(), m_pcEncLib->getNnlfSet1IntraChromaModelName()); +#endif m_pcNNFilterSet1.initCABACEstimator( m_pcEncLib->getCABACEncoder(), m_pcEncLib->getCtxCache(), pcSlice ); m_pcNNFilterSet1.cnnFilterEncoder(pcPic, pcSlice->getLambdas()); } diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp index 2421b93953bc56f3f039d291002c90f25e84e345..6523da53b8cfa28bd1eb7941f7926de31ba45d3b 100644 --- a/source/Lib/EncoderLib/EncLib.cpp +++ b/source/Lib/EncoderLib/EncLib.cpp @@ -1456,6 +1456,9 @@ void EncLib::xInitSPS( SPS& sps ) { sps.setNnlfSet1MaxNumParams(m_nnlfSet1MaxNumParams); } +#if JVET_AC0177_MULTI_FRAME + sps.setNnlfSet1MultiframeEnabledFlag(m_nnlfSet1Multiframe); +#endif } #endif diff --git a/source/Lib/EncoderLib/VLCWriter.cpp b/source/Lib/EncoderLib/VLCWriter.cpp index 10dba722dd7908a4d3cbe651738e240bc02d2b09..aad5d30b34f4175be0f3a5acda3f5b48b18be022 100644 --- a/source/Lib/EncoderLib/VLCWriter.cpp +++ b/source/Lib/EncoderLib/VLCWriter.cpp @@ -1022,6 +1022,9 @@ void HLSWriter::codeSPS( const SPS* pcSPS ) WRITE_UVLC( pcSPS->getNnlfSet1InferSize(CNNLF_INFER_GRANULARITY_BASE), "sps_nnlf_set1_infer_size_base" ); WRITE_UVLC( pcSPS->getNnlfSet1InferSizeExtension(), "sps_nnlf_set1_infer_size_extension" ); WRITE_UVLC( pcSPS->getNnlfSet1MaxNumParams(), "sps_nnlf_set1_max_num_params" ); +#if JVET_AC0177_MULTI_FRAME + WRITE_FLAG( pcSPS->getNnlfSet1MultiframeEnabledFlag(), "sps_nnlf_set1_multi_frame_enabled_flag" ); +#endif } #endif diff --git a/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Conversion/convert.py b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Conversion/convert.py new file mode 100755 index 0000000000000000000000000000000000000000..fb68905594bb382149faf5ba3af659dbe030eee4 --- /dev/null +++ b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Conversion/convert.py @@ -0,0 +1,61 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn +from net import ConditionalNet +import numpy as np +import os + +# input +yuv = np.ones((1, 1, 32, 32), dtype=np.float32) +pred = np.ones((1, 1, 32, 32), dtype=np.float32) +forw = np.ones((1, 1, 32, 32), dtype=np.float32) +bacw = np.ones((1, 1, 32, 32), dtype=np.float32) +qp = np.ones((1, 1, 32, 32), dtype=np.float32) + +# model +# model = nn.DataParallel(ConditionalNet(96, 8)) # if model is trained on multiple GPUs +model = ConditionalNet(96, 8) # if model is trained with single GPU +state = torch.load('50.ckpt', map_location=torch.device('cpu')) +model.load_state_dict(state) + +dummy_input = (torch.from_numpy(yuv), torch.from_numpy(pred), torch.from_numpy(forw), torch.from_numpy(bacw), torch.from_numpy(qp)) +torch.onnx.export(model.module, dummy_input, "NnlfSet1_LumaCNNFilter_InterSlice_MultiframePrior_Tid345_int16.onnx") + + + + + + diff --git a/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Conversion/net.py b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Conversion/net.py new file mode 100644 index 0000000000000000000000000000000000000000..eebcdb75d4ce31096d34c64a67f9a3c2e3d486a8 --- /dev/null +++ b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Conversion/net.py @@ -0,0 +1,133 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn + + +def conv3x3(in_channels, out_channels, stride=1, padding=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=stride, padding=padding) + + +def conv1x1(in_channels, out_channels, stride=1, padding=0): + return nn.Conv2d(in_channels, out_channels, kernel_size=1, + stride=stride, padding=padding) + + +# Conv3x3 + PReLU +class conv3x3_f(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(conv3x3_f, self).__init__() + self.conv = conv3x3(in_channels, out_channels, stride) + self.relu = nn.PReLU() + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + return x + + +# Conv1x1 + PReLU +class conv1x1_f(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(conv1x1_f, self).__init__() + self.conv = conv1x1(in_channels, out_channels, stride) + self.relu = nn.PReLU() + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + return x + + +# Residual Block +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ResidualBlock, self).__init__() + self.conv1 = conv3x3(in_channels, out_channels) + self.relu = nn.PReLU() + self.conv2 = conv3x3(out_channels, out_channels) + + def forward(self, x): + out = self.conv1(x) + out = self.relu(out) + out = self.conv2(out) + return out + + +class ConditionalNet(nn.Module): + def __init__(self, f, rbn): + super(ConditionalNet, self).__init__() + self.rbn = rbn + self.convRec = conv3x3_f(1, f) + self.convPred = conv3x3_f(1, f) + self.convTemp = conv3x3_f(2, f) + self.convQp = conv3x3_f(1, f) + self.fuse = conv1x1_f(4 * f, f) + self.transitionH = conv3x3_f(f, f, 2) + self.backbone = nn.ModuleList([ResidualBlock(f, f)]) + for _ in range(self.rbn - 1): + self.backbone.append(ResidualBlock(f, f)) + self.last_layer = nn.Sequential( + nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=3, + stride=1, + padding=1), + nn.PReLU(), + nn.Conv2d( + in_channels=f, + out_channels=4, + kernel_size=3, + stride=1, + padding=1), + # nn.PixelShuffle(2) + ) + + def forward(self, rec, pred, forw, bacw, qp): + rec_f = self.convRec(rec) + pred_f = self.convPred(pred) + temp_f = self.convTemp(torch.cat((forw, bacw), 1)) + qp_f = self.convQp(qp) + xh = torch.cat((rec_f, pred_f, temp_f, qp_f), 1) + xh = self.fuse(xh) + x = self.transitionH(xh) + for i in range(self.rbn): + x = self.backbone[i](x) + x + # output + x = self.last_layer(x) + # x = x + rec + return x diff --git a/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/dataset.py b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d2899bf5cd2d07d896cdcd7f2397589c3c88b2d1 --- /dev/null +++ b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/dataset.py @@ -0,0 +1,151 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from __future__ import print_function, division +import torch +import numpy as np +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +import json +import matplotlib.pyplot as plt +import torchvision.transforms.functional as F + + +class CnnlfDataset(Dataset): + """MultiType Tree partition prediction dataset.""" + + def __init__(self, data, transform=None): + """ + Args: + blk_size: block size. + data: dataset bin file. There should be a json description file in the same folder named + data.json + transform (callable, optional): Optional transform to be applied + on a sample. + """ + self.transform = transform + + with open(data+'.json') as file: + description = json.loads(file.read()) + + self.data_file = data + self.patch_size = description['patch_size'] + self.border_size = description['border_size'] + self.components = dict(zip(description['components'], range(len(description['components'])))) + self.nb_comp = len(self.components) + self.len = description['nb_patches'] + self.block_size = self.patch_size + 2 * self.border_size + self.block_volume = self.block_size * self.block_size * self.nb_comp + + def __len__(self): + return self.len + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + with open(self.data_file) as file: + block = np.fromfile(file, dtype='float32', count=self.block_volume, offset=self.block_volume * idx * 4).\ + reshape((self.block_size, self.block_size, self.nb_comp)) + + org = block[:, :, self.components['org_Y']] + rec = block[:, :, self.components['rec_before_dbf_Y']] + pred = block[:, :, self.components['pred_Y']] + forw = block[:, :, self.components['ref_list_0_Y']] + bacw = block[:, :, self.components['ref_list_1_Y']] + qp = block[:, :, self.components['qp_base']] + training_sample = {'org': org, 'rec': rec, 'pred': pred, 'forw': forw, 'bacw': bacw, 'qp': qp} + + if self.transform: + training_sample = self.transform(training_sample) + + return training_sample + + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + + def __call__(self, sample): + org, rec, pred, forw, bacw, qp = sample['org'], sample['rec'], sample['pred'], sample['forw'], sample['bacw'], sample['qp'] + + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + org = org[np.newaxis, :, :] + rec = rec[np.newaxis, :, :] + pred = pred[np.newaxis, :, :] + forw = forw[np.newaxis, :, :] + bacw = bacw[np.newaxis, :, :] + qp = qp[np.newaxis, :, :] + return {'org': torch.from_numpy(org), + 'rec': torch.from_numpy(rec), + 'pred': torch.from_numpy(pred), + 'forw': torch.from_numpy(forw), + 'bacw': torch.from_numpy(bacw), + 'qp': torch.from_numpy(qp)} + + +def preprocess(sample, device): + org, rec, pred = sample['org'].to(device), sample['rec'].to(device), sample['pred'].to(device) + forw, bacw, qp = sample['forw'].to(device), sample['bacw'].to(device), sample['qp'].to(device) + + if torch.rand(1) < 0.5: + org[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.hflip(org[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + rec[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.hflip(rec[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + pred[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.hflip(pred[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + forw[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.hflip(forw[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + bacw[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.hflip(bacw[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + if torch.rand(1) < 0.5: + org[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.vflip(org[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + rec[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.vflip(rec[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + pred[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.vflip(pred[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + forw[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.vflip(forw[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + bacw[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :] = F.vflip(bacw[qp[:, 0, 0, 0] > 30.0/64.0, :, :, :]) + + return {'org': org, 'rec': rec, 'pred': pred, 'forw': forw, 'bacw': bacw, 'qp': qp} + + +def display_batch(batch): + fig, ax = plt.subplots(nrows=8, ncols=8) + i = 0 + for row in ax: + for col in row: + col.imshow(batch[i, 0, :, :], cmap='gray', vmin=0., vmax=1.) + i += 1 + plt.show() + + +def get_loader(data, batch_size, shuffle, num_workers=4): + dataset = CnnlfDataset(data=data, transform=transforms.Compose([ToTensor()])) + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + return data_loader diff --git a/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/main.py b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/main.py new file mode 100755 index 0000000000000000000000000000000000000000..db523c733da081b3cad0cca37de7ed5206cc3e56 --- /dev/null +++ b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/main.py @@ -0,0 +1,176 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import argparse +import torch +import torch.nn as nn +from net import ConditionalNet +from dataset import get_loader, display_batch, preprocess +from math import log10 + + +# learning policy +def adjust_learning_rate(optimizer, decay_rate): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + print('update learing rate') + for param_group in optimizer.param_groups: + param_group['lr'] *= decay_rate + + +def train(opt): + # Device configuration + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + # load training and validation data + train_loader = get_loader(data=opt.train_data, batch_size=opt.train_batch_size, shuffle=True, + num_workers=opt.num_workers) + if opt.validation_data: + validation_loader = get_loader(data=opt.validation_data, batch_size=opt.validation_batch_size, shuffle=False, + num_workers=opt.num_workers) + # Construct network + CnnlfNet = ConditionalNet(opt.feature_maps, opt.rbn) + + if torch.cuda.device_count() > 1: + print("Let's use", torch.cuda.device_count(), "GPUs!") + # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs + CnnlfNet = nn.DataParallel(CnnlfNet) + CnnlfNet.to(device) + + if opt.pretrained_model: + CnnlfNet.load_state_dict(torch.load(opt.pretrained_model, map_location=device)) + + optimizer = torch.optim.Adam(CnnlfNet.parameters(), lr=opt.lr, weight_decay=opt.weight_decay) + + # Loss and optimizer + criterion = nn.MSELoss() + if opt.MSE == 1: + print('Loss function: MSE') + criterionB = nn.MSELoss() + else: + print('Loss function: SAD') + criterionB = nn.L1Loss() + + # Train the model + step_per_epoch = len(train_loader) + for epoch in range(0, opt.epoches): + print('Epoch {}'.format(epoch)) + for i, sample_t in enumerate(train_loader): + global_step = epoch * step_per_epoch + i + 1 + if global_step == opt.decay_epoch1 * step_per_epoch or global_step == opt.decay_epoch2 * step_per_epoch: + adjust_learning_rate(optimizer, opt.decay_rate) + if global_step == opt.mse_epoch * step_per_epoch: + criterionB = nn.MSELoss() + adjust_learning_rate(optimizer, opt.decay_rate) + sample_t = preprocess(sample_t, device) + org, rec, pred = sample_t['org'], sample_t['rec'], sample_t['pred'] + forw, bacw, qp = sample_t['forw'], sample_t['bacw'], sample_t['qp'] + # Forward pass + outputs = CnnlfNet(rec, pred, forw, bacw, qp) + lossB = criterionB(outputs, org) + loss = criterion(outputs, org) + + # Backward and optimize + optimizer.zero_grad() + lossB.backward() + optimizer.step() + if global_step % opt.loss_interval == 0: + print('Global Step {}, Loss: {:.4f}'.format(global_step, loss.item())) + print('PSNR of the model on a batch: {}'.format(10 * log10(1 / loss.item()))) + if opt.validation_data and global_step % opt.validation_interval == 0: + # validate the model + with torch.no_grad(): + avg_psnr_ref = 0 + avg_psnr_val = 0 + _len = 0 + for sample_v in validation_loader: + sample_v = preprocess(sample_v, device) + org, rec, pred = sample_v['org'], sample_v['rec'], sample_v['pred'] + forw, bacw, qp = sample_v['forw'], sample_v['bacw'], sample_v['qp'] + qp = sample_v['qp'].to(device) + outputs = CnnlfNet(rec, pred, forw, bacw, qp) + loss_ref = criterion(rec, org) + loss_val = criterion(outputs, org) + if loss_ref.item() == 0: + continue + _len += 1 + avg_psnr_ref += 10 * log10(1 / loss_ref.item()) + avg_psnr_val += 10 * log10(1 / loss_val.item()) + print('PSNR of the anchor on validation: {}'.format(avg_psnr_ref / _len)) + print('PSNR of the model on validation: {}'.format(avg_psnr_val / _len)) + if (epoch + 1) % opt.checkpoint_interval == 0: + # Save the model checkpoint + torch.save(CnnlfNet.state_dict(), str(epoch+1) + '.ckpt') + + +if __name__ == '__main__': + # parse arguments + parser = argparse.ArgumentParser() + + # training/validation data + parser.add_argument('--num_workers', type=int, default=4, help='number of workers') + parser.add_argument('--train_batch_size', type=int, default=64, help='train batch size') + parser.add_argument('--validation_batch_size', type=int, default=64, help='validation batch size') + parser.add_argument('--train_data', type=str, default='luma_data.bin', help='training data') + parser.add_argument('--validation_data', type=str, default='', help='validation data') + + # network + parser.add_argument('--feature_maps', type=int, default=96, help='number of feature maps') + parser.add_argument('--rbn', type=int, default=8, help='number of residual blocks') + + # loss + parser.add_argument('--weight_decay', type=float, default=1e-8, help='weight decay') + parser.add_argument('--MSE', type=int, default=0, help='loss function, default=SAD') + + # optimizer configurations + parser.add_argument('--lr', type=float, default=1e-4, help='learning rate, default=0.0001') + parser.add_argument('--decay_epoch1', type=int, default=20, help='first milestone to decay lr') + parser.add_argument('--decay_epoch2', type=int, default=30, help='second milestone to decay lr') + parser.add_argument('--mse_epoch', type=int, default=40, help='switch to the mse loss and decrease learning rate at mse_epoch') + parser.add_argument('--decay_rate', type=float, default=0.1, help='the multiplier to decay lr') + parser.add_argument('--epoches', type=int, default=50, help='number of epoches to train for') + + # dump configurations + parser.add_argument('--loss_interval', type=int, default=1000, help='number of iteration to print loss') + parser.add_argument('--validation_interval', type=int, default=10000, help='number of iteration to validate') + parser.add_argument('--checkpoint_interval', type=int, default=10, help='number of epochs to save checkpoint') + + # finetune + parser.add_argument('--pretrained_model', type=str, default="", help='pretrained model') + + options = parser.parse_args() + print(options) + train(options) + diff --git a/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/net.py b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/net.py new file mode 100644 index 0000000000000000000000000000000000000000..03f6880d8fc82f9519ad17e21ebbf0f51ec028df --- /dev/null +++ b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/net.py @@ -0,0 +1,133 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn + + +def conv3x3(in_channels, out_channels, stride=1, padding=1): + return nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=stride, padding=padding) + + +def conv1x1(in_channels, out_channels, stride=1, padding=0): + return nn.Conv2d(in_channels, out_channels, kernel_size=1, + stride=stride, padding=padding) + + +# Conv3x3 + PReLU +class conv3x3_f(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(conv3x3_f, self).__init__() + self.conv = conv3x3(in_channels, out_channels, stride) + self.relu = nn.PReLU() + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + return x + + +# Conv1x1 + PReLU +class conv1x1_f(nn.Module): + def __init__(self, in_channels, out_channels, stride=1): + super(conv1x1_f, self).__init__() + self.conv = conv1x1(in_channels, out_channels, stride) + self.relu = nn.PReLU() + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + return x + + +# Residual Block +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(ResidualBlock, self).__init__() + self.conv1 = conv3x3(in_channels, out_channels) + self.relu = nn.PReLU() + self.conv2 = conv3x3(out_channels, out_channels) + + def forward(self, x): + out = self.conv1(x) + out = self.relu(out) + out = self.conv2(out) + return out + + +class ConditionalNet(nn.Module): + def __init__(self, f, rbn): + super(ConditionalNet, self).__init__() + self.rbn = rbn + self.convRec = conv3x3_f(1, f) + self.convPred = conv3x3_f(1, f) + self.convTemp = conv3x3_f(2, f) + self.convQp = conv3x3_f(1, f) + self.fuse = conv1x1_f(4 * f, f) + self.transitionH = conv3x3_f(f, f, 2) + self.backbone = nn.ModuleList([ResidualBlock(f, f)]) + for _ in range(self.rbn - 1): + self.backbone.append(ResidualBlock(f, f)) + self.last_layer = nn.Sequential( + nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=3, + stride=1, + padding=1), + nn.PReLU(), + nn.Conv2d( + in_channels=f, + out_channels=4, + kernel_size=3, + stride=1, + padding=1), + nn.PixelShuffle(2) + ) + + def forward(self, rec, pred, forw, bacw, qp): + rec_f = self.convRec(rec) + pred_f = self.convPred(pred) + temp_f = self.convTemp(torch.cat((forw, bacw), 1)) + qp_f = self.convQp(qp) + xh = torch.cat((rec_f, pred_f, temp_f, qp_f), 1) + xh = self.fuse(xh) + x = self.transitionH(xh) + for i in range(self.rbn): + x = self.backbone[i](x) + x + # output + x = self.last_layer(x) + x = x + rec + return x diff --git a/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/train.sh b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..53f5719c1aebe226ce4993bb13853d1aabbe4fb0 --- /dev/null +++ b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/Training/train.sh @@ -0,0 +1,35 @@ +# The copyright in this software is being made available under the BSD +# License, included below. This software may be subject to other third party +# and contributor rights, including patent rights, and no such rights are +# granted under this license. +# +# Copyright (c) 2010-2022, ITU/ISO/IEC +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +# be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. + +#!/bin/bash +python3 main.py + diff --git a/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/instruction.pdf b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/instruction.pdf new file mode 100644 index 0000000000000000000000000000000000000000..5b3626c1366eded884276de65c7f8860fe8f6177 Binary files /dev/null and b/training/training_scripts/Nn_Filtering_Set_1/Scripts/AdditionalLumaInter/instruction.pdf differ