From 3ef354706ad13ab558e0b2bd34125d598c1d6abf Mon Sep 17 00:00:00 2001
From: hegilmez <hegilmez@qti.qualcomm.com>
Date: Tue, 9 Apr 2019 13:21:51 -0700
Subject: [PATCH] JVET-N0866: unified transform derivation for ISP and implicit
 MTS

- Implicit transform derivation
- DST-7 is applied horizontally/vertically as long as the number of (luma) samples are less than or equal to 16 in a row/column, otherwise DCT-2 is applied.
---
 source/Lib/CommonLib/TrQuant.cpp   | 10 ++++
 source/Lib/CommonLib/TrQuant.h     |  3 +-
 source/Lib/CommonLib/TypeDef.h     |  2 +
 source/Lib/CommonLib/UnitTools.cpp | 74 ++++++++++++++++++++++++++++++
 source/Lib/CommonLib/UnitTools.h   |  6 +++
 5 files changed, 94 insertions(+), 1 deletion(-)

diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp
index 8c701503e..b542d274a 100644
--- a/source/Lib/CommonLib/TrQuant.cpp
+++ b/source/Lib/CommonLib/TrQuant.cpp
@@ -281,6 +281,7 @@ void TrQuant::invRdpcmNxN(TransformUnit& tu, const ComponentID &compID, PelBuf &
 // Logical transform
 // ------------------------------------------------------------------------------------------------
 
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
 void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer )
 {
   bool mtsActivated = CU::isIntra( *tu.cu ) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter( *tu.cu );
@@ -358,6 +359,7 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy
       trTypeHor = trTypeVer = DST7;
   }
 }
+#endif
 
 void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPelBuf &resi, CoeffBuf &dstCoeff, const int width, const int height )
 {
@@ -371,7 +373,11 @@ void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPel
   int trTypeHor = DCT2;
   int trTypeVer = DCT2;
 
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  TU::getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+#else
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+#endif
 
   const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
@@ -439,7 +445,11 @@ void TrQuant::xIT( const TransformUnit &tu, const ComponentID &compID, const CCo
   int trTypeHor = DCT2;
   int trTypeVer = DCT2;
 
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  TU::getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+#else
   getTrTypes ( tu, compID, trTypeHor, trTypeVer );
+#endif
 
   const int      skipWidth  = ( trTypeHor != DCT2 && width  == 32 ) ? 16 : width  > JVET_C0024_ZERO_OUT_TH ? width  - JVET_C0024_ZERO_OUT_TH : 0;
   const int      skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0;
diff --git a/source/Lib/CommonLib/TrQuant.h b/source/Lib/CommonLib/TrQuant.h
index 85964c1c8..e023c8c62 100644
--- a/source/Lib/CommonLib/TrQuant.h
+++ b/source/Lib/CommonLib/TrQuant.h
@@ -78,8 +78,9 @@ public:
                     const bool useTransformSkipFast = false
   );
 
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
   void getTrTypes( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer );
-
+#endif
 
 protected:
 
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index 67b78ab04..d26c385b9 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -50,6 +50,8 @@
 #include <assert.h>
 #include <cassert>
 
+#define JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP             1 // JVET-N0866: unified transform derivation for ISP and implicit MTS (combining JVET-N0172, JVET-N0375, JVET-N0419 and JVET-N0420)
+
 #define JVET_N0103_CGSIZE_HARMONIZATION                   1 // Chroma CG sizes aligned to luma CG sizes
 
 #define JVET_N0146_DMVR_BDOF_CONDITION                    1 // JVET-N146/N0162/N0442/N0153/N0262/N0440/N0086 applicable condition of DMVR and BDOF
diff --git a/source/Lib/CommonLib/UnitTools.cpp b/source/Lib/CommonLib/UnitTools.cpp
index 43bfb9bf6..04b41e437 100644
--- a/source/Lib/CommonLib/UnitTools.cpp
+++ b/source/Lib/CommonLib/UnitTools.cpp
@@ -5568,6 +5568,7 @@ bool TU::getPrevTuCbfAtDepth( const TransformUnit &currentTu, const ComponentID
   return ( prevTU != nullptr ) ? TU::getCbfAtDepth( *prevTU, compID, trDepth ) : false;
 }
 
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
 void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV )
 {
   typeH = DCT2, typeV = DCT2;
@@ -5600,7 +5601,80 @@ void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID,
   typeV = tuArea.height <= 2 || tuArea.height >= 32 ? DCT2 : typeV;
 }
 
