From 8ac636d5c7addeeae86756e93e594d651e095cc9 Mon Sep 17 00:00:00 2001
From: "moonmo.koo" <moonmo.koo@lge.com>
Date: Fri, 1 Feb 2019 21:42:54 +0900
Subject: [PATCH] Migration of JVET-M0297 Test 2

---
 source/Lib/CommonLib/ContextModelling.h |   4 +
 source/Lib/CommonLib/DepQuant.cpp       | 141 ++++++++++++++++++++++++
 source/Lib/CommonLib/TrQuant.cpp        |  14 +++
 source/Lib/CommonLib/TypeDef.h          |   2 +
 source/Lib/DecoderLib/CABACReader.cpp   |  62 +++++++++++
 source/Lib/DecoderLib/CABACReader.h     |   5 +
 source/Lib/EncoderLib/CABACWriter.cpp   |  54 +++++++++
 source/Lib/EncoderLib/CABACWriter.h     |   7 +-
 8 files changed, 288 insertions(+), 1 deletion(-)

diff --git a/source/Lib/CommonLib/ContextModelling.h b/source/Lib/CommonLib/ContextModelling.h
index 4f2c0d90..930177fc 100644
--- a/source/Lib/CommonLib/ContextModelling.h
+++ b/source/Lib/CommonLib/ContextModelling.h
@@ -70,6 +70,10 @@ public:
   int             cgPosX          ()                        const { return m_subSetPosX; }
   unsigned        width           ()                        const { return m_width; }
   unsigned        height          ()                        const { return m_height; }
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  unsigned        log2CGWidth     ()                        const { return m_log2CGWidth; }
+  unsigned        log2CGHeight    ()                        const { return m_log2CGHeight; }
+#endif
   unsigned        log2CGSize      ()                        const { return m_log2CGSize; }
   unsigned        log2BlockWidth  ()                        const { return m_log2BlockWidth; }
   unsigned        log2BlockHeight ()                        const { return m_log2BlockHeight; }
diff --git a/source/Lib/CommonLib/DepQuant.cpp b/source/Lib/CommonLib/DepQuant.cpp
index 5373fe7c..e9fe2439 100644
--- a/source/Lib/CommonLib/DepQuant.cpp
+++ b/source/Lib/CommonLib/DepQuant.cpp
@@ -87,6 +87,10 @@ namespace DQIntern
     NbInfoSbb     nextNbInfoSbb;
     int           nextSbbRight;
     int           nextSbbBelow;
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    int           posX;
+    int           posY;
+#endif
   };
 
   class Rom;
@@ -382,6 +386,10 @@ namespace DQIntern
       scanInfo.spt      = SCAN_SOCSBB;
     else if( scanInfo.eosbb && scanIdx > 0 && scanIdx < m_numCoeff - m_sbbSize )
       scanInfo.spt      = SCAN_EOCSBB;
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    scanInfo.posX       = m_scanId2PosX[ scanIdx ];
+    scanInfo.posY       = m_scanId2PosY[ scanIdx ];
+#endif
     if( scanIdx )
     {
       const int nextScanIdx = scanIdx - 1;
@@ -898,12 +906,66 @@ namespace DQIntern
       m_goRiceZero    = 0;
     }
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    void checkRdCosts( const ScanPosType spt, const PQData& pqDataA, const PQData& pqDataB, Decision& decisionA, Decision& decisionB, bool bZeroFix) const
+#else
     void checkRdCosts( const ScanPosType spt, const PQData& pqDataA, const PQData& pqDataB, Decision& decisionA, Decision& decisionB) const
