From f05b549a82532ece4257e8f998ac94d2494aeddd Mon Sep 17 00:00:00 2001
From: Mischa Siekmann <mischa.siekmann@hhi.fraunhofer.de>
Date: Thu, 18 Jul 2019 16:53:09 +0200
Subject: [PATCH] JVET-O0094 integration

---
 source/Lib/CommonLib/ContextModelling.h | 23 ++++++++-
 source/Lib/CommonLib/DepQuant.cpp       | 61 ++++++++++++++++++++++-
 source/Lib/CommonLib/TrQuant.cpp        | 64 +++++++++++++++++++++++--
 source/Lib/CommonLib/TypeDef.h          |  3 +-
 source/Lib/CommonLib/UnitTools.cpp      |  5 +-
 source/Lib/DecoderLib/CABACReader.cpp   | 39 +++++++++++++--
 source/Lib/DecoderLib/CABACReader.h     |  8 ++++
 source/Lib/EncoderLib/CABACWriter.cpp   | 29 ++++++++++-
 source/Lib/EncoderLib/CABACWriter.h     |  4 ++
 9 files changed, 222 insertions(+), 14 deletions(-)

diff --git a/source/Lib/CommonLib/ContextModelling.h b/source/Lib/CommonLib/ContextModelling.h
index 18e2654ac3..35fef1447e 100644
--- a/source/Lib/CommonLib/ContextModelling.h
+++ b/source/Lib/CommonLib/ContextModelling.h
@@ -289,17 +289,36 @@ class CUCtx
 public:
   CUCtx()              : isDQPCoded(false), isChromaQpAdjCoded(false),
                          qgStart(false),
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+                         numNonZeroCoeffNonTs(0)
+                         {
+                           violatesLfnstConstrained[CHANNEL_TYPE_LUMA  ] = false;
+                           violatesLfnstConstrained[CHANNEL_TYPE_CHROMA] = false;
+                         }
+#else
                          numNonZeroCoeffNonTs(0) {}
+#endif
   CUCtx(int _qp)       : isDQPCoded(false), isChromaQpAdjCoded(false),
                          qgStart(false),
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+                         numNonZeroCoeffNonTs(0), qp(_qp)
+                         {
+                           violatesLfnstConstrained[CHANNEL_TYPE_LUMA  ] = false;
+                           violatesLfnstConstrained[CHANNEL_TYPE_CHROMA] = false;
+                         }
+#else
                          numNonZeroCoeffNonTs(0), qp(_qp) {}
+#endif
   ~CUCtx() {}
 public:
   bool      isDQPCoded;
   bool      isChromaQpAdjCoded;
   bool      qgStart;
-  uint32_t      numNonZeroCoeffNonTs;
-  int8_t     qp;                   // used as a previous(last) QP and for QP prediction
+  uint32_t  numNonZeroCoeffNonTs;
+  int8_t    qp;                   // used as a previous(last) QP and for QP prediction
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  bool      violatesLfnstConstrained[MAX_NUM_CHANNEL_TYPE];
+#endif
 };
 
 class MergeCtx
diff --git a/source/Lib/CommonLib/DepQuant.cpp b/source/Lib/CommonLib/DepQuant.cpp
index a682bad02c..e7e1428753 100644
--- a/source/Lib/CommonLib/DepQuant.cpp
+++ b/source/Lib/CommonLib/DepQuant.cpp
@@ -869,13 +869,17 @@ namespace DQIntern
       m_goRicePar     = 0;
       m_goRiceZero    = 0;
     }
-
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+    void checkRdCosts( const ScanPosType spt, const PQData& pqDataA, const PQData& pqDataB, Decision& decisionA, Decision& decisionB ) const
+#else
     void checkRdCosts( const ScanPosType spt, const PQData& pqDataA, const PQData& pqDataB, Decision& decisionA, Decision& decisionB, bool zeroOut ) const
