diff --git a/source/Lib/CommonLib/NNInference.h b/source/Lib/CommonLib/NNInference.h new file mode 100644 index 0000000000000000000000000000000000000000..ad72de96a4d53af33f6512e5c5c140ae0006423e --- /dev/null +++ b/source/Lib/CommonLib/NNInference.h @@ -0,0 +1,180 @@ +/** \file NNInference.h + \brief neural network-based inference class (header) +*/ + +#ifndef __NNINFERENCE__ +#define __NNINFERENCE__ + +#include "CommonDef.h" +#include "Unit.h" +#include "Picture.h" +#include "Reshape.h" +#include <sadl/model.h> +using namespace std; +//! \ingroup CommonLib +//! \{ + +struct InputData { + NNInputType cnnlfInputType; + int index; + double scale; + int shift; + bool luma; + bool chroma; +}; + +class NNInference +{ +public: + NNInference(); + template<typename T> + static void fillInputFromBuf (Picture* pic, UnitArea inferArea, sadl::Tensor<T> &input, PelUnitBuf buf, bool luma, bool chroma, double scale, int shift) + { + PelBuf bufY, bufCb, bufCr; + + if (luma) + { + bufY = buf.get(COMPONENT_Y); + } + if (chroma) + { + bufCb = buf.get(COMPONENT_Cb); + bufCr = buf.get(COMPONENT_Cr); + } + + int hor, ver; + if (luma) + { + hor = inferArea.lwidth(); + ver = inferArea.lheight(); + } + else + { + hor = inferArea.lwidth() >> 1; + ver = inferArea.lheight() >> 1; + } + #if NN_FIXED_POINT_IMPLEMENTATION + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + if (luma && !chroma) + { + input(0, yy, xx, 0) = bufY.at(xx, yy) << shift; + } + else if (!luma && chroma) + { + input(0, yy, xx, 0) = bufCb.at(xx, yy) << shift; + input(0, yy, xx, 1) = bufCr.at(xx, yy) << shift; + } + else if (luma && chroma) + { + input(0, yy, xx, 0) = bufY.at(xx, yy) << shift; + input(0, yy, xx, 1) = bufCb.at(xx >> 1, yy >> 1) << shift; + input(0, yy, xx, 2) = bufCr.at(xx >> 1, yy >> 1) << shift; + } + } + } + #else + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + if (luma && !chroma) + { + input(0, yy, xx, 0) = bufY.at(xx, yy) / scale; + } + else if (!luma && chroma) + { + input(0, yy, xx, 0) = bufCb.at(xx, yy) / scale; + input(0, yy, xx, 1) = bufCr.at(xx, yy) / scale; + } + else if (luma && chroma) + { + input(0, yy, xx, 0) = bufY.at(xx, yy) / scale; + input(0, yy, xx, 1) = bufCb.at(xx >> 1, yy >> 1) / scale; + input(0, yy, xx, 2) = bufCr.at(xx >> 1, yy >> 1) / scale; + } + } + } + #endif + } + template<typename T> + static void fillInputFromConstant (Picture* pic, UnitArea inferArea, sadl::Tensor<T> &input, int c, bool luma, double scale, int shift) + { + int hor, ver; + if (luma) + { + hor = inferArea.lwidth(); + ver = inferArea.lheight(); + } + else + { + hor = inferArea.lwidth() >> 1; + ver = inferArea.lheight() >> 1; + } + #if NN_FIXED_POINT_IMPLEMENTATION + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + input(0, yy, xx, 0) = c << shift; + } + } + #else + for (int yy = 0; yy < ver; yy++) + { + for (int xx = 0; xx < hor; xx++) + { + input(0, yy, xx, 0) = c / scale; + } + } + #endif + } + template<typename T> + static void prepareInputs (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int globalQp, int localQp, bool inter, std::vector<InputData> listInputData) + { + for (auto &inputData : listInputData) + { + switch (inputData.cnnlfInputType) + { + case NN_INPUT_REC: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getRecBeforeDbfBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift); + break; + case NN_INPUT_PRED: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getPredBufCustom(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift); + break; + case NN_INPUT_PARTITION: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getCuAverageBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift); + break; + case NN_INPUT_BS: + fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getBsMapBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift); + break; + case NN_INPUT_GLOBAL_QP: + fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], globalQp, inputData.luma, inputData.scale, inputData.shift); + break; + case NN_INPUT_LOCAL_QP: + fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], localQp, inputData.luma, inputData.scale, inputData.shift); + break; + case NN_INPUT_SLICE_TYPE: + fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], inter ? 1 : 0, inputData.luma, inputData.scale, inputData.shift); + break; + default: + THROW("invalid input data"); + break; + } + } + } + template<typename T> + static void infer(sadl::Model<T> &model, vector<sadl::Tensor<T>> &inputs) + { + if (!model.apply(inputs)) + { + cerr << "[ERROR] issue during luma model inference" << endl; + exit(-1); + } + } +}; +//! \} +#endif + diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index e976ee3fcc64a19e7ebf3f8aa38ae62cefaf6331..aadb4a4281627dad0327599558e63874e16b9d59 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -51,6 +51,8 @@ #include <cassert> // clang-format off +#define NN_COMMON_API 1 + #define NNVC_INFO_ENCODER 1 // add some info in encoder logs necessary to extract data #define NNVC_DUMP_DATA 1 @@ -433,6 +435,20 @@ enum ComponentID MAX_NUM_TBLOCKS = MAX_NUM_COMPONENT }; +#if NN_COMMON_API +enum NNInputType +{ + NN_INPUT_REC = 0, + NN_INPUT_PRED = 1, + NN_INPUT_PARTITION = 2, + NN_INPUT_BS = 3, + NN_INPUT_GLOBAL_QP = 4, + NN_INPUT_LOCAL_QP = 5, + NN_INPUT_SLICE_TYPE = 6, + MAX_NUM_NN_INPUT = 7 +}; +#endif + #define MAP_CHROMA(c) (ComponentID(c)) enum InputColourSpaceConversion // defined in terms of conversion prior to input of encoder.