+#endif
     {
       const int32_t*  goRiceTab = g_goRiceBits[m_goRicePar];
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+      int64_t         rdCostA;
+      int64_t         rdCostB;
+      int64_t         rdCostZ;
+#else
       int64_t         rdCostA   = m_rdCost + pqDataA.deltaDist;
       int64_t         rdCostB   = m_rdCost + pqDataB.deltaDist;
       int64_t         rdCostZ   = m_rdCost;
+#endif
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+      if( bZeroFix )
+      {
+        rdCostZ = m_rdCost;
+#if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
+        if( m_remRegBins >= 4 )
+#else
+        if( m_remRegBins >= 3 )
+#endif
+        {
+          if( spt == SCAN_ISCSBB )
+          {
+            rdCostZ += m_sigFracBits.intBits[0];
+          }
+          else if( spt == SCAN_SOCSBB )
+          {
+            rdCostZ += m_sbbFracBits.intBits[1] + m_sigFracBits.intBits[0];
+          }
+          else if( m_numSigSbb )
+          {
+            rdCostZ += m_sigFracBits.intBits[0];
+          }
+          else
+          {
+            rdCostZ = decisionA.rdCost;
+          }
+        }
+        else
+        {
+          rdCostZ += goRiceTab[m_goRiceZero];
+        }
+        if( rdCostZ < decisionA.rdCost )
+        {
+          decisionA.rdCost = rdCostZ;
+          decisionA.absLevel = 0;
+          decisionA.prevId = m_stateId;
+        }
+      }
+      else
+      {
+        rdCostA = m_rdCost + pqDataA.deltaDist;
+        rdCostB = m_rdCost + pqDataB.deltaDist;
+        rdCostZ = m_rdCost;
+#endif
 #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
       if( m_remRegBins >= 4 )
 #else
@@ -971,6 +1033,9 @@ namespace DQIntern
         decisionB.absLevel = pqDataB.absLevel;
         decisionB.prevId   = m_stateId;
       }
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+      }
+#endif
     }
 
     inline void checkRdCostStart(int32_t lastOffset, const PQData &pqData, Decision &decision) const
@@ -1004,6 +1069,16 @@ namespace DQIntern
       }
     }
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    inline void checkRdCostSkipSbbZeroFix(Decision &decision) const
+    {
+      int64_t rdCost = m_rdCost + m_sbbFracBits.intBits[0];
+      decision.rdCost = rdCost;
+      decision.absLevel = 0;
+      decision.prevId = 4 + m_stateId;
+    }
+#endif
+
   private:
     int64_t                   m_rdCost;
     uint16_t                  m_absLevelsAndCtxInit[24];  // 16x8bit for abs levels + 16x16bit for ctx init id
@@ -1314,8 +1389,13 @@ namespace DQIntern
     void    dequant ( const TransformUnit& tu,  CoeffBuf& recCoeff, const ComponentID compID, const QpParam& cQP )  const;
 
   private:
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    void    xDecideAndUpdate  ( const TCoeff absCoeff, const ScanInfo& scanInfo, bool bZeroFix );
+    void    xDecide           ( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions, bool bZeroFix );
+#else
     void    xDecideAndUpdate  ( const TCoeff absCoeff, const ScanInfo& scanInfo );
     void    xDecide           ( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions );
+#endif
 
   private:
     CommonCtx   m_commonCtx;
@@ -1353,34 +1433,73 @@ namespace DQIntern
 #undef  DINIT
 
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  void DepQuant::xDecide( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions, bool bZeroFix)
+#else
   void DepQuant::xDecide( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions)
+#endif
   {
     ::memcpy( decisions, startDec, 8*sizeof(Decision) );
 
     PQData  pqData[4];
     m_quant.preQuantCoeff( absCoeff, pqData );
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2], bZeroFix);
+    m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0], bZeroFix);
+    m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3], bZeroFix);
+    m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1], bZeroFix);
+#else
     m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2]);
     m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0]);
     m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3]);
     m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1]);
