From 4288e3c963feea102fdd8917f7df2e59fe9fd8cc Mon Sep 17 00:00:00 2001
From: Jani Lainema <jani.lainema@nokia.com>
Date: Fri, 1 Feb 2019 13:57:56 +0200
Subject: [PATCH] JVET-M0303: implicit MTS

---
 source/App/EncoderApp/EncApp.cpp    |  3 +++
 source/App/EncoderApp/EncAppCfg.cpp |  9 +++++++++
 source/App/EncoderApp/EncAppCfg.h   |  3 +++
 source/Lib/CommonLib/Slice.h        | 12 ++++++++++++
 source/Lib/CommonLib/TrQuant.cpp    | 21 ++++++++++++++++++++-
 source/Lib/CommonLib/TypeDef.h      |  2 ++
 source/Lib/DecoderLib/VLCReader.cpp | 18 ++++++++++++++----
 source/Lib/EncoderLib/EncCfg.h      | 10 +++++++---
 source/Lib/EncoderLib/EncLib.cpp    | 14 ++++++++++++++
 source/Lib/EncoderLib/VLCWriter.cpp | 18 ++++++++++++++----
 10 files changed, 98 insertions(+), 12 deletions(-)

diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp
index 54cfe178..022278ec 100644
--- a/source/App/EncoderApp/EncApp.cpp
+++ b/source/App/EncoderApp/EncApp.cpp
@@ -239,6 +239,9 @@ void EncApp::xInitLibCfg()
   m_cEncLib.setFastIntraEMT                                      ( m_FastEMT & m_EMT & 1 );
   m_cEncLib.setInterEMT                                          ( ( m_EMT >> 1 ) & 1 );
   m_cEncLib.setFastInterEMT                                      ( ( m_FastEMT >> 1 ) & ( m_EMT >> 1 ) & 1 );
+#endif
+#if JVET_M0303_IMPLICIT_MTS
+  m_cEncLib.setImplicitMTS                                       ( m_MTSImplicit );
 #endif
   m_cEncLib.setUseCompositeRef                                   ( m_compositeRefEnabled );
   m_cEncLib.setUseGBi                                            ( m_GBi );
diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp
index f31fad86..f817d6ec 100644
--- a/source/App/EncoderApp/EncAppCfg.cpp
+++ b/source/App/EncoderApp/EncAppCfg.cpp
@@ -852,6 +852,9 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] )
     "\t1:  Enable fast methods only for Intra EMT\n"
     "\t2:  Enable fast methods only for Inter EMT\n"
     "\t3:  Enable fast methods for both Intra & Inter EMT\n")
+#endif
+#if JVET_M0303_IMPLICIT_MTS
+  ("MTSImplicit",                                     m_MTSImplicit,                                        0, "Enable implicit MTS (when explicit MTS is off)\n")
 #endif
   ("CompositeLTReference",                            m_compositeRefEnabled,                            false, "Enable Composite Long Term Reference Frame")
   ("GBi",                                             m_GBi,                                            false, "Enable Generalized Bi-prediction(GBi)")
@@ -2299,9 +2302,15 @@ bool EncAppCfg::xCheckParameter()
   xConfirmPara( m_MTS < 0 || m_MTS > 3, "MTS must be greater than 0 smaller than 4" );
   xConfirmPara( m_MTSIntraMaxCand < 0 || m_MTSIntraMaxCand > 5, "m_MTSIntraMaxCand must be greater than 0 and smaller than 6" );
   xConfirmPara( m_MTSInterMaxCand < 0 || m_MTSInterMaxCand > 5, "m_MTSInterMaxCand must be greater than 0 and smaller than 6" );
+#if JVET_M0303_IMPLICIT_MTS
+  xConfirmPara( m_MTS != 0 && m_MTSImplicit != 0, "Both explicit and implicit MTS cannot be enabled at the same time" );
+#endif
 #else
   xConfirmPara( m_EMT < 0 || m_EMT >3, "EMT must be 0, 1, 2 or 3" );
   xConfirmPara( m_FastEMT < 0 || m_FastEMT >3, "FEMT must be 0, 1, 2 or 3" );