+#endif
     {
       const int32_t*  goRiceTab = g_goRiceBits[m_goRicePar];
       int64_t         rdCostA   = m_rdCost + pqDataA.deltaDist;
       int64_t         rdCostB   = m_rdCost + pqDataB.deltaDist;
       int64_t         rdCostZ   = m_rdCost;
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
       if( zeroOut )
       {
         rdCostZ = m_rdCost;
@@ -911,6 +915,7 @@ namespace DQIntern
       }
       else
       {
+#endif
         if( m_remRegBins >= 4 )
         {
           if( pqDataA.absLevel < 4 )
@@ -975,7 +980,9 @@ namespace DQIntern
           decisionB.prevId = m_stateId;
         }
       }
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
     }
+#endif
 
     inline void checkRdCostStart(int32_t lastOffset, const PQData &pqData, Decision &decision) const
     {
@@ -1370,15 +1377,36 @@ namespace DQIntern
   {
     ::memcpy( decisions, startDec, 8*sizeof(Decision) );
 
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+    if( zeroOut )
+    {
+      if( spt==SCAN_EOCSBB )
+      {
+        m_skipStates[0].checkRdCostSkipSbbZeroOut( decisions[0] );
+        m_skipStates[1].checkRdCostSkipSbbZeroOut( decisions[1] );
+        m_skipStates[2].checkRdCostSkipSbbZeroOut( decisions[2] );
+        m_skipStates[3].checkRdCostSkipSbbZeroOut( decisions[3] );
+      }
+      return;
+    }
+#endif
 
     PQData  pqData[4];
     m_quant.preQuantCoeff( absCoeff, pqData, quanCoeff );
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+    m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2]);
+    m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0]);
+    m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3]);
+    m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1]);
+#else
     m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2], zeroOut );
     m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0], zeroOut );
     m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3], zeroOut );
     m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1], zeroOut );
+#endif
     if( spt==SCAN_EOCSBB )
     {
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
       if( zeroOut )
       {
         m_skipStates[0].checkRdCostSkipSbbZeroOut( decisions[0] );
@@ -1388,17 +1416,25 @@ namespace DQIntern
       }
       else
       {
+#endif
         m_skipStates[0].checkRdCostSkipSbb( decisions[0] );
         m_skipStates[1].checkRdCostSkipSbb( decisions[1] );
         m_skipStates[2].checkRdCostSkipSbb( decisions[2] );
         m_skipStates[3].checkRdCostSkipSbb( decisions[3] );
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
       }
+#endif
     }
+
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
     if( !zeroOut )
     {
+#endif
     m_startState.checkRdCostStart( lastOffset, pqData[0], decisions[0] );
     m_startState.checkRdCostStart( lastOffset, pqData[2], decisions[2] );
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
     }
+#endif
   }
 
   void DepQuant::xDecideAndUpdate( const TCoeff absCoeff, const ScanInfo& scanInfo, bool zeroOut, int quantCoeff )
@@ -1420,7 +1456,11 @@ namespace DQIntern
         m_currStates[3].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[3] );
         ::memcpy( decisions+4, decisions, 4*sizeof(Decision) );
       }
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+      else if( !zeroOut )
+#else
       else
