diff --git a/source/Lib/CommonLib/CommonDef.h b/source/Lib/CommonLib/CommonDef.h index b2d3086e60b15d39d7e1da3e07b4b4e3a5f2ddd1..ed69ac29de581ec9ae1da35aca239b11cb35273b 100644 --- a/source/Lib/CommonLib/CommonDef.h +++ b/source/Lib/CommonLib/CommonDef.h @@ -195,9 +195,9 @@ static const int QP_OFFSET_NUM = 2; static const int NNQPOFFSET[QP_OFFSET_NUM] = { -5, 5 }; #endif #endif -#if NN_FILTERING_SET_1 && NN_FIXED_POINT_IMPLEMENTATION -static const int NN_INPUT_PRECISION= 13; -static const int NN_OUPUTPUT_PRECISION= 13; +#if NN_FILTERING_SET_1 +static const int NN_INPUT_PRECISION= 13; +static const int NN_OUTPUT_PRECISION= 13; #endif diff --git a/source/Lib/CommonLib/NNFilterSet1.cpp b/source/Lib/CommonLib/NNFilterSet1.cpp index 03c3da2f3575133b7069d40961c56e12ff5dda4f..dec53e08c6c4e9fe9b90603da2a29d24e975f80e 100644 --- a/source/Lib/CommonLib/NNFilterSet1.cpp +++ b/source/Lib/CommonLib/NNFilterSet1.cpp @@ -227,8 +227,8 @@ void extractOutputsLuma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, P #if NN_FIXED_POINT_IMPLEMENTATION int log2InputScale = 10; int log2OutputScale = 10; - int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale; - int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale; + int shiftInput = NN_OUTPUT_PRECISION - log2InputScale; + int shiftOutput = NN_OUTPUT_PRECISION - log2OutputScale; int offset = (1 << shiftOutput) / 2; #else double inputScale = 1024; @@ -277,10 +277,10 @@ template<typename T> void extractOutputsChroma (Picture* pic, sadl::Model<T> &m, PelStorage& tempBuf, PelStorage& tempScaledBuf, UnitArea inferArea, int extLeft, int extRight, int extTop, int extBottom, bool inter) { #if NN_FIXED_POINT_IMPLEMENTATION -int log2InputScale = 10; + int log2InputScale = 10; int log2OutputScale = 10; - int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale; - int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale; + int shiftInput = NN_OUTPUT_PRECISION - log2InputScale; + int shiftOutput = NN_OUTPUT_PRECISION - log2OutputScale; int offset = (1 << shiftOutput) / 2; #else double inputScale = 1024; @@ -338,6 +338,11 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL { //at::init_num_threads(); // use all available threads + double inputScale = 1024; + double qpScale = 64; + int log2InputScale = 10; + int log2QpScale = 6; + const int border_to_skip = 0; if (border_to_skip>0) sadl::Tensor<float>::skip_border = true; @@ -360,10 +365,10 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL std::vector<InputData> listInputData; if (inter) { - InputData inputRec = {NN_INPUT_REC, 0, 1024, 3, true, false}; - InputData inputPred = {NN_INPUT_PRED, 1, 1024, 3, true, false}; - InputData inputBs = {NN_INPUT_BS, 2, 1024, 3, true, false}; - InputData inputQp = {NN_INPUT_LOCAL_QP, 3, 64, 7, true, false}; + InputData inputRec = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputPred = {NN_INPUT_PRED, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputBs = {NN_INPUT_BS, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 3, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false}; listInputData.push_back(inputRec); listInputData.push_back(inputPred); listInputData.push_back(inputBs); @@ -371,15 +376,15 @@ void NNFilterSet1::cnnFilterLumaBlock(Picture* pic, UnitArea inferArea, int extL } else { - InputData inputRec = {NN_INPUT_REC, 0, 1024, 3, true, false}; - InputData inputPred = {NN_INPUT_PRED, 1, 1024, 3, true, false}; + InputData inputRec = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputPred = {NN_INPUT_PRED, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; #if JVET_AB0053_NO_PART_NO_ATTN - InputData inputBs = {NN_INPUT_BS, 2, 1024, 3, true, false}; - InputData inputQp = {NN_INPUT_LOCAL_QP, 3, 64, 7, true, false}; + InputData inputBs = {NN_INPUT_BS, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 3, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false}; #else - InputData inputPartition = {NN_INPUT_PARTITION, 2, 1024, 3, true, false}; - InputData inputBs = {NN_INPUT_BS, 3, 1024, 3, true, false}; - InputData inputQp = {NN_INPUT_LOCAL_QP, 4, 64, 7, true, false}; + InputData inputPartition = {NN_INPUT_PARTITION, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputBs = {NN_INPUT_BS, 3, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 4, qpScale, NN_INPUT_PRECISION - log2QpScale, true, false}; #endif listInputData.push_back(inputRec); listInputData.push_back(inputPred); @@ -404,6 +409,11 @@ void NNFilterSet1::cnnFilterChromaBlock(Picture* pic, UnitArea inferArea, int ex //at::init_num_threads(); + double inputScale = 1024; + double qpScale = 64; + int log2InputScale = 10; + int log2QpScale = 6; + const int border_to_skip = 0; if (border_to_skip>0) sadl::Tensor<float>::skip_border = true; @@ -427,11 +437,11 @@ void NNFilterSet1::cnnFilterChromaBlock(Picture* pic, UnitArea inferArea, int ex if (inter) { - InputData inputRecCrossComponent = {NN_INPUT_REC, 0, 1024, 3, true, false}; - InputData inputRec = {NN_INPUT_REC, 1, 1024, 3, false, true}; - InputData inputPred = {NN_INPUT_PRED, 2, 1024, 3, false, true}; - InputData inputBs = {NN_INPUT_BS, 3, 1024, 3, false, true}; - InputData inputQp = {NN_INPUT_LOCAL_QP, 4, 64, 7, false, true}; + InputData inputRecCrossComponent = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputRec = {NN_INPUT_REC, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true}; + InputData inputPred = {NN_INPUT_PRED, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true}; + InputData inputBs = {NN_INPUT_BS, 3, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 4, qpScale, NN_INPUT_PRECISION - log2QpScale, false, true}; listInputData.push_back(inputRecCrossComponent); listInputData.push_back(inputRec); listInputData.push_back(inputPred); @@ -440,12 +450,12 @@ void NNFilterSet1::cnnFilterChromaBlock(Picture* pic, UnitArea inferArea, int ex } else { - InputData inputRecCrossComponent = {NN_INPUT_REC, 0, 1024, 3, true, false}; - InputData inputRec = {NN_INPUT_REC, 1, 1024, 3, false, true}; - InputData inputPred = {NN_INPUT_PRED, 2, 1024, 3, false, true}; - InputData inputPartition = {NN_INPUT_PARTITION, 3, 1024, 3, false, true}; - InputData inputBs = {NN_INPUT_BS, 4, 1024, 3, false, true}; - InputData inputQp = {NN_INPUT_LOCAL_QP, 5, 64, 7, false, true}; + InputData inputRecCrossComponent = {NN_INPUT_REC, 0, inputScale, NN_INPUT_PRECISION - log2InputScale, true, false}; + InputData inputRec = {NN_INPUT_REC, 1, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true}; + InputData inputPred = {NN_INPUT_PRED, 2, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true}; + InputData inputPartition = {NN_INPUT_PARTITION, 3, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true}; + InputData inputBs = {NN_INPUT_BS, 4, inputScale, NN_INPUT_PRECISION - log2InputScale, false, true}; + InputData inputQp = {NN_INPUT_LOCAL_QP, 5, qpScale, NN_INPUT_PRECISION - log2QpScale, false, true}; listInputData.push_back(inputRecCrossComponent); listInputData.push_back(inputRec); listInputData.push_back(inputPred); diff --git a/source/Lib/EncoderLib/EncNNFilterSet1.cpp b/source/Lib/EncoderLib/EncNNFilterSet1.cpp index 72c4b9d77060c971b0b22c7f481d5ef4d53020fb..c6c8100856feb6c6fd771a70ed80a81629d6aa5f 100644 --- a/source/Lib/EncoderLib/EncNNFilterSet1.cpp +++ b/source/Lib/EncoderLib/EncNNFilterSet1.cpp @@ -273,8 +273,8 @@ void EncNNFilterSet1::extractOutputsLumaRd (Picture* pic, sadl::Model<T> &m, Pel #if NN_FIXED_POINT_IMPLEMENTATION int log2InputScale = 10; int log2OutputScale = 10; - int shiftInput = NN_OUPUTPUT_PRECISION - log2InputScale; - int shiftOutput = NN_OUPUTPUT_PRECISION - log2OutputScale; + int shiftInput = NN_OUTPUT_PRECISION - log2InputScale; + int shiftOutput = NN_OUTPUT_PRECISION - log2OutputScale; int offset = (1 << shiftOutput) / 2; #else double inputScale = 1024;