+#if JVET_M0303_IMPLICIT_MTS
+  xConfirmPara( m_EMT != 0 && m_MTSImplicit != 0, "Both explicit and implicit MTS cannot be enabled at the same time" );
+#endif
 #endif
   if( m_usePCM)
   {
diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h
index c91c7ec5..1c5c5aa5 100644
--- a/source/App/EncoderApp/EncAppCfg.h
+++ b/source/App/EncoderApp/EncAppCfg.h
@@ -224,6 +224,9 @@ protected:
   int       m_EMT;                                            ///< XZ: Enhanced Multiple Transform
   int       m_FastEMT;                                        ///< XZ: Fast Methods of Enhanced Multiple Transform
 #endif
+#if JVET_M0303_IMPLICIT_MTS
+  int       m_MTSImplicit;
+#endif
 
   bool      m_compositeRefEnabled;
   bool      m_GBi;
diff --git a/source/Lib/CommonLib/Slice.h b/source/Lib/CommonLib/Slice.h
index a4e2a209..5c1fe943 100644
--- a/source/Lib/CommonLib/Slice.h
+++ b/source/Lib/CommonLib/Slice.h
@@ -807,6 +807,9 @@ private:
 #if JVET_M0142_CCLM_COLLOCATED_CHROMA
   bool              m_cclmCollocatedChromaFlag;
 #endif
+#if JVET_M0303_IMPLICIT_MTS
+  bool              m_MTS;
+#endif
 #if JVET_M0464_UNI_MTS
   bool              m_IntraMTS;                   // 18
   bool              m_InterMTS;                   // 19
@@ -874,6 +877,15 @@ public:
   void      setCclmCollocatedChromaFlag( bool b )                                   { m_cclmCollocatedChromaFlag = b; }
   bool      getCclmCollocatedChromaFlag()                                 const     { return m_cclmCollocatedChromaFlag; }
 #endif
+#if JVET_M0303_IMPLICIT_MTS
+  void      setUseMTS             ( bool b )                                        { m_MTS = b; }
+  bool      getUseMTS             ()                                      const     { return m_MTS; }
+#if JVET_M0464_UNI_MTS
+  bool      getUseImplicitMTS     ()                                      const     { return m_MTS && !m_IntraMTS && !m_InterMTS; }
+#else
+  bool      getUseImplicitMTS     ()                                      const     { return m_MTS && !m_IntraEMT && !m_InterEMT; }
+#endif
+#endif
 #if JVET_M0464_UNI_MTS
   void      setUseIntraMTS        ( bool b )                                        { m_IntraMTS = b; }
   bool      getUseIntraMTS        ()                                      const     { return m_IntraMTS; }
diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp
index 6396fe5c..a148cb5a 100644
--- a/source/Lib/CommonLib/TrQuant.cpp
+++ b/source/Lib/CommonLib/TrQuant.cpp
@@ -296,7 +296,10 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy
 #else
   bool emtActivated = CU::isIntra( *tu.cu ) ? tu.cs->sps->getSpsNext().getUseIntraEMT() : tu.cs->sps->getSpsNext().getUseInterEMT();
 #endif
-  
+#if JVET_M0303_IMPLICIT_MTS
+  bool mtsImplicit  = CU::isIntra( *tu.cu ) && tu.cs->sps->getSpsNext().getUseImplicitMTS() && compID == COMPONENT_Y;
+#endif
+
   trTypeHor = DCT2;
   trTypeVer = DCT2;
   
@@ -325,6 +328,22 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy
       }
     }
   }
+#if JVET_M0303_IMPLICIT_MTS
+  else if ( mtsImplicit )
+  {
+    int  width       = tu.blocks[compID].width;
+    int  height      = tu.blocks[compID].height;
+    bool widthDstOk  = width  >= 4 && width  <= 16;
+    bool heightDstOk = height >= 4 && height <= 16;
+    
+    if ( width < height && widthDstOk )
+      trTypeHor = DST7;
+    else if ( height < width && heightDstOk )
+      trTypeVer = DST7;
+    else if ( width == height && widthDstOk )
+      trTypeHor = trTypeVer = DST7;
+  }
+#endif
 }
 
 void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPelBuf &resi, CoeffBuf &dstCoeff, const int width, const int height )
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index cb1dc567..f70df382 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -50,6 +50,8 @@
 #include <assert.h>
 #include <cassert>
 