+#endif
       {
         switch( scanInfo.nextNbInfoSbb.num )
         {
@@ -1505,10 +1545,17 @@ namespace DQIntern
     zeroOutforThres = zeroOut || (32 < tuPars.m_height || 32 < tuPars.m_width);
     //===== find first test position =====
     int firstTestPos = numCoeff - 1;
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+    if( lfnstIdx > 0 && tu.mtsIdx != MTS_SKIP && width >= 4 && height >= 4 )
+    {
+      firstTestPos = ( ( width == 4 && height == 4 ) || ( width == 8 && height == 8 ) )  ? 7 : 15 ;
+    }
+#else
     if( lfnstIdx > 0 && tu.mtsIdx != MTS_SKIP && ( ( width == 4 && height == 4 ) || ( width == 8 && height == 8 ) ) )
     {
       firstTestPos = 7;
     }
+#endif
     const TCoeff defaultQuantisationCoefficient = (TCoeff)m_quant.getQScale();
     const TCoeff thres = m_quant.getLastThreshold();
     for( ; firstTestPos >= 0; firstTestPos-- )
@@ -1543,15 +1590,25 @@ namespace DQIntern
     for( int scanIdx = firstTestPos; scanIdx >= 0; scanIdx-- )
     {
       const ScanInfo& scanInfo = tuPars.m_scanInfo[ scanIdx ];
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
       bool lfnstZeroOut = lfnstIdx > 0 && tu.mtsIdx != MTS_SKIP && width >= 4 && height >= 4 &&
         ( ( ( ( width >= 8 && height >= 8 ) && scanIdx >= 16 ) || ( ( ( width == 4 && height == 4 ) || ( width == 8 && height == 8 ) ) && scanIdx >= 8 ) ) && scanIdx < 48 );
       if (enableScalingLists)
       {
         m_quant.initQuantBlock(tu, compID, cQP, lambda, quantCoeff[scanInfo.rasterPos]);
         xDecideAndUpdate( abs( tCoeff[scanInfo.rasterPos]), scanInfo, (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)) || lfnstZeroOut, quantCoeff[scanInfo.rasterPos] );
-	    }
+      }
       else
         xDecideAndUpdate( abs( tCoeff[scanInfo.rasterPos]), scanInfo, (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)) || lfnstZeroOut, defaultQuantisationCoefficient );
+#else
+      if (enableScalingLists)
+      {
+        m_quant.initQuantBlock(tu, compID, cQP, lambda, quantCoeff[scanInfo.rasterPos]);
+        xDecideAndUpdate( abs( tCoeff[scanInfo.rasterPos]), scanInfo, (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)), quantCoeff[scanInfo.rasterPos] );
+      }
+      else
+        xDecideAndUpdate( abs( tCoeff[scanInfo.rasterPos]), scanInfo, (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)), defaultQuantisationCoefficient );
+#endif
     }
 
     //===== find best path =====
diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp
index 51e628011f..e79da6fc1e 100644
--- a/source/Lib/CommonLib/TrQuant.cpp
+++ b/source/Lib/CommonLib/TrQuant.cpp
@@ -272,13 +272,15 @@ void TrQuant::xInvLfnst( const TransformUnit &tu, const ComponentID compID )
 #endif
       bool          transposeFlag   = getTransposeFlag( intraMode );
       const int     sbSize          = whge3 ? 8 : 4;
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
       const int     subGrpXMax      = ( height == 4 && width  > 8 ) ? 2 : 1;
       const int     subGrpYMax      = ( width  == 4 && height > 8 ) ? 2 : 1;
+#endif
       bool          tu4x4Flag       = ( width == 4 && height == 4 );
       bool          tu8x8Flag       = ( width == 8 && height == 8 );
       TCoeff*       lfnstTemp;
       TCoeff*       coeffTemp;
-
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
       for( int subGroupX = 0; subGroupX < subGrpXMax; subGroupX++ )
       {
         for( int subGroupY = 0; subGroupY < subGrpYMax; subGroupY++ )
@@ -288,7 +290,11 @@ void TrQuant::xInvLfnst( const TransformUnit &tu, const ComponentID compID )
           int y;
           lfnstTemp = m_tempInMatrix; // inverse spectral rearrangement
           coeffTemp = m_plTempCoeff + offsetX + offsetY;
-
+#else
+          int y;
+          lfnstTemp = m_tempInMatrix; // inverse spectral rearrangement
+          coeffTemp = m_plTempCoeff;
+#endif
           TCoeff * dst = lfnstTemp;
           const ScanElement * scanPtr = scan;
           for( y = 0; y < 16; y++ )
@@ -339,8 +345,10 @@ void TrQuant::xInvLfnst( const TransformUnit &tu, const ComponentID compID )
               coeffTemp += width;
             }
           }
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
         }
       } // subGroupX
+#endif
     }
   }
 }