+#endif
     if( spt==SCAN_EOCSBB )
     {
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+      if( bZeroFix )
+      {
+        m_skipStates[0].checkRdCostSkipSbbZeroFix( decisions[0] );
+        m_skipStates[1].checkRdCostSkipSbbZeroFix( decisions[1] );
+        m_skipStates[2].checkRdCostSkipSbbZeroFix( decisions[2] );
+        m_skipStates[3].checkRdCostSkipSbbZeroFix( decisions[3] );
+      }
+      else
+      {
+#endif
       m_skipStates[0].checkRdCostSkipSbb( decisions[0] );
       m_skipStates[1].checkRdCostSkipSbb( decisions[1] );
       m_skipStates[2].checkRdCostSkipSbb( decisions[2] );
       m_skipStates[3].checkRdCostSkipSbb( decisions[3] );
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+      }
+#endif
     }
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    if (!bZeroFix) {
+#endif
     m_startState.checkRdCostStart( lastOffset, pqData[0], decisions[0] );
     m_startState.checkRdCostStart( lastOffset, pqData[2], decisions[2] );
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    }
+#endif
   }
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  void DepQuant::xDecideAndUpdate( const TCoeff absCoeff, const ScanInfo& scanInfo, bool bZeroFix )
+#else
   void DepQuant::xDecideAndUpdate( const TCoeff absCoeff, const ScanInfo& scanInfo )
