From 6dd3f6bd0295052509e2c9a1ef6ab1478ebd5c51 Mon Sep 17 00:00:00 2001
From: jiechen <jiechen.cj@alibaba-inc.com>
Date: Thu, 30 Jan 2020 17:31:40 +0800
Subject: [PATCH] enable palette mode for non 4:4:4 color format

---
 source/App/EncoderApp/EncAppCfg.cpp      |   2 +
 source/Lib/CommonLib/IntraPrediction.cpp |  56 +++++
 source/Lib/CommonLib/Slice.h             |   4 +
 source/Lib/CommonLib/TypeDef.h           |   2 +
 source/Lib/CommonLib/Unit.cpp            |   7 +
 source/Lib/CommonLib/Unit.h              |   3 +
 source/Lib/DecoderLib/CABACReader.cpp    |  68 ++++++
 source/Lib/DecoderLib/DecCu.cpp          |  11 +
 source/Lib/DecoderLib/VLCReader.cpp      |   4 +
 source/Lib/EncoderLib/CABACWriter.cpp    |  11 +
 source/Lib/EncoderLib/EncCu.cpp          |  43 ++++
 source/Lib/EncoderLib/EncModeCtrl.cpp    |  16 ++
 source/Lib/EncoderLib/IntraSearch.cpp    | 252 +++++++++++++++++++++++
 source/Lib/EncoderLib/IntraSearch.h      |  56 +++++
 source/Lib/EncoderLib/VLCWriter.cpp      |   4 +
 15 files changed, 539 insertions(+)

diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp
index 8baddf032..04d595d3a 100644
--- a/source/App/EncoderApp/EncAppCfg.cpp
+++ b/source/App/EncoderApp/EncAppCfg.cpp
@@ -3768,7 +3768,9 @@ void EncAppCfg::xPrintParameter()
   m_useColorTrans = (m_chromaFormatIDC == CHROMA_444 && m_costMode != COST_LOSSLESS_CODING) ? m_useColorTrans : 0u;
 #endif
   msg(VERBOSE, "ACT:%d ", m_useColorTrans);
+#if !JVET_Q0504_PLT_NON444
     m_PLTMode = ( m_chromaFormatIDC == CHROMA_444) ? m_PLTMode : 0u;
+#endif
     msg(VERBOSE, "PLT:%d ", m_PLTMode);
     msg(VERBOSE, "IBC:%d ", m_IBCMode);
   msg( VERBOSE, "HashME:%d ", m_HashME );