@@ -370,14 +378,17 @@ void TrQuant::xFwdLfnst( const TransformUnit &tu, const ComponentID compID, cons
 
       bool            transposeFlag   = getTransposeFlag( intraMode );
       const int       sbSize          = whge3 ? 8 : 4;
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
       const int       subGrpXMax      = ( height == 4 && width  > 8 ) ? 2 : 1;
       const int       subGrpYMax      = ( width  == 4 && height > 8 ) ? 2 : 1;
+#endif
       bool            tu4x4Flag       = ( width == 4 && height == 4 );
       bool            tu8x8Flag       = ( width == 8 && height == 8 );
       TCoeff*         lfnstTemp;
       TCoeff*         coeffTemp;
       TCoeff*         tempCoeff       = loadTr ? m_mtsCoeffs[ tu.mtsIdx ] : m_plTempCoeff;
 
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
       for( int subGroupX = 0; subGroupX < subGrpXMax; subGroupX++ )
       {
         for( int subGroupY = 0; subGroupY < subGrpYMax; subGroupY++ )
@@ -387,6 +398,11 @@ void TrQuant::xFwdLfnst( const TransformUnit &tu, const ComponentID compID, cons
           int y;
           lfnstTemp = m_tempInMatrix; // forward low frequency non-separable transform
           coeffTemp = tempCoeff + offsetX + offsetY;
+#else
+          int y;
+          lfnstTemp = m_tempInMatrix; // forward low frequency non-separable transform
+          coeffTemp = tempCoeff;
+#endif
 
           if( transposeFlag )
           {
@@ -430,8 +446,11 @@ void TrQuant::xFwdLfnst( const TransformUnit &tu, const ComponentID compID, cons
           fwdLfnstNxN( m_tempInMatrix, m_tempOutMatrix, g_lfnstLut[ intraMode ], lfnstIdx - 1, sbSize, ( tu4x4Flag || tu8x8Flag ) ? 8 : 16 );
 
           lfnstTemp = m_tempOutMatrix; // forward spectral rearrangement
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
           coeffTemp = tempCoeff + offsetX + offsetY;
-
+#else
+          coeffTemp = tempCoeff;
+#endif
           const ScanElement * scanPtr = scan;
           int lfnstCoeffNum = ( sbSize == 4 ) ? sbSize * sbSize : 48;
           for( y = 0; y < lfnstCoeffNum; y++ )
@@ -439,8 +458,10 @@ void TrQuant::xFwdLfnst( const TransformUnit &tu, const ComponentID compID, cons
             coeffTemp[ scanPtr->idx ] = *lfnstTemp++;
             scanPtr++;
           }
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
         }
       } // subGroupX
+#endif
     }
   }
 }
@@ -652,8 +673,26 @@ void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPel
 
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
 
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
   const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
+#else
+  int  skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
+  int  skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
+  if( tu.cs->sps->getUseLFNST() && tu.cu->lfnstIdx )
+  {
+    if( (width == 4 && height > 4) || (width > 4 && height == 4) )
+    {
+      skipWidth  = width  - 4;
+      skipHeight = height - 4;
+    }
+    else if( (width >= 8 && height >= 8) )
+    {
+      skipWidth  = width  - 8;
+      skipHeight = height - 8;
+    }
+  }
+#endif
 
 #if RExt__DECODER_DEBUG_TOOL_STATISTICS
   if ( trTypeHor != DCT2 )
