From 59cac6da827fd2f78cd5b30d877a4ab8e94480a4 Mon Sep 17 00:00:00 2001
From: Mohammed Golam Sarwer <m.sarwer@alibaba-inc.com>
Date: Fri, 25 Oct 2019 21:54:11 +0200
Subject: [PATCH] JVET-P0983/P0391: remove sps maximum sbt size flag: add
 configurable parameter to use 64 size SBT in encoder RDO

---
 doc/software-manual.tex             | 6 ++++++
 source/App/EncoderApp/EncApp.cpp    | 3 +++
 source/App/EncoderApp/EncAppCfg.cpp | 3 +++
 source/App/EncoderApp/EncAppCfg.h   | 4 +++-
 source/Lib/CommonLib/Slice.cpp      | 2 ++
 source/Lib/CommonLib/Slice.h        | 4 ++++
 source/Lib/CommonLib/TypeDef.h      | 1 +
 source/Lib/CommonLib/Unit.cpp       | 4 ++++
 source/Lib/DecoderLib/VLCReader.cpp | 2 ++
 source/Lib/EncoderLib/EncCfg.h      | 9 +++++++++
 source/Lib/EncoderLib/EncCu.cpp     | 7 +++++++
 source/Lib/EncoderLib/EncLib.cpp    | 2 ++
 source/Lib/EncoderLib/VLCWriter.cpp | 2 ++
 13 files changed, 48 insertions(+), 1 deletion(-)

diff --git a/doc/software-manual.tex b/doc/software-manual.tex
index db557a0b8..a1b8bafe5 100644
--- a/doc/software-manual.tex
+++ b/doc/software-manual.tex
@@ -1342,6 +1342,12 @@ candidate is not evaluated if the merge skip mode was the best merge
 mode for one of the previous candidates.
 \\
 
+\Option{SBT64RDO} &
+%\ShortOption{\None} &
+\Default{true} &
+Enables or disables the use of more than 32 SBT in encoder RDO check.  When enabled, more than 32 size SBT is checked in the RDO.
+\\
+
 \Option{RDpenalty} &
 %\ShortOption{\None} &
 \Default{0} &
diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp
index 9bef109b9..a7ebf95d8 100644
--- a/source/App/EncoderApp/EncApp.cpp
+++ b/source/App/EncoderApp/EncApp.cpp
@@ -285,6 +285,9 @@ void EncApp::xInitLibCfg()
 #endif
   m_cEncLib.setImplicitMTS                                       ( m_MTSImplicit );
   m_cEncLib.setUseSBT                                            ( m_SBT );
+#if JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
+  m_cEncLib.setUse64SBTRDOCheck                                  (m_SBT64RDOCheck);
+#endif
   m_cEncLib.setUseCompositeRef                                   ( m_compositeRefEnabled );
   m_cEncLib.setUseSMVD                                           ( m_SMVD );
   m_cEncLib.setUseGBi                                            ( m_GBi );
diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp
index 0b07ec200..ba71af2ac 100644
--- a/source/App/EncoderApp/EncAppCfg.cpp
+++ b/source/App/EncoderApp/EncAppCfg.cpp
@@ -943,6 +943,9 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] )
   ("MTSInterMaxCand",                                 m_MTSInterMaxCand,                                    4, "Number of additional candidates to test in encoder search for MTS in inter slices\n")
   ("MTSImplicit",                                     m_MTSImplicit,                                        0, "Enable implicit MTS (when explicit MTS is off)\n")
   ( "SBT",                                            m_SBT,                                            false, "Enable Sub-Block Transform for inter blocks\n" )
+#if JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
+  ("SBT64RDO",                                        m_SBT64RDOCheck,             (m_iSourceWidth >= 1920 ? true : false ), "Enable more than 32 SBT in encoder RDO check \n")
+#endif
   ( "ISP",                                            m_ISP,                                            false, "Enable Intra Sub-Partitions\n" )
   ("SMVD",                                            m_SMVD,                                           false, "Enable Symmetric MVD\n")
   ("CompositeLTReference",                            m_compositeRefEnabled,                            false, "Enable Composite Long Term Reference Frame")
diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h
index 1269ebd1f..267c895f2 100644
--- a/source/App/EncoderApp/EncAppCfg.h
+++ b/source/App/EncoderApp/EncAppCfg.h
@@ -280,7 +280,9 @@ protected:
   int       m_MTSInterMaxCand;                                ///< XZ: Number of additional candidates to test
   int       m_MTSImplicit;
   bool      m_SBT;                                            ///< Sub-Block Transform for inter blocks
-
+#if JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
+  bool      m_SBT64RDOCheck;                                            
+#endif
   bool      m_SMVD;
   bool      m_compositeRefEnabled;
   bool      m_GBi;
diff --git a/source/Lib/CommonLib/Slice.cpp b/source/Lib/CommonLib/Slice.cpp
index 2890f7848..ffd2e9127 100644
--- a/source/Lib/CommonLib/Slice.cpp
+++ b/source/Lib/CommonLib/Slice.cpp
@@ -1602,7 +1602,9 @@ SPS::SPS()
 , m_DMVR                      ( false )
 , m_MMVD                      ( false )
 , m_SBT                       ( false )
