diff --git a/source/Lib/CommonLib/NNInference.h b/source/Lib/CommonLib/NNInference.h index ad72de96a4d53af33f6512e5c5c140ae0006423e..257584619e1ba4c6da96ab1dc841cd6d7d9d517d 100644 --- a/source/Lib/CommonLib/NNInference.h +++ b/source/Lib/CommonLib/NNInference.h @@ -132,7 +132,7 @@ public: #endif } template<typename T> - static void prepareInputs (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int globalQp, int localQp, bool inter, std::vector<InputData> listInputData) + static void prepareInputs (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int globalQp, int localQp, int sliceType, std::vector<InputData> listInputData) { for (auto &inputData : listInputData) { @@ -157,7 +157,7 @@ public: fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], localQp, inputData.luma, inputData.scale, inputData.shift); break; case NN_INPUT_SLICE_TYPE: - fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], inter ? 1 : 0, inputData.luma, inputData.scale, inputData.shift); + fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], sliceType, inputData.luma, inputData.scale, inputData.shift); break; default: THROW("invalid input data"); @@ -170,7 +170,7 @@ public: { if (!model.apply(inputs)) { - cerr << "[ERROR] issue during luma model inference" << endl; + cerr << "[ERROR] issue during inference" << endl; exit(-1); } }