@@ -719,9 +758,26 @@ void TrQuant::xIT( const TransformUnit &tu, const ComponentID &compID, const CCo
   int trTypeVer = DCT2;
 
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
-
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
   const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
+#else
+  int skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
+  int skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
+  if( tu.cs->sps->getUseLFNST() && tu.cu->lfnstIdx )
+  {
+    if( (width == 4 && height > 4) || (width > 4 && height == 4) )
+    {
+      skipWidth  = width  - 4;
+      skipHeight = height - 4;
+    }
+    else if( (width >= 8 && height >= 8) )
+    {
+      skipWidth  = width  - 8;
+      skipHeight = height - 8;
+    }
+  }
+#endif
 
   TCoeff *block = ( TCoeff * ) alloca( width * height * sizeof( TCoeff ) );
 
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index 9c54b0e6f0..dedfda1c8c 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -71,7 +71,6 @@
 #define JVET_O0429_CRS_LAMBDA_FIX                         1 // JVET-O0429: fix encoder lambda rounding used in CRS
 
 #define JVET_O0428_LMCS_CLEANUP                           1 // JVET-O0428: LMCS cleanups
-
 #define JVET_O0164_REMOVE_AMVP_SPATIAL_SCALING            1 // JVET-O0164/JVET-O0587: remove spatial AMVP candidate scaling
 
 #define JVET_O0162_IBC_MVP_FLAG                           1 // JVET-O0162/O0331/O0480/O0574: IBC mvp flag conditioned on MaxNumMergeCand>1
@@ -85,6 +84,8 @@
 #define JVET_O0364_PDPC_DC                                1 // JVET-O0364 Part 4: align PDPC process for DC with the one for Planar
 #define JVET_O0364_PDPC_ANGULAR                           1 // JVET-O0364 Part 5: simplify PDPC process for angular modes
 
+#define JVET_O0049_LFNST_ZERO_PRIM_COEFFS                 1 // JVET-O0049: CE6-2.1a, LFNST involves zeroing of primary only coefficient positions
+
 #define FIX_DB_MAX_TRANSFORM_SIZE                         1
 
 #define MRG_SHARELIST_SHARSIZE                            32
diff --git a/source/Lib/CommonLib/UnitTools.cpp b/source/Lib/CommonLib/UnitTools.cpp
index ee6c152191..f011cb94ff 100644
--- a/source/Lib/CommonLib/UnitTools.cpp
+++ b/source/Lib/CommonLib/UnitTools.cpp
@@ -261,6 +261,7 @@ uint32_t CU::getNumNonZeroCoeffNonTs( const CodingUnit& cu, const bool lumaFlag,
   return count;
 }
 
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
 uint32_t CU::getNumNonZeroCoeffNonTsCorner8x8( const CodingUnit& cu, const bool lumaFlag, const bool chromaFlag )
 {
   uint32_t count = 0;
@@ -271,6 +272,7 @@ uint32_t CU::getNumNonZeroCoeffNonTsCorner8x8( const CodingUnit& cu, const bool
 
   return count;
 }
+#endif
 
 bool CU::divideTuInRows( const CodingUnit &cu )
 {
@@ -4533,7 +4535,7 @@ uint32_t TU::getNumNonZeroCoeffsNonTS( const TransformUnit& tu, const bool bLuma
   }
   return count;
 }
-
+#if !JVET_O0049_LFNST_ZERO_PRIM_COEFFS
 uint32_t TU::getNumNonZeroCoeffsNonTSCorner8x8( const TransformUnit& tu, const bool lumaFlag, const bool chromaFlag )
 {
   const uint32_t lumaWidth       = tu.blocks[ 0 ].width,  chromaWidth  = tu.blocks[ 1 ].width;
@@ -4575,6 +4577,7 @@ uint32_t TU::getNumNonZeroCoeffsNonTSCorner8x8( const TransformUnit& tu, const b
   }
   return count;
 }
+#endif
 
 bool TU::needsSqrt2Scale( const TransformUnit &tu, const ComponentID &compID )
 {
diff --git a/source/Lib/DecoderLib/CABACReader.cpp b/source/Lib/DecoderLib/CABACReader.cpp
index 3f017485f4..c3f49d8841 100644
--- a/source/Lib/DecoderLib/CABACReader.cpp
+++ b/source/Lib/DecoderLib/CABACReader.cpp
@@ -1355,6 +1355,10 @@ void CABACReader::cu_residual( CodingUnit& cu, Partitioner &partitioner, CUCtx&
       return;
     }
   }
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_LUMA]   = false;
+  cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_CHROMA] = false;
+#endif
 
   ChromaCbfs chromaCbfs;
   if( cu.ispMode && isLuma( partitioner.chType ) )
@@ -1366,8 +1370,11 @@ void CABACReader::cu_residual( CodingUnit& cu, Partitioner &partitioner, CUCtx&
   {
     transform_tree( *cu.cs, partitioner, cuCtx, chromaCbfs );
   }
-
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  residual_lfnst_mode( cu, cuCtx );
+#else
   residual_lfnst_mode( cu );
+#endif
 }
 
 void CABACReader::rqt_root_cbf( CodingUnit& cu )