+#endif
   {
     Decision* decisions = m_trellis[ scanInfo.scanIdx ];
 
     std::swap( m_prevStates, m_currStates );
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    xDecide( scanInfo.spt, absCoeff, lastOffset(scanInfo.scanIdx), decisions, bZeroFix );
+#else
     xDecide( scanInfo.spt, absCoeff, lastOffset(scanInfo.scanIdx), decisions);
+#endif
 
     if( scanInfo.scanIdx )
     {
@@ -1456,6 +1575,12 @@ namespace DQIntern
     ::memset( tu.getCoeffs( compID ).buf, 0x00, numCoeff*sizeof(TCoeff) );
     absSum          = 0;
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    const CompArea& area = tu.blocks[compID];
+    const uint32_t iWidth = area.width;
+    const uint32_t iHeight = area.height;
+#endif
+
     //===== find first test position =====
     int   firstTestPos = numCoeff - 1;
     const TCoeff thres = m_quant.getLastThreshold();
@@ -1480,12 +1605,28 @@ namespace DQIntern
     }
     m_startState.init();
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    int iEffWidth = iWidth, iEffHeight = iHeight;
+#if JVET_M0464_UNI_MTS
+    if( tu.mtsIdx > 1 && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#else
+    if( tu.cu->emtFlag && !tu.transformSkip[compID] && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#endif
+    {
+      iEffHeight = ( iHeight == 32 ) ? 16 : iHeight;
+      iEffWidth = ( iWidth == 32 ) ? 16 : iWidth;
+    }
+#endif
 
     //===== populate trellis =====
     for( int scanIdx = firstTestPos; scanIdx >= 0; scanIdx-- )
     {
       const ScanInfo& scanInfo = tuPars.m_scanInfo[ scanIdx ];
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+      xDecideAndUpdate( abs( tCoeff[ scanInfo.rasterPos ] ), scanInfo, ( iEffWidth < iWidth || iEffHeight < iHeight ) && ( scanInfo.posX >= iEffWidth || scanInfo.posY >= iEffHeight ) );
+#else
       xDecideAndUpdate( abs( tCoeff[ scanInfo.rasterPos ] ), scanInfo );
+#endif
     }
 
     //===== find best path =====
diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp
index 6396fe5c..942c0edc 100644
--- a/source/Lib/CommonLib/TrQuant.cpp
+++ b/source/Lib/CommonLib/TrQuant.cpp
@@ -336,8 +336,10 @@ void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPel
   const int      shift_2nd              =  (g_aucLog2[height])            + TRANSFORM_MATRIX_SHIFT                          + COM16_C806_TRANS_PREC;
   const uint32_t transformWidthIndex    = g_aucLog2[width ] - 1;  // nLog2WidthMinus1, since transform start from 2-point
   const uint32_t transformHeightIndex   = g_aucLog2[height] - 1;  // nLog2HeightMinus1, since transform start from 2-point
+#if !JVET_M0297_32PT_MTS_ZERO_OUT
   const int      skipWidth              = width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight             = height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
+#endif
   
   CHECK( shift_1st < 0, "Negative shift" );
   CHECK( shift_2nd < 0, "Negative shift" );
@@ -346,6 +348,11 @@ void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPel
   int trTypeVer = DCT2;
   
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
+  const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
+#endif
   
 #if RExt__DECODER_DEBUG_TOOL_STATISTICS
   if ( trTypeHor != DCT2 )
@@ -386,8 +393,10 @@ void TrQuant::xIT( const TransformUnit &tu, const ComponentID &compID, const CCo
   const int      shift_2nd              = ( TRANSFORM_MATRIX_SHIFT + maxLog2TrDynamicRange - 1 ) - bitDepth + COM16_C806_TRANS_PREC;
   const uint32_t transformWidthIndex    = g_aucLog2[width ] - 1;                                // nLog2WidthMinus1, since transform start from 2-point
   const uint32_t transformHeightIndex   = g_aucLog2[height] - 1;                                // nLog2HeightMinus1, since transform start from 2-point
+#if !JVET_M0297_32PT_MTS_ZERO_OUT
   const int      skipWidth              = width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight             = height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
+#endif
   
   CHECK( shift_1st < 0, "Negative shift" );
   CHECK( shift_2nd < 0, "Negative shift" );
@@ -396,6 +405,11 @@ void TrQuant::xIT( const TransformUnit &tu, const ComponentID &compID, const CCo
   int trTypeVer = DCT2;
   
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
+  const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
+#endif
   
   TCoeff *tmp   = ( TCoeff * ) alloca( width * height * sizeof( TCoeff ) );
   TCoeff *block = ( TCoeff * ) alloca( width * height * sizeof( TCoeff ) );
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index cb1dc567..8b8796ce 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -50,6 +50,8 @@
 #include <assert.h>
 #include <cassert>
 
+#define JVET_M0297_32PT_MTS_ZERO_OUT                      1 // 32 point MTS based on skipping high frequency coefficients
+
 #define JVET_M0471_LONG_DEBLOCKING_FILTERS                1 
 #define JVET_M0470                                        1 // Fixed GR/TU+EG-k transition point, use limited prefix length for escape codes
 
diff --git a/source/Lib/DecoderLib/CABACReader.cpp b/source/Lib/DecoderLib/CABACReader.cpp
index bd9ff0f1..0f192634 100644
--- a/source/Lib/DecoderLib/CABACReader.cpp
+++ b/source/Lib/DecoderLib/CABACReader.cpp
@@ -2455,7 +2455,11 @@ void CABACReader::residual_coding( TransformUnit& tu, ComponentID compID )
 #endif
 
   // parse last coeff position
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  cctx.setScanPosLast( last_sig_coeff( cctx, tu, compID ) );
+#else
   cctx.setScanPosLast( last_sig_coeff( cctx ) );
+#endif
 
   // parse subblocks
   const int stateTransTab = ( tu.cs->slice->getDepQuantEnabledFlag() ? 32040 : 0 );
@@ -2469,7 +2473,11 @@ void CABACReader::residual_coding( TransformUnit& tu, ComponentID compID )
     for( int subSetId = ( cctx.scanPosLast() >> cctx.log2CGSize() ); subSetId >= 0; subSetId--)
     {
       cctx.initSubblock       ( subSetId );
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+      residual_coding_subblock( cctx, coeff, stateTransTab, state, tu, compID );
+#else
       residual_coding_subblock( cctx, coeff, stateTransTab, state );
+#endif
 #if !JVET_M0464_UNI_MTS
       if (useEmt)
       {
@@ -2650,11 +2658,44 @@ void CABACReader::explicit_rdpcm_mode( TransformUnit& tu, ComponentID compID )
 }
 
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+int CABACReader::last_sig_coeff( CoeffCodingContext& cctx, TransformUnit& tu, ComponentID compID )
+#else
 int CABACReader::last_sig_coeff( CoeffCodingContext& cctx )
+#endif
 {
   RExt__DECODER_DEBUG_BIT_STATISTICS_CREATE_SET_SIZE2( STATS__CABAC_BITS__LAST_SIG_X_Y, Size( cctx.width(), cctx.height() ), cctx.compID() );
 
   unsigned PosLastX = 0, PosLastY = 0;
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  unsigned uiMaxLastPosX = cctx.maxLastPosX();
+  unsigned uiMaxLastPosY = cctx.maxLastPosY();
+
+#if JVET_M0464_UNI_MTS
+  if( tu.mtsIdx > 1 && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#else
+  if( tu.cu->emtFlag && !tu.transformSkip[ compID ] && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#endif
+  {
+    uiMaxLastPosX = ( tu.blocks[ compID ].width  == 32 ) ? g_uiGroupIdx[ 15 ] : uiMaxLastPosX;
+    uiMaxLastPosY = ( tu.blocks[ compID ].height == 32 ) ? g_uiGroupIdx[ 15 ] : uiMaxLastPosY;
+  }
+
+  for( ; PosLastX < uiMaxLastPosX; PosLastX++ )
+  {
+    if( !m_BinDecoder.decodeBin( cctx.lastXCtxId( PosLastX ) ) )
+    {
+      break;
+    }
+  }
+  for( ; PosLastY < uiMaxLastPosY; PosLastY++ )
+  {
+    if( !m_BinDecoder.decodeBin( cctx.lastYCtxId( PosLastY ) ) )
+    {
+      break;
+    }
+  }
+#else
   for( ; PosLastX < cctx.maxLastPosX(); PosLastX++ )
   {
     if( ! m_BinDecoder.decodeBin( cctx.lastXCtxId( PosLastX ) ) )
@@ -2669,6 +2710,7 @@ int CABACReader::last_sig_coeff( CoeffCodingContext& cctx )
       break;
     }
   }
+#endif
   if( PosLastX > 3 )
   {
     uint32_t uiTemp  = 0;
@@ -2715,7 +2757,11 @@ int CABACReader::last_sig_coeff( CoeffCodingContext& cctx )
 
 
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+void CABACReader::residual_coding_subblock( CoeffCodingContext& cctx, TCoeff* coeff, const int stateTransTable, int& state, TransformUnit& tu, ComponentID compID )
+#else
 void CABACReader::residual_coding_subblock( CoeffCodingContext& cctx, TCoeff* coeff, const int stateTransTable, int& state )
+#endif
 {
   // NOTE: All coefficients of the subblock must be set to zero before calling this function
 #if RExt__DECODER_DEBUG_BIT_STATISTICS
@@ -2738,7 +2784,23 @@ void CABACReader::residual_coding_subblock( CoeffCodingContext& cctx, TCoeff* co
   bool sigGroup = ( isLast || !minSubPos );
   if( !sigGroup )
   {
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+#if JVET_M0464_UNI_MTS
+    if( tu.mtsIdx > 1 && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#else
+    if( tu.cu->emtFlag && !tu.transformSkip[ compID ] && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#endif
+    {
+      sigGroup = ( ( tu.blocks[compID].height == 32 && cctx.cgPosY() >= ( 16 >> cctx.log2CGHeight() ) )
+                || ( tu.blocks[compID].width  == 32 && cctx.cgPosX() >= ( 16 >> cctx.log2CGWidth()  ) ) ) ? 0 : m_BinDecoder.decodeBin( cctx.sigGroupCtxId() );
+    }
+    else
+    {
+      sigGroup = m_BinDecoder.decodeBin(cctx.sigGroupCtxId());
+    }
+#else
     sigGroup = m_BinDecoder.decodeBin( cctx.sigGroupCtxId() );
+#endif
   }
   if( sigGroup )
   {
diff --git a/source/Lib/DecoderLib/CABACReader.h b/source/Lib/DecoderLib/CABACReader.h
index 29e463ca..c4812fc7 100644
--- a/source/Lib/DecoderLib/CABACReader.h
+++ b/source/Lib/DecoderLib/CABACReader.h
@@ -145,8 +145,13 @@ public:
   void        emt_cu_flag               ( CodingUnit&                   cu );
 #endif
   void        explicit_rdpcm_mode       ( TransformUnit&                tu,     ComponentID     compID );
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  int         last_sig_coeff            ( CoeffCodingContext&           cctx,   TransformUnit& tu, ComponentID   compID );
+  void        residual_coding_subblock  ( CoeffCodingContext&           cctx,   TCoeff*         coeff, const int stateTransTable, int& state, TransformUnit& tu, ComponentID compID );
+#else
   int         last_sig_coeff            ( CoeffCodingContext&           cctx );
   void        residual_coding_subblock  ( CoeffCodingContext&           cctx,   TCoeff*         coeff, const int stateTransTable, int& state );
+#endif
 
   // cross component prediction (clause 7.3.8.12)
   void        cross_comp_pred           ( TransformUnit&                tu,     ComponentID     compID );
diff --git a/source/Lib/EncoderLib/CABACWriter.cpp b/source/Lib/EncoderLib/CABACWriter.cpp
index 0c5ca538..a9436dc8 100644
--- a/source/Lib/EncoderLib/CABACWriter.cpp
+++ b/source/Lib/EncoderLib/CABACWriter.cpp
@@ -2298,7 +2298,11 @@ void CABACWriter::residual_coding( const TransformUnit& tu, ComponentID compID )
   cctx.setScanPosLast(scanPosLast);
 
   // code last coeff position
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  last_sig_coeff( cctx, tu, compID );
+#else
   last_sig_coeff( cctx );
+#endif
 
   // code subblocks
   const int stateTab  = ( tu.cs->slice->getDepQuantEnabledFlag() ? 32040 : 0 );
@@ -2311,7 +2315,11 @@ void CABACWriter::residual_coding( const TransformUnit& tu, ComponentID compID )
   for( int subSetId = ( cctx.scanPosLast() >> cctx.log2CGSize() ); subSetId >= 0; subSetId--)
   {
     cctx.initSubblock       ( subSetId, sigGroupFlags[subSetId] );
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+    residual_coding_subblock( cctx, coeff, stateTab, state, tu, compID );
+#else
     residual_coding_subblock( cctx, coeff, stateTab, state );
+#endif
 
 #if !JVET_M0464_UNI_MTS
     if (useEmt)
@@ -2476,7 +2484,11 @@ void CABACWriter::explicit_rdpcm_mode( const TransformUnit& tu, ComponentID comp
 }
 
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+void CABACWriter::last_sig_coeff( CoeffCodingContext& cctx, const TransformUnit& tu, ComponentID compID )
+#else
 void CABACWriter::last_sig_coeff( CoeffCodingContext& cctx )
+#endif
 {
   unsigned blkPos = cctx.blockPos( cctx.scanPosLast() );
   unsigned posX, posY;
@@ -2497,11 +2509,30 @@ void CABACWriter::last_sig_coeff( CoeffCodingContext& cctx )
   unsigned GroupIdxX = g_uiGroupIdx[ posX ];
   unsigned GroupIdxY = g_uiGroupIdx[ posY ];
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  unsigned uiMaxLastPosX = cctx.maxLastPosX();
+  unsigned uiMaxLastPosY = cctx.maxLastPosY();
+
+#if JVET_M0464_UNI_MTS
+  if( tu.mtsIdx > 1 && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#else
+  if( tu.cu->emtFlag && !tu.transformSkip[ compID ] && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#endif
+  {
+    uiMaxLastPosX = ( tu.blocks[compID].width  == 32 ) ? g_uiGroupIdx[ 15 ] : uiMaxLastPosX;
+    uiMaxLastPosY = ( tu.blocks[compID].height == 32 ) ? g_uiGroupIdx[ 15 ] : uiMaxLastPosY;
+  }
+#endif
+
   for( CtxLast = 0; CtxLast < GroupIdxX; CtxLast++ )
   {
     m_BinEncoder.encodeBin( 1, cctx.lastXCtxId( CtxLast ) );
   }
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  if( GroupIdxX < uiMaxLastPosX )
+#else
   if( GroupIdxX < cctx.maxLastPosX() )
+#endif
   {
     m_BinEncoder.encodeBin( 0, cctx.lastXCtxId( CtxLast ) );
   }
@@ -2509,7 +2540,11 @@ void CABACWriter::last_sig_coeff( CoeffCodingContext& cctx )
   {
     m_BinEncoder.encodeBin( 1, cctx.lastYCtxId( CtxLast ) );
   }
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  if( GroupIdxY < uiMaxLastPosY )
+#else
   if( GroupIdxY < cctx.maxLastPosY() )
+#endif
   {
     m_BinEncoder.encodeBin( 0, cctx.lastYCtxId( CtxLast ) );
   }
@@ -2533,7 +2568,11 @@ void CABACWriter::last_sig_coeff( CoeffCodingContext& cctx )
 
 
 
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+void CABACWriter::residual_coding_subblock( CoeffCodingContext& cctx, const TCoeff* coeff, const int stateTransTable, int& state, const TransformUnit& tu, ComponentID compID )
+#else
 void CABACWriter::residual_coding_subblock( CoeffCodingContext& cctx, const TCoeff* coeff, const int stateTransTable, int& state )
+#endif
 {
   //===== init =====
   const int   minSubPos   = cctx.minSubPos();
@@ -2544,6 +2583,21 @@ void CABACWriter::residual_coding_subblock( CoeffCodingContext& cctx, const TCoe
   //===== encode significant_coeffgroup_flag =====
   if( !isLast && cctx.isNotFirst() )
   {
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+#if JVET_M0464_UNI_MTS
+    if( tu.mtsIdx > 1 && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#else
+    if( tu.cu->emtFlag && !tu.transformSkip[ compID ] && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
+#endif
+    {
+      if( ( tu.blocks[compID].height == 32 && cctx.cgPosY() >= ( 16 >> cctx.log2CGHeight() ) )
+       || ( tu.blocks[compID].width  == 32 && cctx.cgPosX() >= ( 16 >> cctx.log2CGWidth()  ) ) )
+      {
+        return;
+      }
+    }
+#endif
+
     if( cctx.isSigGroup() )
     {
       m_BinEncoder.encodeBin( 1, cctx.sigGroupCtxId() );
diff --git a/source/Lib/EncoderLib/CABACWriter.h b/source/Lib/EncoderLib/CABACWriter.h
index 51fbd8dd..19d2827f 100644
--- a/source/Lib/EncoderLib/CABACWriter.h
+++ b/source/Lib/EncoderLib/CABACWriter.h
@@ -159,8 +159,13 @@ public:
   void        emt_cu_flag               ( const CodingUnit&             cu );
 #endif
   void        explicit_rdpcm_mode       ( const TransformUnit&          tu,       ComponentID       compID );
+#if JVET_M0297_32PT_MTS_ZERO_OUT
+  void        last_sig_coeff            ( CoeffCodingContext&           cctx,     const TransformUnit& tu, ComponentID       compID );
+  void        residual_coding_subblock  ( CoeffCodingContext&           cctx,     const TCoeff*     coeff, const int stateTransTable, int& state, const TransformUnit& tu, ComponentID compID);
+#else
   void        last_sig_coeff            ( CoeffCodingContext&           cctx );
-  void        residual_coding_subblock  ( CoeffCodingContext&           cctx,     const TCoeff*     coeff, const int stateTransTable, int& state   );
+  void        residual_coding_subblock  ( CoeffCodingContext&           cctx,     const TCoeff*     coeff, const int stateTransTable, int& state );
+#endif
 
   // cross component prediction (clause 7.3.8.12)
   void        cross_comp_pred           ( const TransformUnit&          tu,       ComponentID       compID );
-- 
GitLab