From 9b4695e1e51a0d2ddc5854145a54fc70424d0752 Mon Sep 17 00:00:00 2001
From: "takeshi.tsukuba" <takeshi.tsukuba@sony.com>
Date: Thu, 6 May 2021 16:05:52 +0900
Subject: [PATCH] JVET-V0066: Encoder improvements to palette coding for high
 bit depth

---
 source/Lib/EncoderLib/IntraSearch.cpp | 61 +++------------------------
 source/Lib/EncoderLib/IntraSearch.h   |  4 --
 2 files changed, 6 insertions(+), 59 deletions(-)

diff --git a/source/Lib/EncoderLib/IntraSearch.cpp b/source/Lib/EncoderLib/IntraSearch.cpp
index c1014b851..42b61d638 100644
--- a/source/Lib/EncoderLib/IntraSearch.cpp
+++ b/source/Lib/EncoderLib/IntraSearch.cpp
@@ -68,8 +68,6 @@ IntraSearch::IntraSearch()
   {
     m_pSharedPredTransformSkip[ch] = nullptr;
   }
-  m_truncBinBits = nullptr;
-  m_escapeNumBins = nullptr;
   m_minErrorIndexMap = nullptr;
   for (unsigned i = 0; i < (MAXPLTSIZE + 1); i++)
   {
@@ -163,21 +161,6 @@ void IntraSearch::destroy()
   m_tmpStorageLCU.destroy();
   m_colorTransResiBuf.destroy();
   m_isInitialized = false;
-  if (m_truncBinBits != nullptr)
-  {
-    for (unsigned i = 0; i < m_symbolSize; i++)
-    {
-      delete[] m_truncBinBits[i];
-      m_truncBinBits[i] = nullptr;
-    }
-    delete[] m_truncBinBits;
-    m_truncBinBits = nullptr;
-  }
-  if (m_escapeNumBins != nullptr)
-  {
-    delete[] m_escapeNumBins;
-    m_escapeNumBins = nullptr;
-  }
   if (m_indexError[0] != nullptr)
   {
     for (unsigned i = 0; i < (MAXPLTSIZE + 1); i++)
@@ -310,20 +293,6 @@ void IntraSearch::init( EncCfg*        pcEncCfg,
   m_isInitialized = true;
   if (pcEncCfg->getPLTMode())
   {
-    m_symbolSize = (1 << bitDepthY); // pixel values are within [0, SymbolSize-1] with size SymbolSize
-    if (m_truncBinBits == nullptr)
-    {
-      m_truncBinBits = new uint16_t*[m_symbolSize];
-      for (unsigned i = 0; i < m_symbolSize; i++)
-      {
-        m_truncBinBits[i] = new uint16_t[m_symbolSize + 1];
-      }
-    }
-    if (m_escapeNumBins == nullptr)
-    {
-      m_escapeNumBins = new uint16_t[m_symbolSize];
-    }
-    initTBCTable(bitDepthY);
     if (m_indexError[0] == nullptr)
     {
       for (unsigned i = 0; i < (MAXPLTSIZE + 1); i++)
@@ -2207,7 +2176,7 @@ void IntraSearch::preCalcPLTIndexRD(CodingStructure& cs, Partitioner& partitione
       {
         if (lossless)
         {
-          rate += m_escapeNumBins[curPel[comp]];
+          rate += getEpExGolombNumBins(curPel[comp], 5);
         }
         else
         {
@@ -2220,7 +2189,7 @@ void IntraSearch::preCalcPLTIndexRD(CodingStructure& cs, Partitioner& partitione
           {
             error += tmpErr * tmpErr;
           }
-          rate += m_escapeNumBins[paPixelValue[comp]];   // encode quantized escape color
+          rate += getEpExGolombNumBins(paPixelValue[comp], 5);   // encode quantized escape color
         }
       }
       double rdCost = (double)error + m_pcRdCost->getLambda()*(double)rate;
@@ -2527,7 +2496,7 @@ double IntraSearch::rateDistOptPLT(
       rdCost = MAX_DOUBLE;
       return rdCost;
     }
-    rdCost += m_pcRdCost->getLambda()*(m_truncBinBits[(runIndex > refIndex) ? runIndex - 1 : runIndex][(scanPos == 0) ? (indexMaxValue + 1) : indexMaxValue] << SCALE_BITS);
+    rdCost += m_pcRdCost->getLambda()*(getTruncBinBits((runIndex > refIndex) ? runIndex - 1 : runIndex, (scanPos == 0) ? (indexMaxValue + 1) : indexMaxValue)  << SCALE_BITS);
   }
   rdCost += m_indexError[runIndex][m_scanOrder[scanPos].idx] * (1 << SCALE_BITS);
   if (scanPos > 0)
@@ -2545,6 +2514,7 @@ double IntraSearch::rateDistOptPLT(
   }
   return rdCost;
 }
+
 uint32_t IntraSearch::getEpExGolombNumBins(uint32_t symbol, uint32_t count)
 {
   uint32_t numBins = 0;
@@ -2596,26 +2566,6 @@ uint32_t IntraSearch::getTruncBinBits(uint32_t symbol, uint32_t maxSymbol)
   return idxCodeBit;
 }
 
-void IntraSearch::initTBCTable(int bitDepth)
-{
-  for (uint32_t i = 0; i < m_symbolSize; i++)
-  {
-    memset(m_truncBinBits[i], 0, sizeof(uint16_t)*(m_symbolSize + 1));
-  }
-  for (uint32_t i = 0; i < (m_symbolSize + 1); i++)
-  {
-    for (uint32_t j = 0; j < i; j++)
-    {
-      m_truncBinBits[j][i] = getTruncBinBits(j, i);
-    }
-  }
-  memset(m_escapeNumBins, 0, sizeof(uint16_t)*m_symbolSize);
-  for (uint32_t i = 0; i < m_symbolSize; i++)
-  {
-    m_escapeNumBins[i] = getEpExGolombNumBins(i, 5);
-  }
-}
-
 void IntraSearch::calcPixelPred(CodingStructure& cs, Partitioner& partitioner, uint32_t yPos, uint32_t xPos, ComponentID compBegin, uint32_t numComp)
 {
   CodingUnit    &cu = *cs.getCU(partitioner.chType);
@@ -2740,9 +2690,10 @@ void IntraSearch::derivePLTLossy(CodingStructure& cs, Partitioner& partitioner,
 
   TransformUnit &tu = *cs.getTU(partitioner.chType);
   QpParam cQP(tu, compBegin);
-  int qp = cQP.Qp(true) - 12;
+  int qp = cQP.Qp(true) - 6*(channelBitDepth_L - 8);
   qp = (qp < 0) ? 0 : ((qp > 56) ? 56 : qp);
   int errorLimit = g_paletteQuant[qp];
+
   if (lossless)
   {
     errorLimit = 0;
diff --git a/source/Lib/EncoderLib/IntraSearch.h b/source/Lib/EncoderLib/IntraSearch.h
index 3133ec800..67253de40 100644
--- a/source/Lib/EncoderLib/IntraSearch.h
+++ b/source/Lib/EncoderLib/IntraSearch.h
@@ -385,9 +385,6 @@ protected:
   CtxCache*       m_CtxCache;
 
   bool            m_isInitialized;
-  uint32_t        m_symbolSize;
-  uint16_t**      m_truncBinBits;
-  uint16_t*       m_escapeNumBins;
   bool            m_bestEscape;
   double*         m_indexError[MAXPLTSIZE + 1];
   uint8_t*        m_minErrorIndexMap; // store the best index in terms of distortion for each pixel
@@ -481,7 +478,6 @@ protected:
   void     deriveIndexMap         (CodingStructure& cs, Partitioner& partitioner, ComponentID compBegin, uint32_t numComp, PLTScanMode pltScanMode, double& dCost, bool* idxExist);
   bool     deriveSubblockIndexMap(CodingStructure& cs, Partitioner& partitioner, ComponentID compBegin, PLTScanMode pltScanMode, int minSubPos, int maxSubPos, const BinFracBits& fracBitsPltRunType, const BinFracBits* fracBitsPltIndexINDEX, const BinFracBits* fracBitsPltIndexCOPY, const double minCost, bool useRotate);
   double   rateDistOptPLT         (bool RunType, uint8_t RunIndex, bool prevRunType, uint8_t prevRunIndex, uint8_t aboveRunIndex, bool& prevCodedRunType, int& prevCodedRunPos, int scanPos, uint32_t width, int dist, int indexMaxValue, const BinFracBits* IndexfracBits, const BinFracBits& TypefracBits);
-  void     initTBCTable           (int bitDepth);
   uint32_t getTruncBinBits        (uint32_t symbol, uint32_t maxSymbol);
   uint32_t getEpExGolombNumBins   (uint32_t symbol, uint32_t count);
   void xGetNextISPMode                    ( ModeInfo& modeInfo, const ModeInfo* lastMode, const Size cuSize );
-- 
GitLab