@@ -2336,7 +2343,11 @@ void CABACReader::transform_unit( TransformUnit& tu, CUCtx& cuCtx, ChromaCbfs& c
     }
     if( cbfLuma )
     {
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+      residual_coding( tu, COMPONENT_Y, cuCtx );
+#else
       residual_coding( tu, COMPONENT_Y );
+#endif
     }
     if( !lumaOnly )
     {
@@ -2348,7 +2359,11 @@ void CABACReader::transform_unit( TransformUnit& tu, CUCtx& cuCtx, ChromaCbfs& c
         }
         if( tu.cbf[ compID ] )
         {
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+          residual_coding( tu, compID, cuCtx );
+#else
           residual_coding( tu, compID );
+#endif
         }
       }
     }
@@ -2415,7 +2430,11 @@ void CABACReader::joint_cb_cr( TransformUnit& tu )
   tu.jointCbCr = m_BinDecoder.decodeBin( Ctx::JointCbCrFlag( 0 ) );
 }
 
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+void CABACReader::residual_coding( TransformUnit& tu, ComponentID compID, CUCtx& cuCtx )
+#else
 void CABACReader::residual_coding( TransformUnit& tu, ComponentID compID )
+#endif
 {
   const CodingUnit& cu = *tu.cu;
   DTRACE( g_trace_ctx, D_SYNTAX, "residual_coding() etype=%d pos=(%d,%d) size=%dx%d predMode=%d\n", tu.blocks[compID].compID, tu.blocks[compID].x, tu.blocks[compID].y, tu.blocks[compID].width, tu.blocks[compID].height, cu.predMode );
@@ -2458,7 +2477,13 @@ void CABACReader::residual_coding( TransformUnit& tu, ComponentID compID )
 
   // parse last coeff position
   cctx.setScanPosLast( last_sig_coeff( cctx, tu, compID ) );
-
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  if( tu.mtsIdx != MTS_SKIP && tu.blocks[ compID ].height >= 4 && tu.blocks[ compID ].width >= 4 )
+  {
+    const int maxLfnstPos = ((tu.blocks[compID].height == 4 && tu.blocks[compID].width == 4) || (tu.blocks[compID].height == 8 && tu.blocks[compID].width == 8)) ? 7 : 15;
+    cuCtx.violatesLfnstConstrained[ toChannelType(compID) ] |= cctx.scanPosLast() > maxLfnstPos;
+  }
+#endif
   // parse subblocks
   const int stateTransTab = ( tu.cs->slice->getDepQuantEnabledFlag() ? 32040 : 0 );
   int       state         = 0;
@@ -2594,7 +2619,11 @@ void CABACReader::explicit_rdpcm_mode( TransformUnit& tu, ComponentID compID )
   }
 }
 
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+void CABACReader::residual_lfnst_mode( CodingUnit& cu,  CUCtx& cuCtx  )
+#else
 void CABACReader::residual_lfnst_mode( CodingUnit& cu )
+#endif
 {
   if( cu.ispMode != NOT_INTRA_SUBPARTITIONS || cu.mipFlag == true ||
     ( CS::isDualITree( *cu.cs ) && cu.chType == CHANNEL_TYPE_CHROMA && std::min( cu.blocks[ 1 ].width, cu.blocks[ 1 ].height ) < 4 ) )
@@ -2609,7 +2638,11 @@ void CABACReader::residual_lfnst_mode( CodingUnit& cu )
     const bool lumaFlag              = CS::isDualITree( *cu.cs ) ? (   isLuma( cu.chType ) ? true : false ) : true;
     const bool chromaFlag            = CS::isDualITree( *cu.cs ) ? ( isChroma( cu.chType ) ? true : false ) : true;
     bool nonZeroCoeffNonTs;
-    bool nonZeroCoeffNonTsCorner8x8  = CU::getNumNonZeroCoeffNonTsCorner8x8( cu, lumaFlag, chromaFlag ) > 0;
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+    bool nonZeroCoeffNonTsCorner8x8 = ( lumaFlag && cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_LUMA] ) || (chromaFlag && cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_CHROMA] );
+#else
+    bool nonZeroCoeffNonTsCorner8x8 = CU::getNumNonZeroCoeffNonTsCorner8x8( cu, lumaFlag, chromaFlag ) > 0;
+#endif
     const int  nonZeroCoeffThr       = CS::isDualITree( *cu.cs ) ? ( isLuma( cu.chType ) ? LFNST_SIG_NZ_LUMA : LFNST_SIG_NZ_CHROMA ) : LFNST_SIG_NZ_LUMA + LFNST_SIG_NZ_CHROMA;
     nonZeroCoeffNonTs = CU::getNumNonZeroCoeffNonTs( cu, lumaFlag, chromaFlag ) > nonZeroCoeffThr;
 