+#if !JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
 , m_MaxSbtSize                ( 32 )
+#endif
 , m_ISP                       ( false )
 , m_chromaFormatIdc           (CHROMA_420)
 , m_uiMaxTLayers              (  1)
diff --git a/source/Lib/CommonLib/Slice.h b/source/Lib/CommonLib/Slice.h
index 646822421..0e6ebc9e3 100644
--- a/source/Lib/CommonLib/Slice.h
+++ b/source/Lib/CommonLib/Slice.h
@@ -775,7 +775,9 @@ private:
   bool              m_DMVR;
   bool              m_MMVD;
   bool              m_SBT;
+#if !JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
   uint8_t           m_MaxSbtSize;
+#endif
   bool              m_ISP;
   ChromaFormat      m_chromaFormatIdc;
 
@@ -1069,8 +1071,10 @@ public:
   bool                    getUseSBT() const                                                               { return m_SBT; }
   void                    setUseISP( bool b )                                                             { m_ISP = b; }
   bool                    getUseISP() const                                                               { return m_ISP; }
+#if !JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
   void                    setMaxSbtSize( uint8_t val )                                                    { m_MaxSbtSize = val; }
   uint8_t                 getMaxSbtSize() const                                                           { return m_MaxSbtSize; }
+#endif
 
   void      setAMVREnabledFlag    ( bool b )                                        { m_AMVREnabledFlag = b; }
   bool      getAMVREnabledFlag    ()                                      const     { return m_AMVREnabledFlag; }
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index da5f7bfa4..cb17fd390 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -49,6 +49,7 @@
 #include <cstring>
 #include <assert.h>
 #include <cassert>
+#define JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG           1 // JVET-P0983/JVET-P0391: Remove sps_sbt_max_size_64_flag
 
 #define JVET_P0530_TPM_WEIGHT_ALIGN                       1 // JVET-P0530: align chroma weights with luma weights for TPM blending
 
diff --git a/source/Lib/CommonLib/Unit.cpp b/source/Lib/CommonLib/Unit.cpp
index a3742c476..05d0035b9 100644
--- a/source/Lib/CommonLib/Unit.cpp
+++ b/source/Lib/CommonLib/Unit.cpp
@@ -459,7 +459,11 @@ const uint8_t CodingUnit::checkAllowedSbt() const
   memset( allow_type, false, NUMBER_SBT_IDX * sizeof( bool ) );
 
   //parameter
+#if JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
+  int maxSbtCUSize = cs->sps->getMaxTbSize();
+#else
   int maxSbtCUSize = cs->sps->getMaxSbtSize();
+#endif
   int minSbtCUSize = 1 << ( MIN_CU_LOG2 + 1 );
 
   //check on size
diff --git a/source/Lib/DecoderLib/VLCReader.cpp b/source/Lib/DecoderLib/VLCReader.cpp
index d11de7ca1..57a95bd07 100644
--- a/source/Lib/DecoderLib/VLCReader.cpp
+++ b/source/Lib/DecoderLib/VLCReader.cpp
@@ -1469,10 +1469,12 @@ void HLSyntaxReader::parseSPS(SPS* pcSPS)
   READ_FLAG( uiCode,    "sps_mip_flag");                            pcSPS->setUseMIP                 ( uiCode != 0 );
   // KJS: not in draft yet
   READ_FLAG(uiCode, "sbt_enable_flag");                             pcSPS->setUseSBT(uiCode != 0);
+#if !JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
   if( pcSPS->getUseSBT() )
   {
     READ_FLAG(uiCode, "max_sbt_size_64_flag");                      pcSPS->setMaxSbtSize(std::min((int)(1 << pcSPS->getLog2MaxTbSize()), uiCode != 0 ? 64 : 32));
   }
+#endif
   // KJS: not in draft yet
   READ_FLAG(uiCode, "sps_reshaper_enable_flag");                   pcSPS->setUseReshaper(uiCode == 1);
   READ_FLAG(uiCode, "isp_enable_flag");                            pcSPS->setUseISP(uiCode != 0);
diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h
index ce58afcab..7d899ae5b 100644
--- a/source/Lib/EncoderLib/EncCfg.h
+++ b/source/Lib/EncoderLib/EncCfg.h
@@ -285,6 +285,10 @@ protected:
 #endif
   int       m_ImplicitMTS;
   bool      m_SBT;                                ///< Sub-Block Transform for inter blocks
+#if JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
+  bool     m_SBT64RDOCheck; // Enable more than 32 SBT in encoder RDO check
+#endif
+
   bool      m_LFNST;
   bool      m_useFastLFNST;
   int       m_SubPuMvpMode;
@@ -891,6 +895,11 @@ public:
   void      setUseSBT                       ( bool b )       { m_SBT = b; }
   bool      getUseSBT                       ()         const { return m_SBT; }
 