diff --git a/source/Lib/CommonLib/IntraPrediction.cpp b/source/Lib/CommonLib/IntraPrediction.cpp
index 776a332ba..be475677f 100644
--- a/source/Lib/CommonLib/IntraPrediction.cpp
+++ b/source/Lib/CommonLib/IntraPrediction.cpp
@@ -1927,6 +1927,10 @@ void IntraPrediction::reorderPLT(CodingStructure& cs, Partitioner& partitioner,
 
     for (curidx = 0; curidx < cu.curPLTSize[compBegin]; curidx++)
     {
+#if JVET_Q0504_PLT_NON444
+      if( curPLTpred[curidx] )
+        continue;
+#endif
       bool matchTmp = true;
       for (int comp = compBegin; comp < (compBegin + numComp); comp++)
       {
@@ -1943,10 +1947,25 @@ void IntraPrediction::reorderPLT(CodingStructure& cs, Partitioner& partitioner,
     {
       cu.reuseflag[compBegin][predidx] = true;
       curPLTpred[curidx] = true;
+#if JVET_Q0504_PLT_NON444
+      if( cu.isLocalSepTree() )
+      {
+        cu.reuseflag[COMPONENT_Y][predidx] = true;
+        for( int comp = COMPONENT_Y; comp < MAX_NUM_COMPONENT; comp++ )
+        {
+          curPLTtmp[comp][reusePLTSizetmp] = cs.prevPLT.curPLT[comp][predidx];
+        }
+      }
+      else
+      {
+#endif
       for (int comp = compBegin; comp < (compBegin + numComp); comp++)
       {
         curPLTtmp[comp][reusePLTSizetmp] = cs.prevPLT.curPLT[comp][predidx];
       }
+#if JVET_Q0504_PLT_NON444
+      }
+#endif
       reusePLTSizetmp++;
       pltSizetmp++;
     }
@@ -1956,20 +1975,57 @@ void IntraPrediction::reorderPLT(CodingStructure& cs, Partitioner& partitioner,
   {
     if (!curPLTpred[curidx])
     {
+#if JVET_Q0504_PLT_NON444
+      if( cu.isLocalSepTree() )
+      {
+        for( int comp = compBegin; comp < (compBegin + numComp); comp++ )
+        {
+          curPLTtmp[comp][pltSizetmp] = cu.curPLT[comp][curidx];
+        }
+        if( isLuma(partitioner.chType) )
+        {
+          curPLTtmp[COMPONENT_Cb][pltSizetmp] = 1 << (cs.sps->getBitDepth(CHANNEL_TYPE_CHROMA) - 1);
+          curPLTtmp[COMPONENT_Cr][pltSizetmp] = 1 << (cs.sps->getBitDepth(CHANNEL_TYPE_CHROMA) - 1);
+        }
+        else
+        {
+          curPLTtmp[COMPONENT_Y][pltSizetmp] = 1 << (cs.sps->getBitDepth(CHANNEL_TYPE_LUMA) - 1);
+        }
+      }
+      else
+      {
+#endif
       for (int comp = compBegin; comp < (compBegin + numComp); comp++)
       {
         curPLTtmp[comp][pltSizetmp] = cu.curPLT[comp][curidx];
       }
+#if JVET_Q0504_PLT_NON444
+      }
+#endif
       pltSizetmp++;
     }
   }
   assert(pltSizetmp == cu.curPLTSize[compBegin]);
   for (int curidx = 0; curidx < cu.curPLTSize[compBegin]; curidx++)
   {
+#if JVET_Q0504_PLT_NON444
+    if( cu.isLocalSepTree() )
+    {
+      for( int comp = COMPONENT_Y; comp < MAX_NUM_COMPONENT; comp++ )
+      {
+        cu.curPLT[comp][curidx] = curPLTtmp[comp][curidx];
+      }
+    }
+    else
+    {
+#endif
     for (int comp = compBegin; comp < (compBegin + numComp); comp++)
     {
       cu.curPLT[comp][curidx] = curPLTtmp[comp][curidx];
     }
+#if JVET_Q0504_PLT_NON444
+    }
+#endif
   }
 }
 //! \}
diff --git a/source/Lib/CommonLib/Slice.h b/source/Lib/CommonLib/Slice.h
index 4fb4fb24d..0700d0ecb 100644
--- a/source/Lib/CommonLib/Slice.h
+++ b/source/Lib/CommonLib/Slice.h
@@ -2659,7 +2659,11 @@ protected:
   Picture*              xGetRefPic( PicList& rcListPic, int poc, const int layerId );
   Picture*              xGetLongTermRefPic( PicList& rcListPic, int poc, bool pocHasMsb, const int layerId );
 public:
+#if JVET_Q0504_PLT_NON444
+  std::unordered_map< bool, std::unordered_map< Position, std::unordered_map< Size, double> > > m_mapPltCost;
+#else
   std::unordered_map< Position, std::unordered_map< Size, double> > m_mapPltCost;
+#endif
 private:
 };// END CLASS DEFINITION Slice
 
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index b71b81557..0302faaf0 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -50,6 +50,8 @@
 #include <assert.h>
 #include <cassert>
 
+#define JVET_Q0504_PLT_NON444                             1 // JVET-Q0504: enable palette mode for non 444 color format
+
 #define JVET_Q0512_ENC_CHROMA_TS_ACT                      1 // JVET-Q0512: encoder-side improvement on enabling chroma transform-skip for ACT
 #define JVET_Q0446_MIP_CONST_SHIFT_OFFSET                 1 // JVET-Q0446: MIP with constant shift and offset
 
diff --git a/source/Lib/CommonLib/Unit.cpp b/source/Lib/CommonLib/Unit.cpp
index 8a6d7c8d6..43c3d7bb6 100644
--- a/source/Lib/CommonLib/Unit.cpp
+++ b/source/Lib/CommonLib/Unit.cpp
@@ -378,6 +378,13 @@ const bool CodingUnit::isSepTree() const
   return treeType != TREE_D || CS::isDualITree( *cs );
 }
 
+#if JVET_Q0504_PLT_NON444
+const bool CodingUnit::isLocalSepTree() const
+{
+  return treeType != TREE_D && !CS::isDualITree(*cs);
+}
+#endif
+
 const bool CodingUnit::checkCCLMAllowed() const
 {
   bool allowCCLM = false;
diff --git a/source/Lib/CommonLib/Unit.h b/source/Lib/CommonLib/Unit.h
index b8c2a30c7..5d147c31f 100644
--- a/source/Lib/CommonLib/Unit.h
+++ b/source/Lib/CommonLib/Unit.h
@@ -367,6 +367,9 @@ struct CodingUnit : public UnitArea
   const uint8_t     checkAllowedSbt() const;
   const bool        checkCCLMAllowed() const;
   const bool        isSepTree() const;
+#if JVET_Q0504_PLT_NON444
+  const bool        isLocalSepTree() const;
+#endif 
   const bool        isConsInter() const { return modeType == MODE_TYPE_INTER; }
   const bool        isConsIntra() const { return modeType == MODE_TYPE_INTRA; }
 };
diff --git a/source/Lib/DecoderLib/CABACReader.cpp b/source/Lib/DecoderLib/CABACReader.cpp
index 08e21e183..8ff4e51fc 100644
--- a/source/Lib/DecoderLib/CABACReader.cpp
+++ b/source/Lib/DecoderLib/CABACReader.cpp
@@ -671,6 +671,16 @@ void CABACReader::coding_tree( CodingStructure& cs, Partitioner& partitioner, CU
   bool jointPLT = false;
   if (cu.isSepTree())
   {
+#if JVET_Q0504_PLT_NON444
+    if( cu.isLocalSepTree() )
+    {
+      compBegin = COMPONENT_Y;
+      numComp = (cu.chromaFormat != CHROMA_400)?3: 1;
+      jointPLT = true;
+    }
+    else
+    {
+#endif
     if (isLuma(partitioner.chType))
     {
       compBegin = COMPONENT_Y;
@@ -681,11 +691,18 @@ void CABACReader::coding_tree( CodingStructure& cs, Partitioner& partitioner, CU
       compBegin = COMPONENT_Cb;
       numComp = 2;
     }
+#if JVET_Q0504_PLT_NON444
+    }
+#endif
   }
   else
   {
     compBegin = COMPONENT_Y;
+#if JVET_Q0504_PLT_NON444
+    numComp = (cu.chromaFormat != CHROMA_400) ? 3 : 1;
+#else
     numComp = 3;
+#endif
     jointPLT = true;
   }
   if (CU::isPLT(cu))
@@ -854,7 +871,18 @@ void CABACReader::coding_unit( CodingUnit &cu, Partitioner &partitioner, CUCtx&
     }
     else
     {
+#if JVET_Q0504_PLT_NON444
+      if( cu.chromaFormat != CHROMA_400 )
+      {
+        cu_palette_info(cu, COMPONENT_Y, 3, cuCtx);
+      }
+      else
+      {
+        cu_palette_info(cu, COMPONENT_Y, 1, cuCtx);
+      }
+#else
       cu_palette_info(cu, COMPONENT_Y, 3, cuCtx);
+#endif
     }
     end_of_ctu(cu, cuCtx);
     return;
@@ -1651,6 +1679,10 @@ void CABACReader::cu_palette_info(CodingUnit& cu, ComponentID compBegin, uint32_
   TransformUnit&   tu = *cu.firstTU;
   int curPLTidx = 0;
 
+#if JVET_Q0504_PLT_NON444
+  if( cu.isLocalSepTree() )
+    cu.cs->prevPLT.curPLTSize[compBegin] = cu.cs->prevPLT.curPLTSize[COMPONENT_Y];
+#endif
   cu.lastPLTSize[compBegin] = cu.cs->prevPLT.curPLTSize[compBegin];
 
   if (cu.lastPLTSize[compBegin])
@@ -1662,10 +1694,24 @@ void CABACReader::cu_palette_info(CodingUnit& cu, ComponentID compBegin, uint32_
   {
     if (cu.reuseflag[compBegin][idx])
     {
+#if JVET_Q0504_PLT_NON444
+      if( cu.isLocalSepTree() )
+      {
+        for( int comp = COMPONENT_Y; comp < MAX_NUM_COMPONENT; comp++ )
+        {
+          cu.curPLT[comp][curPLTidx] = cu.cs->prevPLT.curPLT[comp][idx];
+        }
+      }
+      else
+      {
+#endif
       for (int comp = compBegin; comp < (compBegin + numComp); comp++)
       {
         cu.curPLT[comp][curPLTidx] = cu.cs->prevPLT.curPLT[comp][idx];
       }
+#if JVET_Q0504_PLT_NON444
+      }
+#endif
       curPLTidx++;
     }
   }
@@ -1678,6 +1724,10 @@ void CABACReader::cu_palette_info(CodingUnit& cu, ComponentID compBegin, uint32_
   }
 
   cu.curPLTSize[compBegin] = curPLTidx + recievedPLTnum;
+#if JVET_Q0504_PLT_NON444
+  if( cu.isLocalSepTree() )
+    cu.curPLTSize[COMPONENT_Y] = cu.curPLTSize[compBegin];
+#endif
   for (int comp = compBegin; comp < (compBegin + numComp); comp++)
   {
     for (int idx = curPLTidx; idx < cu.curPLTSize[compBegin]; idx++)
@@ -1685,6 +1735,20 @@ void CABACReader::cu_palette_info(CodingUnit& cu, ComponentID compBegin, uint32_
       ComponentID compID = (ComponentID)comp;
       const int  channelBitDepth = sps.getBitDepth(toChannelType(compID));
       cu.curPLT[compID][idx] = m_BinDecoder.decodeBinsEP(channelBitDepth);
+#if JVET_Q0504_PLT_NON444
+      if( cu.isLocalSepTree() )
+      {
+        if( isLuma( cu.chType ) )
+        {
+          cu.curPLT[COMPONENT_Cb][idx] = 1 << (cu.cs->sps->getBitDepth(CHANNEL_TYPE_CHROMA) - 1);
+          cu.curPLT[COMPONENT_Cr][idx] = 1 << (cu.cs->sps->getBitDepth(CHANNEL_TYPE_CHROMA) - 1);
+        }
+        else
+        {
+          cu.curPLT[COMPONENT_Y][idx] = 1 << (cu.cs->sps->getBitDepth(CHANNEL_TYPE_LUMA) - 1);
+        }
+      }
+#endif
     }
   }
   cu.useEscape[compBegin] = true;
@@ -1906,6 +1970,10 @@ void CABACReader::xDecodePLTPredIndicator(CodingUnit& cu, uint32_t maxPLTSize, C
         idx += symbol - 1;
       }
       cu.reuseflag[compBegin][idx] = 1;
+#if JVET_Q0504_PLT_NON444
+      if( cu.isLocalSepTree() )
+        cu.reuseflag[COMPONENT_Y][idx] = 1;
+#endif
       numPltPredicted++;
       idx++;
     }
diff --git a/source/Lib/DecoderLib/DecCu.cpp b/source/Lib/DecoderLib/DecCu.cpp
index 9298a4cb2..23d57c514 100644
--- a/source/Lib/DecoderLib/DecCu.cpp
+++ b/source/Lib/DecoderLib/DecCu.cpp
@@ -482,7 +482,18 @@ void DecCu::xReconIntraQT( CodingUnit &cu )
     }
     else
     {
+#if JVET_Q0504_PLT_NON444
+      if( cu.chromaFormat != CHROMA_400 )
+      {
+        xReconPLT(cu, COMPONENT_Y, 3);
+      }
+      else
+      {
+        xReconPLT(cu, COMPONENT_Y, 1);
+      }
+#else
       xReconPLT(cu, COMPONENT_Y, 3);
+#endif
     }
     return;
   }
diff --git a/source/Lib/DecoderLib/VLCReader.cpp b/source/Lib/DecoderLib/VLCReader.cpp
index b010fb9a1..a6c24bc3a 100644
--- a/source/Lib/DecoderLib/VLCReader.cpp
+++ b/source/Lib/DecoderLib/VLCReader.cpp
@@ -1604,6 +1604,9 @@ void HLSyntaxReader::parseSPS(SPS* pcSPS)
   {
     pcSPS->setUseColorTrans(false);
   }
+#if JVET_Q0504_PLT_NON444
+  READ_FLAG( uiCode,  "sps_palette_enabled_flag");                                pcSPS->setPLTMode                ( uiCode != 0 );
+#else
   if (pcSPS->getChromaFormatIdc() == CHROMA_444)
   {
     READ_FLAG( uiCode,  "sps_palette_enabled_flag");                                pcSPS->setPLTMode                ( uiCode != 0 );
@@ -1612,6 +1615,7 @@ void HLSyntaxReader::parseSPS(SPS* pcSPS)
   {
     pcSPS->setPLTMode(false);
   }
+#endif
   READ_FLAG( uiCode,    "sps_bcw_enabled_flag" );                   pcSPS->setUseBcw( uiCode != 0 );
   READ_FLAG(uiCode, "sps_ibc_enabled_flag");                                    pcSPS->setIBCFlag(uiCode);
   // KJS: sps_ciip_enabled_flag
diff --git a/source/Lib/EncoderLib/CABACWriter.cpp b/source/Lib/EncoderLib/CABACWriter.cpp
index 441fe0465..ce5c6e739 100644
--- a/source/Lib/EncoderLib/CABACWriter.cpp
+++ b/source/Lib/EncoderLib/CABACWriter.cpp
@@ -686,7 +686,18 @@ void CABACWriter::coding_unit( const CodingUnit& cu, Partitioner& partitioner, C
     }
     else
     {
+#if JVET_Q0504_PLT_NON444
+      if( cu.chromaFormat != CHROMA_400 )
+      {
+        cu_palette_info(cu, COMPONENT_Y, 3, cuCtx);
+      }
+      else
+      {
+        cu_palette_info(cu, COMPONENT_Y, 1, cuCtx);
+      }
+#else
       cu_palette_info(cu, COMPONENT_Y, 3, cuCtx);
+#endif
     }
     end_of_ctu(cu, cuCtx);
     return;
diff --git a/source/Lib/EncoderLib/EncCu.cpp b/source/Lib/EncoderLib/EncCu.cpp
index 011ce09d3..de1db2a74 100644
--- a/source/Lib/EncoderLib/EncCu.cpp
+++ b/source/Lib/EncoderLib/EncCu.cpp
@@ -610,6 +610,16 @@ void EncCu::xCompressCU( CodingStructure*& tempCS, CodingStructure*& bestCS, Par
   bool jointPLT = false;
   if (partitioner.isSepTree( *tempCS ))
   {
+#if JVET_Q0504_PLT_NON444
+    if( !CS::isDualITree(*tempCS) && partitioner.treeType != TREE_D )
+    {
+      compBegin = COMPONENT_Y;
+      numComp = (tempCS->area.chromaFormat != CHROMA_400)?3: 1;
+      jointPLT = true;
+    }
+    else
+    {
+#endif
     if (isLuma(partitioner.chType))
     {
       compBegin = COMPONENT_Y;
@@ -620,11 +630,18 @@ void EncCu::xCompressCU( CodingStructure*& tempCS, CodingStructure*& bestCS, Par
       compBegin = COMPONENT_Cb;
       numComp = 2;
     }
+#if JVET_Q0504_PLT_NON444
+    }
+#endif
   }
   else
   {
     compBegin = COMPONENT_Y;
+#if JVET_Q0504_PLT_NON444
+    numComp = (tempCS->area.chromaFormat != CHROMA_400) ? 3 : 1;
+#else
     numComp = 3;
+#endif
     jointPLT = true;
   }
   SplitSeries splitmode = -1;
@@ -2031,7 +2048,18 @@ void EncCu::xCheckPLT(CodingStructure *&tempCS, CodingStructure *&bestCS, Partit
   }
   else
   {
+#if JVET_Q0504_PLT_NON444
+    if( cu.chromaFormat != CHROMA_400 )
+    {
+      m_pcIntraSearch->PLTSearch(*tempCS, partitioner, COMPONENT_Y, 3);
+    }
+    else
+    {
+      m_pcIntraSearch->PLTSearch(*tempCS, partitioner, COMPONENT_Y, 1);
+    }
+#else
     m_pcIntraSearch->PLTSearch(*tempCS, partitioner, COMPONENT_Y, 3);
+#endif
   }
 
 
@@ -2061,7 +2089,18 @@ void EncCu::xCheckPLT(CodingStructure *&tempCS, CodingStructure *&bestCS, Partit
   }
   else
   {
+#if JVET_Q0504_PLT_NON444
+    if( cu.chromaFormat != CHROMA_400 )
+    {
+      m_CABACEstimator->cu_palette_info(cu, COMPONENT_Y, 3, cuCtx);
+    }
+    else
+    {
+      m_CABACEstimator->cu_palette_info(cu, COMPONENT_Y, 1, cuCtx);
+    }
+#else
     m_CABACEstimator->cu_palette_info(cu, COMPONENT_Y, 3, cuCtx);
+#endif
   }
   tempCS->fracBits = m_CABACEstimator->getEstFracBits();
   tempCS->cost = m_pcRdCost->calcRdCost(tempCS->fracBits, tempCS->dist);
@@ -2075,7 +2114,11 @@ void EncCu::xCheckPLT(CodingStructure *&tempCS, CodingStructure *&bestCS, Partit
   tempCS->useDbCost = m_pcEncCfg->getUseEncDbOpt();
 
   const Area currCuArea = cu.block(getFirstComponentOfChannel(partitioner.chType));
+#if JVET_Q0504_PLT_NON444
+  cu.slice->m_mapPltCost[isChroma(partitioner.chType)][currCuArea.pos()][currCuArea.size()] = tempCS->cost;
+#else
   cu.slice->m_mapPltCost[currCuArea.pos()][currCuArea.size()] = tempCS->cost;
+#endif
 #if WCG_EXT
   DTRACE_MODE_COST(*tempCS, m_pcRdCost->getLambda(true));
 #else
diff --git a/source/Lib/EncoderLib/EncModeCtrl.cpp b/source/Lib/EncoderLib/EncModeCtrl.cpp
index 33d3edb25..c96046af0 100644
--- a/source/Lib/EncoderLib/EncModeCtrl.cpp
+++ b/source/Lib/EncoderLib/EncModeCtrl.cpp
@@ -1324,12 +1324,20 @@ void EncModeCtrlMTnoRQT::initCULevel( Partitioner &partitioner, const CodingStru
     // add intra modes
     if( tryIntraRdo )
     {
+#if JVET_Q0504_PLT_NON444
+    if (cs.slice->getSPS()->getPLTMode() && (partitioner.treeType != TREE_D || cs.slice->isIRAP() || (cs.area.lwidth() == 4 && cs.area.lheight() == 4)) && getPltEnc())
+#else
     if (cs.slice->getSPS()->getPLTMode() && ( cs.slice->isIRAP() || (cs.area.lwidth() == 4 && cs.area.lheight() == 4) ) && getPltEnc() )
+#endif
     {
       m_ComprCUCtxList.back().testModes.push_back({ ETM_PALETTE, ETO_STANDARD, qp });
     }
     m_ComprCUCtxList.back().testModes.push_back( { ETM_INTRA, ETO_STANDARD, qp } );
+#if JVET_Q0504_PLT_NON444
+    if (cs.slice->getSPS()->getPLTMode() && partitioner.treeType == TREE_D && !cs.slice->isIRAP() && !(cs.area.lwidth() == 4 && cs.area.lheight() == 4) && getPltEnc())
+#else
     if (cs.slice->getSPS()->getPLTMode() && !cs.slice->isIRAP() && !(cs.area.lwidth() == 4 && cs.area.lheight() == 4) && getPltEnc() )
+#endif
     {
       m_ComprCUCtxList.back().testModes.push_back({ ETM_PALETTE,  ETO_STANDARD, qp });
     }
@@ -1570,7 +1578,11 @@ bool EncModeCtrlMTnoRQT::tryMode( const EncTestMode& encTestmode, const CodingSt
         }
       }
     }
+#if JVET_Q0504_PLT_NON444
+    if (bestMode.type == ETM_PALETTE && !slice.isIRAP() && partitioner.treeType == TREE_D && !(partitioner.currArea().lumaSize().width == 4 && partitioner.currArea().lumaSize().height == 4)) // inter slice
+#else
     if (bestMode.type == ETM_PALETTE && !slice.isIRAP() && !( partitioner.currArea().lumaSize().width == 4 && partitioner.currArea().lumaSize().height == 4) ) // inter slice
+#endif
     {
       return false;
     }
@@ -1593,7 +1605,11 @@ bool EncModeCtrlMTnoRQT::tryMode( const EncTestMode& encTestmode, const CodingSt
     const Area curr_cu = CS::getArea(cs, cs.area, partitioner.chType).blocks[getFirstComponentOfChannel(partitioner.chType)];
     try
     {
+#if JVET_Q0504_PLT_NON444
+      double stored_cost = slice.m_mapPltCost.at(isChroma(partitioner.chType)).at(curr_cu.pos()).at(curr_cu.size());
+#else
       double stored_cost = slice.m_mapPltCost.at(curr_cu.pos()).at(curr_cu.size());
+#endif
       if (bestMode.type != ETM_INVALID && stored_cost > cuECtx.bestCS->cost)
       {
         return false;
diff --git a/source/Lib/EncoderLib/IntraSearch.cpp b/source/Lib/EncoderLib/IntraSearch.cpp
index 02a43bae9..b21fd2580 100644
--- a/source/Lib/EncoderLib/IntraSearch.cpp
+++ b/source/Lib/EncoderLib/IntraSearch.cpp
@@ -1562,6 +1562,10 @@ void IntraSearch::PLTSearch(CodingStructure &cs, Partitioner& partitioner, Compo
     cs.getPredBuf().copyFrom(cs.getOrgBuf());
     cs.getPredBuf().Y().rspSignal(m_pcReshape->getFwdLUT());
   }
+#if JVET_Q0504_PLT_NON444
+  if( cu.isLocalSepTree() )
+    cs.prevPLT.curPLTSize[compBegin] = cs.prevPLT.curPLTSize[COMPONENT_Y];
+#endif
   cu.lastPLTSize[compBegin] = cs.prevPLT.curPLTSize[compBegin];
   //derive palette
   derivePLTLossy(cs, partitioner, compBegin, numComp);
@@ -1604,6 +1608,10 @@ void IntraSearch::PLTSearch(CodingStructure &cs, Partitioner& partitioner, Compo
     Pel curPLTtmp[MAX_NUM_COMPONENT][MAXPLTSIZE];
     int reuseFlagIdx = 0, curPLTtmpIdx = 0, reuseEntrySize = 0;
     memset(cu.reuseflag[compBegin], false, sizeof(bool) * MAXPLTPREDSIZE);
+#if JVET_Q0504_PLT_NON444
+    if( cu.isLocalSepTree() )
+      memset(cu.reuseflag[COMPONENT_Y], false, sizeof(bool) * MAXPLTPREDSIZE);
+#endif
     for (int curIdx = 0; curIdx < cu.curPLTSize[compBegin]; curIdx++)
     {
       if (idxExist[curIdx])
@@ -1631,6 +1639,10 @@ void IntraSearch::PLTSearch(CodingStructure &cs, Partitioner& partitioner, Compo
           if (match)
           {
             cu.reuseflag[compBegin][reuseFlagIdx] = true;
+#if JVET_Q0504_PLT_NON444
+            if( cu.isLocalSepTree() )
+              cu.reuseflag[COMPONENT_Y][reuseFlagIdx] = true;
+#endif
             reuseEntrySize++;
           }
         }
@@ -1640,6 +1652,10 @@ void IntraSearch::PLTSearch(CodingStructure &cs, Partitioner& partitioner, Compo
     cu.reusePLTSize[compBegin] = reuseEntrySize;
     // update palette table
     cu.curPLTSize[compBegin] = newPLTSize;
+#if JVET_Q0504_PLT_NON444
+    if( cu.isLocalSepTree() )
+      cu.curPLTSize[COMPONENT_Y] = newPLTSize;
+#endif
     for (int comp = compBegin; comp < (compBegin + numComp); comp++)
       memcpy( cu.curPLT[comp], curPLTtmp[comp], sizeof(Pel)*cu.curPLTSize[compBegin]);
   }
@@ -2369,6 +2385,87 @@ void IntraSearch::derivePLTLossy(CodingStructure& cs, Partitioner& partitioner,
 
   uint32_t scaleX = getComponentScaleX(COMPONENT_Cb, cs.sps->getChromaFormatIdc());
   uint32_t scaleY = getComponentScaleY(COMPONENT_Cb, cs.sps->getChromaFormatIdc());
+#if JVET_Q0504_PLT_NON444
+  for (uint32_t y = 0; y < height; y++)
+  {
+    for (uint32_t x = 0; x < width; x++)
+    {
+      uint32_t org[3], pX, pY;
+      for (int comp = compBegin; comp < (compBegin + numComp); comp++)
+      {
+        pX = (comp > 0 && compBegin == COMPONENT_Y) ? (x >> scaleX) : x;
+        pY = (comp > 0 && compBegin == COMPONENT_Y) ? (y >> scaleY) : y;
+        org[comp] = orgBuf[comp].at(pX, pY);
+      }
+      element.setAll(org, compBegin, numComp);
+
+      ComponentID tmpCompBegin = compBegin;
+      int tmpNumComp = numComp;
+      if( cs.sps->getChromaFormatIdc() != CHROMA_444 && 
+          numComp == 3 && 
+         (x != ((x >> scaleX) << scaleX) || (y != ((y >> scaleY) << scaleY))) )
+      {
+        tmpCompBegin = COMPONENT_Y;
+        tmpNumComp   = 1;
+      }
+      int besti = last, bestSAD = (last == -1) ? MAX_UINT : pelList[last].getSAD(element, cs.sps->getBitDepths(), tmpCompBegin, tmpNumComp);
+      if( bestSAD )
+      {
+        for (int i = idx - 1; i >= 0; i--)
+        {
+          uint32_t sad = pelList[i].getSAD(element, cs.sps->getBitDepths(), tmpCompBegin, tmpNumComp);
+          if (sad < bestSAD)
+          {
+            bestSAD = sad;
+            besti = i;
+            if (!sad) break;
+          }
+        }
+      }
+      if (besti >= 0 && pelList[besti].almostEqualData(element, errorLimit, cs.sps->getBitDepths(), tmpCompBegin, tmpNumComp))
+      {
+        pelList[besti].addElement(element, tmpCompBegin, tmpNumComp);
+        last = besti;
+      }
+      else
+      {
+        pelList[idx].copyDataFrom(element, tmpCompBegin, tmpNumComp);
+        for (int comp = tmpCompBegin; comp < (tmpCompBegin + tmpNumComp); comp++)
+          pelList[idx].setCnt(1, comp);
+        last = idx;
+        idx++;
+      }
+    }
+  }
+
+  if( cs.sps->getChromaFormatIdc() != CHROMA_444 && numComp == 3 )
+  {
+    for( int i = 0; i < idx; i++ )
+    {
+      pelList[i].setCnt( pelList[i].getCnt(COMPONENT_Y) + (pelList[i].getCnt(COMPONENT_Cb) >> 2), MAX_NUM_COMPONENT);
+    }
+  }
+  else
+  {
+    if( compBegin == 0 )
+    {
+      for( int i = 0; i < idx; i++ )
+      {
+        pelList[i].setCnt(pelList[i].getCnt(COMPONENT_Y), COMPONENT_Cb);
+        pelList[i].setCnt(pelList[i].getCnt(COMPONENT_Y), COMPONENT_Cr);
+        pelList[i].setCnt(pelList[i].getCnt(COMPONENT_Y), MAX_NUM_COMPONENT);
+      }
+    }
+    else
+    {
+      for( int i = 0; i < idx; i++ )
+      {
+        pelList[i].setCnt(pelList[i].getCnt(COMPONENT_Cb), COMPONENT_Y);
+        pelList[i].setCnt(pelList[i].getCnt(COMPONENT_Cb), MAX_NUM_COMPONENT);
+      }
+    }
+  }
+#else
   for (uint32_t y = 0; y < height; y++)
   {
     for (uint32_t x = 0; x < width; x++)
@@ -2409,10 +2506,18 @@ void IntraSearch::derivePLTLossy(CodingStructure& cs, Partitioner& partitioner,
       }
     }
   }
+#endif
 
   for (int i = 0; i < dictMaxSize; i++)
   {
+#if JVET_Q0504_PLT_NON444
+    pelListSort[i].setCnt(0, COMPONENT_Y);
+    pelListSort[i].setCnt(0, COMPONENT_Cb);
+    pelListSort[i].setCnt(0, COMPONENT_Cr);
+    pelListSort[i].setCnt(0, MAX_NUM_COMPONENT);
+#else
     pelListSort[i].setCnt(0);
+#endif
     pelListSort[i].resetAll(compBegin, numComp);
   }
 
@@ -2420,12 +2525,20 @@ void IntraSearch::derivePLTLossy(CodingStructure& cs, Partitioner& partitioner,
   dictMaxSize = 1;
   for (int i = 0; i < idx; i++)
   {
+#if JVET_Q0504_PLT_NON444
+    if( pelList[i].getCnt(MAX_NUM_COMPONENT) > pelListSort[dictMaxSize - 1].getCnt(MAX_NUM_COMPONENT) )
+#else
     if (pelList[i].getCnt() > pelListSort[dictMaxSize - 1].getCnt())
+#endif
     {
       int j;
       for (j = dictMaxSize; j > 0; j--)
       {
+#if JVET_Q0504_PLT_NON444
+        if (pelList[i].getCnt(MAX_NUM_COMPONENT) > pelListSort[j - 1].getCnt(MAX_NUM_COMPONENT))
+#else
         if (pelList[i].getCnt() > pelListSort[j - 1].getCnt() )
+#endif
         {
           pelListSort[j].copyAllFrom(pelListSort[j - 1], compBegin, numComp);
           dictMaxSize = std::min(dictMaxSize + 1, (uint32_t)MAXPLTSIZE);
@@ -2452,6 +2565,140 @@ void IntraSearch::derivePLTLossy(CodingStructure& cs, Partitioner& partitioner,
   int    run;
   double reuseflagCost;
 #endif
+#if JVET_Q0504_PLT_NON444
+  for( int i = 0; i < MAXPLTSIZE; i++ )
+  {
+    if( pelListSort[i].getCnt(MAX_NUM_COMPONENT) )
+    {
+      ComponentID tmpCompBegin = compBegin;
+      int tmpNumComp = numComp;
+      if( cs.sps->getChromaFormatIdc() != CHROMA_444 && numComp == 3 && pelListSort[i].getCnt(COMPONENT_Cb) == 0 )
+      {
+        tmpCompBegin = COMPONENT_Y;
+        tmpNumComp   = 1;
+      }
+
+      for( int comp = tmpCompBegin; comp < (tmpCompBegin + tmpNumComp); comp++ )
+      {
+        int half = pelListSort[i].getCnt(comp) >> 1;
+        cu.curPLT[comp][paletteSize] = (pelListSort[i].getSumData(comp) + half) / pelListSort[i].getCnt(comp);
+      }
+
+      int best = -1;
+      if( errorLimit )
+      {
+        double pal[MAX_NUM_COMPONENT], err = 0.0, bestCost = 0.0;
+        for( int comp = tmpCompBegin; comp < (tmpCompBegin + tmpNumComp); comp++ )
+        {
+          pal[comp] = pelListSort[i].getSumData(comp) / (double)pelListSort[i].getCnt(comp);
+          err = pal[comp] - cu.curPLT[comp][paletteSize];
+          if( isChroma((ComponentID) comp) )
+          {
+            bestCost += (err * err * PLT_CHROMA_WEIGHTING) / (1 << (2 * pcmShiftRight_C)) * pelListSort[i].getCnt(comp);
+          }
+          else
+          {
+            bestCost += (err * err) / (1 << (2 * pcmShiftRight_L)) * pelListSort[i].getCnt(comp);
+          }
+
+        }
+        bestCost += bitCost;
+
+        for( int t = 0; t < cs.prevPLT.curPLTSize[compBegin]; t++ )
+        {
+          double cost = 0.0;
+          for( int comp = tmpCompBegin; comp < (tmpCompBegin + tmpNumComp); comp++ )
+          {
+            err = pal[comp] - cs.prevPLT.curPLT[comp][t];
+            if( isChroma((ComponentID) comp) )
+            {
+              cost += (err * err * PLT_CHROMA_WEIGHTING) / (1 << (2 * pcmShiftRight_C)) * pelListSort[i].getCnt(comp);
+            }
+            else
+            {
+              cost += (err * err) / (1 << (2 * pcmShiftRight_L)) * pelListSort[i].getCnt(comp);
+            }
+          }
+#if JVET_Q0503_Q0712_PLT_ENCODER_IMPROV_BUGFIX
+          run = 0;
+          for (int t2 = t; t2 >= 0; t2--)
+          {
+            if (!reuseflag[t2])
+            {
+              run++;
+            }
+            else
+            {
+              break;
+            }
+          }
+          reuseflagCost = m_pcRdCost->getLambda() / (double)(1 << (2 * plt_lambda_shift)) * getEpExGolombNumBins(run ? run + 1 : run, 0);
+          cost += reuseflagCost;
+#endif
+
+          if( cost < bestCost )
+          {
+            best = t;
+            bestCost = cost;
+          }
+        }
+        if( best != -1 )
+        {
+          for( int comp = tmpCompBegin; comp < (tmpCompBegin + tmpNumComp); comp++ )
+          {
+            cu.curPLT[comp][paletteSize] = cs.prevPLT.curPLT[comp][best];
+          }
+#if JVET_Q0503_Q0712_PLT_ENCODER_IMPROV_BUGFIX
+          reuseflag[best] = true;
+#endif
+        }
+      }
+
+      bool duplicate = false;
+      if( pelListSort[i].getCnt(MAX_NUM_COMPONENT) == 1 && best == -1 )
+      {
+        duplicate = true;
+      }
+      else
+      {
+        for( int t = 0; t < paletteSize; t++ )
+        {
+          bool duplicateTmp = true;
+          for( int comp = tmpCompBegin; comp < (tmpCompBegin + tmpNumComp); comp++ )
+          {
+            duplicateTmp = duplicateTmp && (cu.curPLT[comp][paletteSize] == cu.curPLT[comp][t]);
+          }
+          if( duplicateTmp )
+          {
+            duplicate = true;
+            break;
+          }
+        }
+      }
+      if( !duplicate )
+      {
+        if( cs.sps->getChromaFormatIdc() != CHROMA_444 && numComp == 3 && pelListSort[i].getCnt(COMPONENT_Cb) == 0 )
+        {
+          if( best != -1 )
+          {
+            cu.curPLT[COMPONENT_Cb][paletteSize] = cs.prevPLT.curPLT[COMPONENT_Cb][best];
+            cu.curPLT[COMPONENT_Cr][paletteSize] = cs.prevPLT.curPLT[COMPONENT_Cr][best];
+          }
+          else
+          {
+            cu.curPLT[COMPONENT_Cb][paletteSize] = 1 << (channelBitDepth_C - 1);
+            cu.curPLT[COMPONENT_Cr][paletteSize] = 1 << (channelBitDepth_C - 1);
+          }
+        }
+        paletteSize++;
+      }
+    }
+    else
+    {
+      break;
+    }
+  }
+#else
   for (int i = 0; i < MAXPLTSIZE; i++)
   {
     if (pelListSort[i].getCnt())
@@ -2559,7 +2806,12 @@ void IntraSearch::derivePLTLossy(CodingStructure& cs, Partitioner& partitioner,
       break;
     }
   }
+#endif
   cu.curPLTSize[compBegin] = paletteSize;
+#if JVET_Q0504_PLT_NON444
+  if( cu.isLocalSepTree() )
+    cu.curPLTSize[COMPONENT_Y] = paletteSize;
+#endif
 
   delete[] pelList;
   delete[] pelListSort;
diff --git a/source/Lib/EncoderLib/IntraSearch.h b/source/Lib/EncoderLib/IntraSearch.h
index 9370c1430..da19e57f2 100644
--- a/source/Lib/EncoderLib/IntraSearch.h
+++ b/source/Lib/EncoderLib/IntraSearch.h
@@ -72,17 +72,33 @@ public:
     return cnt > other.cnt;
   }
   SortingElement() {
+#if JVET_Q0504_PLT_NON444
+    cnt[0] = cnt[1] = cnt[2] = cnt[3] = 0;
+    shift[0] = shift[1] = shift[2] = 0;
+    lastCnt[0] = lastCnt[1] = lastCnt[2] = 0;
+#else
     cnt = shift = lastCnt = 0;
+#endif
     data[0] = data[1] = data[2] = 0;
     sumData[0] = sumData[1] = sumData[2] = 0;
   }
+#if JVET_Q0504_PLT_NON444
+  uint32_t  getCnt(int idx) const         { return cnt[idx]; }
+  void      setCnt(uint32_t val, int idx) { cnt[idx] = val; }
+#else
   uint32_t  getCnt() const        { return cnt; }
   void      setCnt(uint32_t val)  { cnt = val; }
+#endif
   int       getSumData (int id) const   { return sumData[id]; }
 
   void resetAll(ComponentID compBegin, uint32_t numComp)
   {
+#if JVET_Q0504_PLT_NON444
+    shift[0] = shift[1] = shift[2] = 0;
+    lastCnt[0] = lastCnt[1] = lastCnt[2] = 0;
+#else
     shift = lastCnt = 0;
+#endif
     for (int ch = compBegin; ch < (compBegin + numComp); ch++)
     {
       data[ch] = 0;
@@ -134,19 +150,53 @@ public:
     {
       data[comp] = element.data[comp];
       sumData[comp] = data[comp];
+#if JVET_Q0504_PLT_NON444
+      shift[comp] = 0; 
+      lastCnt[comp] = 1;
+#endif
     }
+#if !JVET_Q0504_PLT_NON444
     shift = 0; lastCnt = 1;
+#endif
   }
   void copyAllFrom(SortingElement element, ComponentID compBegin, uint32_t numComp)
   {
     copyDataFrom(element, compBegin, numComp);
+#if !JVET_Q0504_PLT_NON444
     cnt = element.cnt;
+#endif
     for (int comp = compBegin; comp < (compBegin + numComp); comp++)
     {
       sumData[comp] = element.sumData[comp];
+#if JVET_Q0504_PLT_NON444
+      cnt[comp]     = element.cnt[comp];
+      shift[comp]   = element.shift[comp];
+      lastCnt[comp] = element.lastCnt[comp];
+#endif
     }
+#if JVET_Q0504_PLT_NON444
+    cnt[MAX_NUM_COMPONENT] = element.cnt[MAX_NUM_COMPONENT];
+#else
     lastCnt = element.lastCnt; shift = element.shift;
+#endif
   }
+#if JVET_Q0504_PLT_NON444
+  void addElement(const SortingElement& element, ComponentID compBegin, uint32_t numComp)
+  {
+    for (int i = compBegin; i<(compBegin + numComp); i++)
+    {
+      sumData[i] += element.data[i];
+      cnt[i]++;
+      if( cnt[i] > 1 && cnt[i] == 2 * lastCnt[i] )
+      {
+        uint32_t rnd = 1 << shift[i];
+        shift[i]++;
+        data[i] = (sumData[i] + rnd) >> shift[i];
+        lastCnt[i] = cnt[i];
+      }
+    }
+  }
+#else
   void addElement(const SortingElement& element, ComponentID compBegin, uint32_t numComp)
   {
     cnt++;
@@ -165,9 +215,15 @@ public:
       lastCnt = cnt;
     }
   }
+#endif
 private:
+#if JVET_Q0504_PLT_NON444
+  uint32_t cnt[MAX_NUM_COMPONENT+1];
+  int shift[3], lastCnt[3], data[3], sumData[3];
+#else
   uint32_t cnt;
   int shift, lastCnt, data[3], sumData[3];
+#endif
 };
 /// encoder search class
 class IntraSearch : public IntraPrediction, CrossComponentPrediction
diff --git a/source/Lib/EncoderLib/VLCWriter.cpp b/source/Lib/EncoderLib/VLCWriter.cpp
index 056819f01..d3ffb3ad2 100644
--- a/source/Lib/EncoderLib/VLCWriter.cpp
+++ b/source/Lib/EncoderLib/VLCWriter.cpp
@@ -1045,10 +1045,14 @@ void HLSWriter::codeSPS( const SPS* pcSPS )
     WRITE_FLAG(pcSPS->getUseColorTrans() ? 1 : 0, "sps_act_enabled_flag");
 #endif
   }
+#if JVET_Q0504_PLT_NON444
+  WRITE_FLAG(pcSPS->getPLTMode() ? 1 : 0,                                                    "sps_palette_enabled_flag" );
+#else
   if (pcSPS->getChromaFormatIdc() == CHROMA_444)
   {
     WRITE_FLAG(pcSPS->getPLTMode() ? 1 : 0,                                                    "sps_palette_enabled_flag" );
   }
+#endif
   WRITE_FLAG( pcSPS->getUseBcw() ? 1 : 0,                                                      "sps_bcw_enabled_flag" );
   WRITE_FLAG(pcSPS->getIBCFlag() ? 1 : 0,                                                      "sps_ibc_enabled_flag");
 
-- 
GitLab