diff --git a/source/Lib/DecoderLib/CABACReader.h b/source/Lib/DecoderLib/CABACReader.h
index 72868ea7f0..68fa2c3bf2 100644
--- a/source/Lib/DecoderLib/CABACReader.h
+++ b/source/Lib/DecoderLib/CABACReader.h
@@ -129,9 +129,17 @@ public:
   void        cu_chroma_qp_offset       ( CodingUnit&                   cu );
 
   // residual coding (clause 7.3.8.11)
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  void        residual_coding           ( TransformUnit&                tu,     ComponentID     compID, CUCtx& cuCtx );
+#else
   void        residual_coding           ( TransformUnit&                tu,     ComponentID     compID );
+#endif
   void        mts_coding                ( TransformUnit&                tu,     ComponentID     compID );
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  void        residual_lfnst_mode       ( CodingUnit&                   cu,     CUCtx&          cuCtx  );
+#else
   void        residual_lfnst_mode       ( CodingUnit&                   cu );
+#endif
   void        isp_mode                  ( CodingUnit&                   cu );
   void        explicit_rdpcm_mode       ( TransformUnit&                tu,     ComponentID     compID );
   int         last_sig_coeff            ( CoeffCodingContext&           cctx,   TransformUnit& tu, ComponentID   compID );
diff --git a/source/Lib/EncoderLib/CABACWriter.cpp b/source/Lib/EncoderLib/CABACWriter.cpp
index 16bf915584..101cdb7f3e 100644
--- a/source/Lib/EncoderLib/CABACWriter.cpp
+++ b/source/Lib/EncoderLib/CABACWriter.cpp
@@ -1249,6 +1249,10 @@ void CABACWriter::cu_residual( const CodingUnit& cu, Partitioner& partitioner, C
     }
   }
 
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_LUMA]   = false;
+  cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_CHROMA] = false;
+#endif
 
   ChromaCbfs chromaCbfs;
   if( cu.ispMode && isLuma( partitioner.chType ) )
@@ -2215,7 +2219,11 @@ void CABACWriter::transform_unit( const TransformUnit& tu, CUCtx& cuCtx, ChromaC
     }
     if( cbfLuma )
     {
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+      residual_coding( tu, COMPONENT_Y, &cuCtx );
+#else
       residual_coding( tu, COMPONENT_Y );
+#endif
     }
     if( !lumaOnly )
     {
@@ -2227,7 +2235,11 @@ void CABACWriter::transform_unit( const TransformUnit& tu, CUCtx& cuCtx, ChromaC
         }
         if( cbf[ compID ] )
         {
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+          residual_coding( tu, compID, &cuCtx );
+#else
           residual_coding( tu, compID );
+#endif
         }
       }
     }
@@ -2295,7 +2307,11 @@ void CABACWriter::joint_cb_cr( const TransformUnit& tu )
   m_BinEncoder.encodeBin( tu.jointCbCr ? 1 : 0, Ctx::JointCbCrFlag( 0 ) );
 }
 