+#define JVET_M0303_IMPLICIT_MTS                           1 // Implicit transform selection (can be enabled with MTSImplicit encoder config parameter)
+
 #define JVET_M0471_LONG_DEBLOCKING_FILTERS                1 
 #define JVET_M0470                                        1 // Fixed GR/TU+EG-k transition point, use limited prefix length for escape codes
 
diff --git a/source/Lib/DecoderLib/VLCReader.cpp b/source/Lib/DecoderLib/VLCReader.cpp
index 94c056f8..ca13c059 100644
--- a/source/Lib/DecoderLib/VLCReader.cpp
+++ b/source/Lib/DecoderLib/VLCReader.cpp
@@ -798,13 +798,23 @@ void HLSyntaxReader::parseSPSNext( SPSNext& spsNext, const bool usePCM )
     READ_FLAG( symbol,  "sps_cclm_collocated_chroma_flag" );        spsNext.setCclmCollocatedChromaFlag( symbol != 0 );
   }
 #endif
+
+#if JVET_M0303_IMPLICIT_MTS
+  READ_FLAG( symbol,    "mts_enabled_flag" );                       spsNext.setUseMTS                 ( symbol != 0 );
+  if ( spsNext.getUseMTS() )
+  {
+#endif
 #if JVET_M0464_UNI_MTS
-  READ_FLAG( symbol,    "mts_intra_enabled_flag" );                 spsNext.setUseIntraMTS            ( symbol != 0 );
-  READ_FLAG( symbol,    "mts_inter_enabled_flag" );                 spsNext.setUseInterMTS            ( symbol != 0 );
+    READ_FLAG( symbol,    "mts_intra_enabled_flag" );               spsNext.setUseIntraMTS            ( symbol != 0 );
+    READ_FLAG( symbol,    "mts_inter_enabled_flag" );               spsNext.setUseInterMTS            ( symbol != 0 );
 #else
-  READ_FLAG( symbol,    "emt_intra_enabled_flag" );                 spsNext.setUseIntraEMT            ( symbol != 0 );
-  READ_FLAG( symbol,    "emt_inter_enabled_flag" );                 spsNext.setUseInterEMT            ( symbol != 0 );
+    READ_FLAG( symbol,    "emt_intra_enabled_flag" );               spsNext.setUseIntraEMT            ( symbol != 0 );
+    READ_FLAG( symbol,    "emt_inter_enabled_flag" );               spsNext.setUseInterEMT            ( symbol != 0 );
 #endif
+#if JVET_M0303_IMPLICIT_MTS
+  }
+#endif
+
   READ_FLAG( symbol,    "affine_flag" );                            spsNext.setUseAffine              ( symbol != 0 );
   if ( spsNext.getUseAffine() )
   {
diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h
index bbfd7482..e9bf8e04 100644
--- a/source/Lib/EncoderLib/EncCfg.h
+++ b/source/Lib/EncoderLib/EncCfg.h
@@ -214,6 +214,9 @@ protected:
   int       m_InterEMT;
   int       m_FastIntraEMT;
   int       m_FastInterEMT;
+#endif
+#if JVET_M0303_IMPLICIT_MTS
+  int       m_ImplicitMTS;
 #endif
   bool      m_LargeCTU;
   int       m_SubPuMvpMode;
@@ -735,9 +738,10 @@ public:
   void      setInterEMT                     ( bool b )       { m_InterEMT = b; }
   bool      getInterEMT                     ()         const { return m_InterEMT; }
 #endif
-
-
-
+#if JVET_M0303_IMPLICIT_MTS
+  void      setImplicitMTS                  ( bool b )       { m_ImplicitMTS = b; }
+  bool      getImplicitMTS                  ()         const { return m_ImplicitMTS; }
+#endif
 
   void      setUseCompositeRef              (bool b)         { m_compositeRefEnabled = b; }
   bool      getUseCompositeRef              ()         const { return m_compositeRefEnabled; }
diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp
index 4d8506f0..a871a8cc 100644
--- a/source/Lib/EncoderLib/EncLib.cpp
+++ b/source/Lib/EncoderLib/EncLib.cpp
@@ -798,9 +798,17 @@ void EncLib::xInitSPS(SPS &sps)
   sps.setNoAmvrConstraintFlag(!m_bNoAmvrConstraintFlag);
   sps.setNoAffineMotionConstraintFlag(!m_Affine);
 #if JVET_M0464_UNI_MTS
+#if JVET_M0303_IMPLICIT_MTS
+  sps.setNoMtsConstraintFlag((m_IntraMTS || m_InterMTS || m_ImplicitMTS) ? false : true);
+#else
   sps.setNoMtsConstraintFlag((m_IntraMTS || m_InterMTS) ? false : true);
+#endif
+#else
+#if JVET_M0303_IMPLICIT_MTS
+  sps.setNoMtsConstraintFlag((m_IntraEMT || m_InterEMT || m_ImplicitMTS) ? false : true);
 #else
   sps.setNoMtsConstraintFlag((m_IntraEMT || m_InterEMT) ? false : true);
+#endif
 #endif
   sps.setNoLadfConstraintFlag(!m_LadfEnabled);
   sps.setNoDepQuantConstraintFlag(!m_DepQuantEnabledFlag);
@@ -867,9 +875,15 @@ void EncLib::xInitSPS(SPS &sps)
   sps.getSpsNext().setUseNextDQP            ( m_AltDQPCoding );
 #endif
 #if JVET_M0464_UNI_MTS
+#if JVET_M0303_IMPLICIT_MTS
+  sps.getSpsNext().setUseMTS                ( m_IntraMTS || m_InterMTS || m_ImplicitMTS );
+#endif
   sps.getSpsNext().setUseIntraMTS           ( m_IntraMTS );
   sps.getSpsNext().setUseInterMTS           ( m_InterMTS );
 #else
+#if JVET_M0303_IMPLICIT_MTS
+  sps.getSpsNext().setUseMTS                ( m_IntraEMT || m_InterEMT || m_ImplicitMTS );
+#endif
   sps.getSpsNext().setUseIntraEMT           ( m_IntraEMT );
   sps.getSpsNext().setUseInterEMT           ( m_InterEMT );
 #endif
diff --git a/source/Lib/EncoderLib/VLCWriter.cpp b/source/Lib/EncoderLib/VLCWriter.cpp
index 82300aae..6697970f 100644
--- a/source/Lib/EncoderLib/VLCWriter.cpp
+++ b/source/Lib/EncoderLib/VLCWriter.cpp
@@ -538,13 +538,23 @@ void HLSWriter::codeSPSNext( const SPSNext& spsNext, const bool usePCM )
     WRITE_FLAG( spsNext.getCclmCollocatedChromaFlag() ? 1 : 0,                                  "sps_cclm_collocated_chroma_flag" );
   }
 #endif