+#if JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
+  void      setUse64SBTRDOCheck(bool b)                     { m_SBT64RDOCheck = b; }
+  bool      getUse64SBTRDOCheck             ()        const { return m_SBT64RDOCheck; }
+#endif
+
   void      setUseCompositeRef              (bool b)         { m_compositeRefEnabled = b; }
   bool      getUseCompositeRef              ()         const { return m_compositeRefEnabled; }
   void      setUseSMVD                      ( bool b )       { m_SMVD = b; }
diff --git a/source/Lib/EncoderLib/EncCu.cpp b/source/Lib/EncoderLib/EncCu.cpp
index 46a8db6ff..39f28c445 100644
--- a/source/Lib/EncoderLib/EncCu.cpp
+++ b/source/Lib/EncoderLib/EncCu.cpp
@@ -665,7 +665,11 @@ void EncCu::xCompressCU( CodingStructure*& tempCS, CodingStructure*& bestCS, Par
   if( partitioner.currQtDepth == 0 && partitioner.currMtDepth == 0 && !tempCS->slice->isIntra() && ( sps.getUseSBT() || sps.getUseInterMTS() ) )
   {
     auto slsSbt = dynamic_cast<SaveLoadEncInfoSbt*>( m_modeCtrl );
+#if  JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
+    int maxSLSize = sps.getUseSBT() ? tempCS->slice->getSPS()->getMaxTbSize() : MTS_INTER_MAX_CU_SIZE;
+#else
     int maxSLSize = sps.getUseSBT() ? tempCS->slice->getSPS()->getMaxSbtSize() : MTS_INTER_MAX_CU_SIZE;
+#endif
     slsSbt->resetSaveloadSbt( maxSLSize );
 #if ENABLE_SPLIT_PARALLELISM
     CHECK( tempCS->picture->scheduler.getSplitJobId() != 0, "The SBT search reset need to happen in sequential region." );
@@ -4232,6 +4236,9 @@ void EncCu::xEncodeInterResidual(   CodingStructure *&tempCS
   }
   const bool mtsAllowed = tempCS->sps->getUseInterMTS() && CU::isInter( *cu ) && partitioner.currArea().lwidth() <= MTS_INTER_MAX_CU_SIZE && partitioner.currArea().lheight() <= MTS_INTER_MAX_CU_SIZE;
   uint8_t sbtAllowed = cu->checkAllowedSbt();
+#if JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
+  sbtAllowed = ((cu->lwidth() > 32 || cu->lheight() > 32) && !(m_pcEncCfg->getUse64SBTRDOCheck())) ? 0 : sbtAllowed;
+#endif
   uint8_t numRDOTried = 0;
   Distortion sbtOffDist = 0;
   bool    sbtOffRootCbf = 0;
diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp
index 27c9b6790..63718467c 100644
--- a/source/Lib/EncoderLib/EncLib.cpp
+++ b/source/Lib/EncoderLib/EncLib.cpp
@@ -996,10 +996,12 @@ void EncLib::xInitSPS(SPS &sps)
   sps.setUseIntraMTS           ( m_IntraMTS );
   sps.setUseInterMTS           ( m_InterMTS );
   sps.setUseSBT                             ( m_SBT );
+#if !JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
   if( sps.getUseSBT() )
   {
     sps.setMaxSbtSize                       ( std::min((int)(1 << m_log2MaxTbSize), m_iSourceWidth >= 1920 ? 64 : 32) );
   }
+#endif
   sps.setUseSMVD                ( m_SMVD );
   sps.setUseGBi                ( m_GBi );
 #if LUMA_ADAPTIVE_DEBLOCKING_FILTER_QP_OFFSET
diff --git a/source/Lib/EncoderLib/VLCWriter.cpp b/source/Lib/EncoderLib/VLCWriter.cpp
index 05025487d..00b0513d6 100644
--- a/source/Lib/EncoderLib/VLCWriter.cpp
+++ b/source/Lib/EncoderLib/VLCWriter.cpp
@@ -910,10 +910,12 @@ void HLSWriter::codeSPS( const SPS* pcSPS )
   WRITE_FLAG( pcSPS->getUseMIP() ? 1: 0,                                                       "sps_mip_flag" );
   // KJS: not in draft yet
   WRITE_FLAG( pcSPS->getUseSBT() ? 1 : 0,                                             "sbt_enable_flag");
+#if !JVET_P0983_REMOVE_SPS_SBT_MAX_SIZE_FLAG
   if( pcSPS->getUseSBT() )
   {
     WRITE_FLAG(pcSPS->getMaxSbtSize() == 64 ? 1 : 0,                                  "max_sbt_size_64_flag");
   }
+#endif
   // KJS: not in draft yet
   WRITE_FLAG(pcSPS->getUseReshaper() ? 1 : 0, "sps_reshaper_enable_flag");
   WRITE_FLAG( pcSPS->getUseISP() ? 1 : 0,                                             "isp_enable_flag");
-- 
GitLab