-void CABACWriter::residual_coding( const TransformUnit& tu, ComponentID compID )
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+void CABACWriter::residual_coding( const TransformUnit& tu, ComponentID compID, CUCtx* cuCtx )
+#else
+void CABACWriter::residual_coding( const TransformUnit& tu, ComponentID compID)
+#endif
 {
   const CodingUnit& cu = *tu.cu;
   DTRACE( g_trace_ctx, D_SYNTAX, "residual_coding() etype=%d pos=(%d,%d) size=%dx%d predMode=%d\n", tu.blocks[compID].compID, tu.blocks[compID].x, tu.blocks[compID].y, tu.blocks[compID].width, tu.blocks[compID].height, cu.predMode );
@@ -2351,6 +2367,13 @@ void CABACWriter::residual_coding( const TransformUnit& tu, ComponentID compID )
   CHECK( scanPosLast < 0, "Coefficient coding called for empty TU" );
   cctx.setScanPosLast(scanPosLast);
 
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  if( cuCtx && tu.mtsIdx != MTS_SKIP && tu.blocks[ compID ].height >= 4 && tu.blocks[ compID ].width >= 4 )
+  {
+    const int maxLfnstPos = ((tu.blocks[compID].height == 4 && tu.blocks[compID].width == 4) || (tu.blocks[compID].height == 8 && tu.blocks[compID].width == 8)) ? 7 : 15;
+    cuCtx->violatesLfnstConstrained[ toChannelType(compID) ] |= cctx.scanPosLast() > maxLfnstPos;
+  }
+#endif
   // code last coeff position
   last_sig_coeff( cctx, tu, compID );
 
@@ -2487,7 +2510,11 @@ void CABACWriter::residual_lfnst_mode( const CodingUnit& cu, CUCtx& cuCtx )
     const bool lumaFlag                   = CS::isDualITree( *cu.cs ) ? (   isLuma( cu.chType ) ? true : false ) : true;
     const bool chromaFlag                 = CS::isDualITree( *cu.cs ) ? ( isChroma( cu.chType ) ? true : false ) : true;
           bool nonZeroCoeffNonTs;
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+          bool nonZeroCoeffNonTsCorner8x8 = ( lumaFlag && cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_LUMA] ) || (chromaFlag && cuCtx.violatesLfnstConstrained[CHANNEL_TYPE_CHROMA] );
+#else
           bool nonZeroCoeffNonTsCorner8x8 = CU::getNumNonZeroCoeffNonTsCorner8x8( cu, lumaFlag, chromaFlag ) > 0;
+#endif
     const int  nonZeroCoeffThr            = CS::isDualITree( *cu.cs ) ? ( isLuma( cu.chType ) ? LFNST_SIG_NZ_LUMA : LFNST_SIG_NZ_CHROMA ) : LFNST_SIG_NZ_LUMA + LFNST_SIG_NZ_CHROMA;
     cuCtx.numNonZeroCoeffNonTs            = CU::getNumNonZeroCoeffNonTs( cu, lumaFlag, chromaFlag );
     nonZeroCoeffNonTs                     = cuCtx.numNonZeroCoeffNonTs > nonZeroCoeffThr;
diff --git a/source/Lib/EncoderLib/CABACWriter.h b/source/Lib/EncoderLib/CABACWriter.h
index f67bb0e65c..f9ef204559 100644
--- a/source/Lib/EncoderLib/CABACWriter.h
+++ b/source/Lib/EncoderLib/CABACWriter.h
@@ -141,7 +141,11 @@ public:
   void        cu_chroma_qp_offset       ( const CodingUnit&             cu );
 
   // residual coding (clause 7.3.8.11)
+#if JVET_O0049_LFNST_ZERO_PRIM_COEFFS
+  void        residual_coding           ( const TransformUnit&          tu,       ComponentID       compID, CUCtx* cuCtx = nullptr );
+#else
   void        residual_coding           ( const TransformUnit&          tu,       ComponentID       compID );
+#endif
   void        mts_coding                ( const TransformUnit&          tu,       ComponentID       compID );
   void        residual_lfnst_mode       ( const CodingUnit&             cu,       CUCtx&            cuCtx );
   void        isp_mode                  ( const CodingUnit&             cu );
-- 
GitLab