Skip to content
Snippets Groups Projects
Commit 3fda99ff authored by Yue Li's avatar Yue Li
Browse files

add a small common API for neural network inference

parent fd901837
No related branches found
No related tags found
No related merge requests found
/** \file NNInference.h
\brief neural network-based inference class (header)
*/
#ifndef __NNINFERENCE__
#define __NNINFERENCE__
#include "CommonDef.h"
#include "Unit.h"
#include "Picture.h"
#include "Reshape.h"
#include <sadl/model.h>
using namespace std;
//! \ingroup CommonLib
//! \{
struct InputData {
NNInputType cnnlfInputType;
int index;
double scale;
int shift;
bool luma;
bool chroma;
};
class NNInference
{
public:
NNInference();
template<typename T>
static void fillInputFromBuf (Picture* pic, UnitArea inferArea, sadl::Tensor<T> &input, PelUnitBuf buf, bool luma, bool chroma, double scale, int shift)
{
PelBuf bufY, bufCb, bufCr;
if (luma)
{
bufY = buf.get(COMPONENT_Y);
}
if (chroma)
{
bufCb = buf.get(COMPONENT_Cb);
bufCr = buf.get(COMPONENT_Cr);
}
int hor, ver;
if (luma)
{
hor = inferArea.lwidth();
ver = inferArea.lheight();
}
else
{
hor = inferArea.lwidth() >> 1;
ver = inferArea.lheight() >> 1;
}
#if NN_FIXED_POINT_IMPLEMENTATION
for (int yy = 0; yy < ver; yy++)
{
for (int xx = 0; xx < hor; xx++)
{
if (luma && !chroma)
{
input(0, yy, xx, 0) = bufY.at(xx, yy) << shift;
}
else if (!luma && chroma)
{
input(0, yy, xx, 0) = bufCb.at(xx, yy) << shift;
input(0, yy, xx, 1) = bufCr.at(xx, yy) << shift;
}
else if (luma && chroma)
{
input(0, yy, xx, 0) = bufY.at(xx, yy) << shift;
input(0, yy, xx, 1) = bufCb.at(xx >> 1, yy >> 1) << shift;
input(0, yy, xx, 2) = bufCr.at(xx >> 1, yy >> 1) << shift;
}
}
}
#else
for (int yy = 0; yy < ver; yy++)
{
for (int xx = 0; xx < hor; xx++)
{
if (luma && !chroma)
{
input(0, yy, xx, 0) = bufY.at(xx, yy) / scale;
}
else if (!luma && chroma)
{
input(0, yy, xx, 0) = bufCb.at(xx, yy) / scale;
input(0, yy, xx, 1) = bufCr.at(xx, yy) / scale;
}
else if (luma && chroma)
{
input(0, yy, xx, 0) = bufY.at(xx, yy) / scale;
input(0, yy, xx, 1) = bufCb.at(xx >> 1, yy >> 1) / scale;
input(0, yy, xx, 2) = bufCr.at(xx >> 1, yy >> 1) / scale;
}
}
}
#endif
}
template<typename T>
static void fillInputFromConstant (Picture* pic, UnitArea inferArea, sadl::Tensor<T> &input, int c, bool luma, double scale, int shift)
{
int hor, ver;
if (luma)
{
hor = inferArea.lwidth();
ver = inferArea.lheight();
}
else
{
hor = inferArea.lwidth() >> 1;
ver = inferArea.lheight() >> 1;
}
#if NN_FIXED_POINT_IMPLEMENTATION
for (int yy = 0; yy < ver; yy++)
{
for (int xx = 0; xx < hor; xx++)
{
input(0, yy, xx, 0) = c << shift;
}
}
#else
for (int yy = 0; yy < ver; yy++)
{
for (int xx = 0; xx < hor; xx++)
{
input(0, yy, xx, 0) = c / scale;
}
}
#endif
}
template<typename T>
static void prepareInputs (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int globalQp, int localQp, bool inter, std::vector<InputData> listInputData)
{
for (auto &inputData : listInputData)
{
switch (inputData.cnnlfInputType)
{
case NN_INPUT_REC:
fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getRecBeforeDbfBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift);
break;
case NN_INPUT_PRED:
fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getPredBufCustom(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift);
break;
case NN_INPUT_PARTITION:
fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getCuAverageBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift);
break;
case NN_INPUT_BS:
fillInputFromBuf<T>(pic, inferArea, inputs[inputData.index], pic->getBsMapBuf(inferArea), inputData.luma, inputData.chroma, inputData.scale, inputData.shift);
break;
case NN_INPUT_GLOBAL_QP:
fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], globalQp, inputData.luma, inputData.scale, inputData.shift);
break;
case NN_INPUT_LOCAL_QP:
fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], localQp, inputData.luma, inputData.scale, inputData.shift);
break;
case NN_INPUT_SLICE_TYPE:
fillInputFromConstant<T>(pic, inferArea, inputs[inputData.index], inter ? 1 : 0, inputData.luma, inputData.scale, inputData.shift);
break;
default:
THROW("invalid input data");
break;
}
}
}
template<typename T>
static void infer(sadl::Model<T> &model, vector<sadl::Tensor<T>> &inputs)
{
if (!model.apply(inputs))
{
cerr << "[ERROR] issue during luma model inference" << endl;
exit(-1);
}
}
};
//! \}
#endif
......@@ -51,6 +51,8 @@
#include <cassert>
// clang-format off
#define NN_COMMON_API 1
#define NNVC_INFO_ENCODER 1 // add some info in encoder logs necessary to extract data
#define NNVC_DUMP_DATA 1
......@@ -433,6 +435,20 @@ enum ComponentID
MAX_NUM_TBLOCKS = MAX_NUM_COMPONENT
};
#if NN_COMMON_API
enum NNInputType
{
NN_INPUT_REC = 0,
NN_INPUT_PRED = 1,
NN_INPUT_PARTITION = 2,
NN_INPUT_BS = 3,
NN_INPUT_GLOBAL_QP = 4,
NN_INPUT_LOCAL_QP = 5,
NN_INPUT_SLICE_TYPE = 6,
MAX_NUM_NN_INPUT = 7
};
#endif
#define MAP_CHROMA(c) (ComponentID(c))
enum InputColourSpaceConversion // defined in terms of conversion prior to input of encoder.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment