From ea7c8afe5bff50a7f05d104bf5d0f9289f515298 Mon Sep 17 00:00:00 2001
From: Hang Huang <huanghang@oppo.com>
Date: Sat, 18 Jan 2025 15:53:57 +0000
Subject: [PATCH] JVET-AK0064: LFNST/NSPT set derivation for CCP coded block
 (Test 3.2)

---
 source/Lib/CommonLib/IntraPrediction.cpp | 132 +++++++++++++++++++++++
 source/Lib/CommonLib/IntraPrediction.h   |   3 +
 source/Lib/CommonLib/TrQuant.cpp         |  12 +++
 source/Lib/CommonLib/TypeDef.h           |   1 +
 source/Lib/CommonLib/Unit.cpp            |  12 +++
 source/Lib/CommonLib/Unit.h              |   3 +
 source/Lib/CommonLib/UnitTools.cpp       |   4 +
 source/Lib/DecoderLib/DecCu.cpp          |  27 +++++
 source/Lib/EncoderLib/IntraSearch.cpp    |   7 ++
 9 files changed, 201 insertions(+)

diff --git a/source/Lib/CommonLib/IntraPrediction.cpp b/source/Lib/CommonLib/IntraPrediction.cpp
index 381e6ec25..79aea99c0 100644
--- a/source/Lib/CommonLib/IntraPrediction.cpp
+++ b/source/Lib/CommonLib/IntraPrediction.cpp
@@ -15885,6 +15885,138 @@ int IntraPrediction::deriveIpmForTransform(CPelBuf predBuf, CodingUnit& cu
   return firstMode;
 }
 #endif
+
+#if JVET_AK0064_CCP_LFNST_NSPT && ENABLE_DIMD
+int IntraPrediction::deriveChromaIpmForTransform(CPelBuf predBufCb, CPelBuf predBufCr, CodingUnit& cu)
+{
+  if (!cu.slice->getSPS()->getUseDimd())
+  {
+    return PLANAR_IDX;
+  }
+  const Pel* pPredCb = predBufCb.buf;
+  const Pel* pPredCr = predBufCr.buf;
+  const int iStride = predBufCb.stride;
+  const int width = predBufCb.width;
+  const int height = predBufCb.height;
+
+  int histogram[NUM_LUMA_MODE] = { 0 };
+  int histogramCb[NUM_LUMA_MODE] = { 0 };
+  int histogramCr[NUM_LUMA_MODE] = { 0 };
+  int firstAmp = 0, curAmp = 0;
+  int firstMode = 0, curMode = 0;
+  if (!cu.firstPU->cs->pcv->isEncoder && cu.firstTU->jointCbCr == 1)
+  {
+    pPredCb = pPredCb + iStride + 1;
+    buildHistogram(pPredCb, iStride, height - 2, width - 2, histogramCb, 0, width - 2, height - 2);
+    for (int i = 0; i < NUM_LUMA_MODE; i++)
+    {
+      curAmp = histogramCb[i];
+      curMode = i;
+      if (curAmp > firstAmp)
+      {
+        firstAmp = curAmp;
+        firstMode = curMode;
+      }
+    }
+    cu.ccpChromaDimdMode[1] = firstMode;
+  }
+  else if (!cu.firstPU->cs->pcv->isEncoder && cu.firstTU->jointCbCr == 2)
+  {
+    //Cr
+    pPredCr = pPredCr + iStride + 1;
+    buildHistogram(pPredCr, iStride, height - 2, width - 2, histogramCr, 0, width - 2, height - 2);
+    firstAmp = 0;
+    curAmp = 0;
+    firstMode = 0;
+    curMode = 0;
+    for (int i = 0; i < NUM_LUMA_MODE; i++)
+    {
+      curAmp = histogramCr[i];
+      curMode = i;
+      if (curAmp > firstAmp)
+      {
+        firstAmp = curAmp;
+        firstMode = curMode;
+      }
+    }
+    cu.ccpChromaDimdMode[2] = firstMode;
+  }
+  else if (!cu.firstPU->cs->pcv->isEncoder)
+  {
+    pPredCb = pPredCb + iStride + 1;
+    buildHistogram(pPredCb, iStride, height - 2, width - 2, histogram, 0, width - 2, height - 2);
+    pPredCr = pPredCr + iStride + 1;
+    buildHistogram(pPredCr, iStride, height - 2, width - 2, histogram, 0, width - 2, height - 2);
+
+    int firstAmp = 0, curAmp = 0;
+    int firstMode = 0, curMode = 0;
+    for (int i = 0; i < NUM_LUMA_MODE; i++)
+    {
+      curAmp = histogram[i];
+      curMode = i;
+      if (curAmp > firstAmp)
+      {
+        firstAmp = curAmp;
+        firstMode = curMode;
+      }
+    }
+    cu.ccpChromaDimdMode[0] = cu.ccpChromaDimdMode[3] = firstMode;
+  }
+  else
+  {
+    pPredCb = pPredCb + iStride + 1;
+    buildHistogram(pPredCb, iStride, height - 2, width - 2, histogramCb, 0, width - 2, height - 2);
+    for (int i = 0; i < NUM_LUMA_MODE; i++)
+    {
+      curAmp = histogramCb[i];
+      curMode = i;
+      if (curAmp > firstAmp)
+      {
+        firstAmp = curAmp;
+        firstMode = curMode;
+      }
+    }
+    cu.ccpChromaDimdMode[1] = firstMode;
+
+    pPredCr = pPredCr + iStride + 1;
+    buildHistogram(pPredCr, iStride, height - 2, width - 2, histogramCr, 0, width - 2, height - 2);
+
+    firstAmp = 0;
+    curAmp = 0;
+    firstMode = 0;
+    curMode = 0;
+    for (int i = 0; i < NUM_LUMA_MODE; i++)
+    {
+      curAmp = histogramCr[i];
+      curMode = i;
+      if (curAmp > firstAmp)
+      {
+        firstAmp = curAmp;
+        firstMode = curMode;
+      }
+    }
+    cu.ccpChromaDimdMode[2] = firstMode;
+
+    firstAmp = 0;
+    curAmp = 0;
+    firstMode = 0;
+    curMode = 0;
+    for (int i = 0; i < NUM_LUMA_MODE; i++)
+    {
+      curAmp = histogramCb[i] + histogramCr[i];
+      curMode = i;
+      if (curAmp > firstAmp)
+      {
+        firstAmp = curAmp;
+        firstMode = curMode;
+      }
+    }
+    cu.ccpChromaDimdMode[0] = cu.ccpChromaDimdMode[3] = firstMode;
+  }
+  return firstMode;
+}
+#endif
+
 #if !JVET_AG0061_INTER_LFNST_NSPT
 int IntraPrediction::buildHistogram(const Pel *pReco, int iStride, uint32_t uiHeight, uint32_t uiWidth, int* piHistogram, int direction, int bw, int bh)
 {
diff --git a/source/Lib/CommonLib/IntraPrediction.h b/source/Lib/CommonLib/IntraPrediction.h
index b61ff3041..b127787ce 100644
--- a/source/Lib/CommonLib/IntraPrediction.h
+++ b/source/Lib/CommonLib/IntraPrediction.h
@@ -936,6 +936,9 @@ public:
 #endif 
   );
 #endif