+
+#if JVET_M0303_IMPLICIT_MTS
+  WRITE_FLAG( spsNext.getUseMTS() ? 1 : 0,                                                      "mts_enabled_flag" );
+  if ( spsNext.getUseMTS() )
+  {
+#endif
 #if JVET_M0464_UNI_MTS
-  WRITE_FLAG( spsNext.getUseIntraMTS() ? 1 : 0,                                                 "mts_intra_enabled_flag" );
-  WRITE_FLAG( spsNext.getUseInterMTS() ? 1 : 0,                                                 "mts_inter_enabled_flag" );
+    WRITE_FLAG( spsNext.getUseIntraMTS() ? 1 : 0,                                               "mts_intra_enabled_flag" );
+    WRITE_FLAG( spsNext.getUseInterMTS() ? 1 : 0,                                               "mts_inter_enabled_flag" );
 #else
-  WRITE_FLAG( spsNext.getUseIntraEMT() ? 1 : 0,                                                 "emt_intra_enabled_flag" );
-  WRITE_FLAG( spsNext.getUseInterEMT() ? 1 : 0,                                                 "emt_inter_enabled_flag" );
+    WRITE_FLAG( spsNext.getUseIntraEMT() ? 1 : 0,                                               "emt_intra_enabled_flag" );
+    WRITE_FLAG( spsNext.getUseInterEMT() ? 1 : 0,                                               "emt_inter_enabled_flag" );
 #endif
+#if JVET_M0303_IMPLICIT_MTS
+  }
+#endif
+
   WRITE_FLAG( spsNext.getUseAffine() ? 1 : 0,                                                   "affine_flag" );
   if ( spsNext.getUseAffine() )
   {
-- 
GitLab