+#endif
+
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+void TU::getTrTypes ( const TransformUnit &tu, const ComponentID compID, int &trTypeHor, int &trTypeVer )
+{
+  const bool isExplicitMTS = (CU::isIntra(*tu.cu) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter(*tu.cu)) && isLuma(compID);
+  const bool isImplicitMTS = CU::isIntra(*tu.cu) && tu.cs->sps->getUseImplicitMTS() && isLuma(compID);
+  const bool isISP         = CU::isIntra(*tu.cu) && tu.cu->ispMode && isLuma(compID);
+  const bool isSBT         = CU::isInter(*tu.cu) && tu.cu->sbtInfo && isLuma(compID);
+
+  trTypeHor = DCT2;
+  trTypeVer = DCT2;
+
+  if (isImplicitMTS || isISP)
+  {
+    int  width       = tu.blocks[compID].width;
+    int  height      = tu.blocks[compID].height;
+    bool widthDstOk  = width  >= 4 && width  <= 16;
+    bool heightDstOk = height >= 4 && height <= 16;
+
+    if (widthDstOk)
+      trTypeHor = DST7;
+    if (heightDstOk)
+      trTypeVer = DST7;
+    return;
+  }
+
+  if( isSBT )
+  {
+    uint8_t sbtIdx = tu.cu->getSbtIdx();
+    uint8_t sbtPos = tu.cu->getSbtPos();
+
+    if( sbtIdx == SBT_VER_HALF || sbtIdx == SBT_VER_QUAD )
+    {
+      assert( tu.lwidth() <= MTS_INTER_MAX_CU_SIZE );
+      if( tu.lheight() > MTS_INTER_MAX_CU_SIZE )
+      {
+        trTypeHor = trTypeVer = DCT2;
+      }
+      else
+      {
+        if( sbtPos == SBT_POS0 )  { trTypeHor = DCT8;  trTypeVer = DST7; }
+        else                      { trTypeHor = DST7;  trTypeVer = DST7; }
+      }
+    }
+    else
+    {
+      assert( tu.lheight() <= MTS_INTER_MAX_CU_SIZE );
+      if( tu.lwidth() > MTS_INTER_MAX_CU_SIZE )
+      {
+        trTypeHor = trTypeVer = DCT2;
+      }
+      else
+      {
+        if( sbtPos == SBT_POS0 )  { trTypeHor = DST7;  trTypeVer = DCT8; }
+        else                      { trTypeHor = DST7;  trTypeVer = DST7; }
+      }
+    }
+    return;
+  }
+
+  if ( isExplicitMTS )
+  {
+    if ( tu.mtsIdx > 1 )
+    {
+      int indHor = ( tu.mtsIdx - 2 ) &  1;
+      int indVer = ( tu.mtsIdx - 2 ) >> 1;
 
+      trTypeHor = indHor ? DCT8 : DST7;
+      trTypeVer = indVer ? DCT8 : DST7;
+    }
+  }
+}
+#endif
 // other tools
 
 uint32_t getCtuAddr( const Position& pos, const PreCalcValues& pcv )
diff --git a/source/Lib/CommonLib/UnitTools.h b/source/Lib/CommonLib/UnitTools.h
index 45b749c7a..df6ddbe66 100644
--- a/source/Lib/CommonLib/UnitTools.h
+++ b/source/Lib/CommonLib/UnitTools.h
@@ -220,7 +220,13 @@ namespace TU
 #endif
   TransformUnit* getPrevTU          ( const TransformUnit &tu, const ComponentID compID );
   bool           getPrevTuCbfAtDepth( const TransformUnit &tu, const ComponentID compID, const int trDepth );
+#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
   void           getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV );
+#endif
+#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP
+  void          getTrTypes         ( const TransformUnit &tu, const ComponentID compID, int &trTypeHor, int &trTypeVer);
+#endif
+
 }
 
 uint32_t getCtuAddr        (const Position& pos, const PreCalcValues &pcv);
-- 
GitLab