+#if JVET_AK0064_CCP_LFNST_NSPT
+  static int deriveChromaIpmForTransform   (CPelBuf predBufCb, CPelBuf predBufCr, CodingUnit& cu);
+#endif
 #if !JVET_AG0061_INTER_LFNST_NSPT
   static int  buildHistogram               ( const Pel *pReco, int iStride, uint32_t uiHeight, uint32_t uiWidth, int* piHistogram, int direction, int bw, int bh );
 #endif
diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp
index da9689945..b36cfb158 100644
--- a/source/Lib/CommonLib/TrQuant.cpp
+++ b/source/Lib/CommonLib/TrQuant.cpp
@@ -511,7 +511,11 @@ void TrQuant::xInvLfnst( const TransformUnit &tu, const ComponentID compID )
     if( PU::isLMCMode( tu.cs->getPU( area.pos(), toChannelType( compID ) )->intraDir[ toChannelType( compID ) ] ) )
 #endif
     {
+#if JVET_AK0064_CCP_LFNST_NSPT
+      intraMode = tu.cu->ccpChromaDimdMode[tu.jointCbCr];
+#else
       intraMode = PU::getCoLocatedIntraLumaMode( *tu.cs->getPU( area.pos(), toChannelType( compID ) ) );
+#endif
 #if JVET_AJ0249_NEURAL_NETWORK_BASED
       if (intraMode == PNN_IDX)
       {
@@ -861,7 +865,11 @@ void TrQuant::xFwdLfnst( const TransformUnit &tu, const ComponentID compID, cons
     if( PU::isLMCMode( tu.cs->getPU( area.pos(), toChannelType( compID ) )->intraDir[ toChannelType( compID ) ] ) )
 #endif
     {
+#if JVET_AK0064_CCP_LFNST_NSPT
+      intraMode = tu.cu->ccpChromaDimdMode[tu.jointCbCr];
+#else
       intraMode = PU::getCoLocatedIntraLumaMode( *tu.cs->getPU( area.pos(), toChannelType( compID ) ) );
+#endif
 #if JVET_AJ0249_NEURAL_NETWORK_BASED
       if (intraMode == PNN_IDX)
       {
@@ -3288,7 +3296,11 @@ int TrQuant::getLfnstIdx(const TransformUnit &tu, ComponentID compID)
   if (PU::isLMCMode(tu.cs->getPU(area.pos(), toChannelType(compID))->intraDir[toChannelType(compID)]))
 #endif
   {
+#if JVET_AK0064_CCP_LFNST_NSPT
+    intraMode = tu.cu->ccpChromaDimdMode[tu.jointCbCr];
+#else
     intraMode = PU::getCoLocatedIntraLumaMode(*tu.cs->getPU(area.pos(), toChannelType(compID)));
+#endif
 #if JVET_AJ0249_NEURAL_NETWORK_BASED
     if (intraMode == PNN_IDX)
     {
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index 7bb087277..195cdb995 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -456,6 +456,7 @@
 #define JVET_AI0050_INTER_MTSS                            1 // JVET-AI0050: Multiple LFNST/NSPT kernel set selection for GPM coded block
 #define JVET_AI0050_SBT_LFNST                             1 // JVET-AI0050: Enable LFNST/NSPT for SBT coded block
 #define JVET_AJ0260_SBT_CORNER_MODE                       1 // JVET-AJ0260: Corner mode for SBT
+#define JVET_AK0064_CCP_LFNST_NSPT                        1 // JVET-AK0064: LFNST/NSPT set derivation for CCP coded block
 
 // Entropy Coding
 #define EC_HIGH_PRECISION                                 1 // CABAC high precision
diff --git a/source/Lib/CommonLib/Unit.cpp b/source/Lib/CommonLib/Unit.cpp
index efed9e153..fbb7fa3e1 100644
--- a/source/Lib/CommonLib/Unit.cpp
+++ b/source/Lib/CommonLib/Unit.cpp
@@ -333,6 +333,12 @@ CodingUnit& CodingUnit::operator=( const CodingUnit& other )
   dimdChromaModeSecond = other.dimdChromaModeSecond;
 #endif
 #endif
+#if JVET_AK0064_CCP_LFNST_NSPT
+  for (int i = 0; i < 4; i++)
+  {
+    ccpChromaDimdMode[i] = other.ccpChromaDimdMode[i];
+  }
+#endif
 #if JVET_AB0157_INTRA_FUSION
   for( int i = 0; i < DIMD_FUSION_NUM-1; i++ )
   {
@@ -684,6 +690,12 @@ void CodingUnit::initData()
   dimdChromaModeSecond = -1;
 #endif
 #endif
+#if JVET_AK0064_CCP_LFNST_NSPT
+  for (int i = 0; i < 4; i++)
+  {
+    ccpChromaDimdMode[i] = -1;
+  }
+#endif
 #if JVET_AB0157_INTRA_FUSION
   for( int i = 0; i < DIMD_FUSION_NUM-1; i++ )
   {
diff --git a/source/Lib/CommonLib/Unit.h b/source/Lib/CommonLib/Unit.h
index a7f8293cc..349e61b3d 100644
--- a/source/Lib/CommonLib/Unit.h
+++ b/source/Lib/CommonLib/Unit.h
@@ -360,6 +360,9 @@ struct CodingUnit : public UnitArea
   int8_t         dimdChromaModeSecond;
 #endif
 #endif
+#if JVET_AK0064_CCP_LFNST_NSPT
+  int8_t         ccpChromaDimdMode[4]; // jointCbCr: 0,1,2,3
+#endif
 #if JVET_AB0157_INTRA_FUSION
   int8_t         dimdBlendMode[DIMD_FUSION_NUM-1]; // max number of blend modes (the main mode is not counter) --> incoherent with dimdRelWeight
   int8_t         dimdRelWeight[DIMD_FUSION_NUM]; // max number of predictions to blend
diff --git a/source/Lib/CommonLib/UnitTools.cpp b/source/Lib/CommonLib/UnitTools.cpp
index 43c220383..cb31cc2ed 100644
--- a/source/Lib/CommonLib/UnitTools.cpp
+++ b/source/Lib/CommonLib/UnitTools.cpp
@@ -34116,7 +34116,11 @@ uint32_t PU::getFinalIntraModeForTransform( const TransformUnit &tu, const Compo
   if( PU::isLMCMode( tu.cs->getPU( area.pos(), toChannelType( compID ) )->intraDir[ toChannelType( compID ) ] ) )
 #endif
   {
+#if JVET_AK0064_CCP_LFNST_NSPT
+    intraMode = tu.cu->ccpChromaDimdMode[tu.jointCbCr];
+#else
     intraMode = PU::getCoLocatedIntraLumaMode( *tu.cs->getPU( area.pos(), toChannelType( compID ) ) );
+#endif
 #if JVET_AJ0249_NEURAL_NETWORK_BASED
     if (intraMode == PNN_IDX)
     {
diff --git a/source/Lib/DecoderLib/DecCu.cpp b/source/Lib/DecoderLib/DecCu.cpp
index 0321ce555..92355ef55 100644
--- a/source/Lib/DecoderLib/DecCu.cpp
+++ b/source/Lib/DecoderLib/DecCu.cpp
@@ -1143,6 +1143,10 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID )
 #endif
   if( compID != COMPONENT_Y && PU::isLMCMode( uiChFinalMode ) )
   {
+#if JVET_AK0064_CCP_LFNST_NSPT
+    if (compID == COMPONENT_Cb)
+    {
+#endif
 #if JVET_AD0188_CCP_MERGE
     PredictionUnit& pu = *tu.cu->firstPU;
 #else
@@ -1150,6 +1154,16 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID )
 #endif
     m_pcIntraPred->xGetLumaRecPixels( pu, area );
     m_pcIntraPred->predIntraChromaLM( compID, piPred, pu, area, uiChFinalMode );
+#if JVET_AK0064_CCP_LFNST_NSPT
+      {
+        CompArea areaCr = pu.Cr();
+        m_pcIntraPred->initIntraPatternChType(*tu.cu, areaCr);
+        PelBuf predCr = cs.getPredBuf(tu.blocks[COMPONENT_Cr]);
+        m_pcIntraPred->xGetLumaRecPixels(pu, areaCr);
+        m_pcIntraPred->predIntraChromaLM(COMPONENT_Cr, predCr, pu, areaCr, uiChFinalMode);
+      }
+    }
+#endif
   }
   else
   {
@@ -1419,6 +1433,8 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID )
     }
   }
 #if SIGN_PREDICTION
+#if JVET_AK0064_CCP_LFNST_NSPT
+  if (isJCCR && compID == COMPONENT_Cb && !PU::isLMCMode(uiChFinalMode))
 #if JVET_AA0057_CCCM
 #if JVET_AD0188_CCP_MERGE
   if (isJCCR && compID == COMPONENT_Cb && !pu.cccmFlag && !pu.idxNonLocalCCP
@@ -1439,6 +1455,7 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID )
     && !pu.decoderDerivedCcpMode
 #endif
     )
+#endif
 #endif
   {
     m_pcIntraPred->initIntraPatternChType(*tu.cu, areaCr);
@@ -1492,6 +1509,16 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID )
     }
   }
   }
+#endif
+#if JVET_AK0064_CCP_LFNST_NSPT
+  if (compID != COMPONENT_Y)
+  {
+    if (compID == COMPONENT_Cb && PU::isLMCMode(uiChFinalMode) && tu.cu->lfnstIdx)
+    {
+      PelBuf predCr = cs.getPredBuf(tu.blocks[COMPONENT_Cr]);
+      IntraPrediction::deriveChromaIpmForTransform(piPred, predCr, *pu.cu);
+    }
+  }
 #endif
   const Slice           &slice = *cs.slice;
   bool flag = slice.getLmcsEnabledFlag() && (slice.isIntra() || (!slice.isIntra() && m_pcReshape->getCTUFlag()));
diff --git a/source/Lib/EncoderLib/IntraSearch.cpp b/source/Lib/EncoderLib/IntraSearch.cpp
index 71f2249e2..ef107c7bc 100644
--- a/source/Lib/EncoderLib/IntraSearch.cpp
+++ b/source/Lib/EncoderLib/IntraSearch.cpp
@@ -14516,6 +14516,13 @@ ChromaCbfs IntraSearch::xRecurIntraChromaCodingQT( CodingStructure &cs, Partitio
 #endif
       }
 
+#if JVET_AK0064_CCP_LFNST_NSPT
+      if (PU::isLMCMode(predMode) && currTU.cu->lfnstIdx)
+      {
+        IntraPrediction::deriveChromaIpmForTransform(piPredCb, piPredCr, *pu.cu);
+      }
+#endif
+
     // determination of chroma residuals including reshaping and cross-component prediction
     //----- get chroma residuals -----
     PelBuf resiCb  = cs.getResiBuf(cbArea);
-- 
GitLab