From bdaaaa0419253a05eb8990897c2d6abe59c5fe95 Mon Sep 17 00:00:00 2001
From: Charles Bonnineau <charles.bonnineau@interdigital.com>
Date: Fri, 15 Nov 2024 19:51:37 +0000
Subject: [PATCH] JVET-AJ0257: Improved Implicit MTS (Test 4.2)

---
 source/Lib/CommonLib/Rom.cpp     | 33 +++++++++++++++
 source/Lib/CommonLib/Rom.h       |  6 +++
 source/Lib/CommonLib/TrQuant.cpp | 69 ++++++++++++++++++++++++++++++++
 source/Lib/CommonLib/TypeDef.h   |  1 +
 4 files changed, 109 insertions(+)

diff --git a/source/Lib/CommonLib/Rom.cpp b/source/Lib/CommonLib/Rom.cpp
index b4aba4478..cf824a2d1 100644
--- a/source/Lib/CommonLib/Rom.cpp
+++ b/source/Lib/CommonLib/Rom.cpp
@@ -4498,6 +4498,39 @@ const uint8_t g_aucIpmToTrSet[16][36] =
   {70,70,71,71,71,71,71,71,71,71,71,71,71,72,72,72,72,72,72,72,72,72,72,72,73,73,73,73,73,73,73,73,73,73,73,74 }, //32x16
   {75,75,76,76,76,76,76,76,76,76,76,76,76,77,77,77,77,77,77,77,77,77,77,77,78,78,78,78,78,78,78,78,78,78,78,79 }, //32x32
 };
+
+#if JVET_AJ0257_IMPLICIT_MTS_LUT
+const uint8_t g_aucImplicitToTrSet[16][35] =
+{
+  { 14,15,14,14,14,14,14,14,14,14,20,20,14,20,14,20,20,20,20,20,20,20,20,20,20,20,14,20,20,20,20,14,14,14,14}, 
+  { 14,15,14,14,14,14,14,14,14,14,14,14,14,20,14,20,20,20,20,20,21,21,14,20,15,14,14,20,14,14,14,14,14,14,14}, 
+  { 14,15,20,15,14,15,14,15,15,15,15,15,15,15,15,15,15,21,15,21,21,21,21,21,15,21,15,15,21,15,14,15,14,21,15}, 
+  { 21,15,21,15,15,21,15,18,12,15,15,15,21,18,21,21,18,21,21,18,12,18,18,12,18,21,18,18,18,21,21,15,15,18,15}, 
+  { 14,14,14,14,14,20,20, 2,14,14,14,14,14,14,14,14,14,14,14,20,15,20,14,14,14,14,15,14,14,14,14,14,14,14,14}, 
+  { 14,14,14,15,14,14,14,14,14,14,14,15,15,15,15,15,15,21,20,21,21,20,15,15,14,15,14,14,14,14,14,14,14,14,14}, 
+  { 14,21,14,15,14,15,14,15,15,15,14,15,14,15,15,21,15,21,21,21,21,21,21,21,21,15,15,21,15,21,14,15,14,20,14}, 
+  { 20,21,15,15,15,15,15,15,21,15,21,15,15,15,21,21,18,21,18,18,18,15,21,21,21,21,21,15,21,15,15,15,21,21,18}, 
+  { 14,20,15,15,15,15,15,15,21,15,21,15,14,14,15,15,14,20,20,18,18,18,14,14,14,14,14,14,15,21,15,14,20,21,14}, 
+  { 14,21,14,12,14,15,14,15,14,15,14,15,14,14,15,15,21,20,21,21,21,20,15,20,14,14,14,20,14,20,14,20,14,14,14}, 
+  { 14,15,14,12,14,15,14,15,14,21,14,21,15,21,14,21,14,21,21,18,21,21,15,21,15,21,14,21,14,14,14,21,14,21,14}, 
+  { 14,12,14,15,14,15,14,15,15,15,15,15,18,21,21,18,21,18,21,18,18,15,15,21, 3,12,21, 0,12, 3,18,12,14, 3,18}, 
+  { 12, 2,20,21,21, 3,20,20,20,21,20, 3,20, 3,21, 2,21,15,21,21,21,20,21,21,21,21,20, 2,14, 2,20,21, 3,21,20}, 
+  { 12,21, 3,20,20,21,14,20,20,21,20, 2,21,21, 3,21, 3,21,21, 0,21,18, 3, 3,21,20,18,20,21,20, 2,14, 2,15,20}, 
+  { 12,21,20,20, 2, 3,20,21, 3, 0,18,21,15,15,21,21, 3,21, 0, 0, 3,21,15, 3, 3, 0,14, 0, 0, 3, 3, 3,21,18,20}, 
+  { 12,18,14,18, 0,15,14, 0, 0,21,15,18,21,18,12, 3,21, 0, 3, 0, 0,18, 0, 3,12, 0, 0, 0, 3, 0,12, 0, 3, 3,18}, 
+};
+
+const uint8_t g_aucImplicitTrIdxToTr[36][2] =
+{
+    { DCT2, DCT2 }, { DCT2, DCT8 },{ DCT2, DST7 },{ DCT2, DCT5 },{ DCT2, DST4 }, { DCT2, DST1 },
+    { DCT8, DCT2 }, { DCT8, DCT8 },{ DCT8, DST7 },{ DCT8, DCT5 },{ DCT8, DST4 }, { DCT8, DST1 }, 
+    { DST7, DCT2 }, { DST7, DCT8 },{ DST7, DST7 },{ DST7, DCT5 },{ DST7, DST4 }, { DST7, DST1 },
+    { DCT5, DCT2 }, { DCT5, DCT8 },{ DCT5, DST7 },{ DCT5, DCT5 },{ DCT5, DST4 }, { DCT5, DST1 },
+    { DST4, DCT2 }, { DST4, DCT8 },{ DST4, DST7 },{ DST4, DCT5 },{ DST4, DST4 }, { DST4, DST1 },
+    { DST1, DCT2 }, { DST1, DCT8 },{ DST1, DST7 },{ DST1, DCT5 },{ DST1, DST4 }, { DST1, DST1 },
+};
+#endif
+
 const int8_t g_aiIdLut[3][3] =
 {
   { 8, 6, 4 },{ 8, 8, 6 },{ 4, 2, -1 }
diff --git a/source/Lib/CommonLib/Rom.h b/source/Lib/CommonLib/Rom.h
index 6f832c7a0..9783b2769 100644
--- a/source/Lib/CommonLib/Rom.h
+++ b/source/Lib/CommonLib/Rom.h
@@ -235,6 +235,12 @@ extern TMatrixCoeff g_aiTr128[NUM_TRANS_TYPE][128][128];
 extern TMatrixCoeff g_aiTr256[NUM_TRANS_TYPE][256][256];
 
 extern const uint8_t g_aucIpmToTrSet[16][36];
+
+#if JVET_AJ0257_IMPLICIT_MTS_LUT
+extern const uint8_t g_aucImplicitToTrSet[16][35];
+extern const uint8_t g_aucImplicitTrIdxToTr[36][2];
+#endif
+
 #if JVET_Y0142_ADAPT_INTRA_MTS
 extern const uint8_t g_aucTrSet[80][6];
 #else
diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp
index 4782440b5..1192b02bb 100644
--- a/source/Lib/CommonLib/TrQuant.cpp
+++ b/source/Lib/CommonLib/TrQuant.cpp
@@ -1275,7 +1275,11 @@ std::vector<int> TrQuant::selectICTCandidates( const TransformUnit &tu, CompStor
 void TrQuant::getTrTypes(const TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer)
 {
   const bool isExplicitMTS = (CU::isIntra(*tu.cu) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter(*tu.cu)) && isLuma(compID);
+#if JVET_AJ0257_IMPLICIT_MTS_LUT
+  const bool isImplicitMTS = CU::isIntra(*tu.cu) && tu.cs->sps->getUseImplicitMTS() && isLuma(compID) && !tu.cu->lfnstIdx && !tu.cu->mipFlag && !tu.cu->eipFlag && !tu.cu->tmpFlag && !tu.cu->sgpm;
+#else
   const bool isImplicitMTS = CU::isIntra(*tu.cu) && tu.cs->sps->getUseImplicitMTS() && isLuma(compID) && tu.cu->lfnstIdx == 0 && tu.cu->mipFlag == 0;
+#endif
   const bool isISP = CU::isIntra(*tu.cu) && tu.cu->ispMode && isLuma(compID);
   const bool isSBT = CU::isInter(*tu.cu) && tu.cu->sbtInfo && isLuma(compID);
 
@@ -1310,6 +1314,70 @@ void TrQuant::getTrTypes(const TransformUnit tu, const ComponentID compID, int &
   {
     int  width = tu.blocks[compID].width;
     int  height = tu.blocks[compID].height;
+  
+#if JVET_AJ0257_IMPLICIT_MTS_LUT    
+    const CompArea& area = tu.blocks[compID];
+    int predMode = PU::getFinalIntraMode(*tu.cs->getPU(area.pos(), toChannelType(compID)), toChannelType(compID));
+
+#if JVET_AJ0249_NEURAL_NETWORK_BASED
+    if (predMode == PNN_IDX)
+    {
+      return;
+    }
+#endif
+
+    if (isISP || width < 4 || height < 4  || tu.cu->dimd || tu.cu->timd)
+    {
+      bool widthDstOk = width >= 4 && width <= 16;
+      bool heightDstOk = height >= 4 && height <= 16;
+
+      if (widthDstOk)
+      {
+        trTypeHor = DST7;
+      }
+      if (heightDstOk)
+      {
+        trTypeVer = DST7;
+      }     
+      return;
+    }
+
+#if JVET_AD0085_TMRL_EXTENSION
+    if (tu.cu->tmrlFlag)
+    {
+      predMode = MAP131TO67(predMode);
+    }
+#endif
+
+    predMode = PU::getWideAngle(tu, (uint32_t)predMode, compID);
+    CHECK(predMode < -(NUM_EXT_LUMA_MODE >> 1) || predMode >= NUM_LUMA_MODE + (NUM_EXT_LUMA_MODE >> 1), "luma mode out of range");
+
+#if JVET_AC0105_DIRECTIONAL_PLANAR
+    if (predMode == PLANAR_IDX)
+    {
+      if (tu.cu->plIdx == 2)
+      {
+        predMode = HOR_IDX;
+      }
+      else if (tu.cu->plIdx == 1)
+      {
+        predMode = VER_IDX;
+      }
+    }
+#endif
+    int modeImplicit = predMode < 0 ? predMode + NUM_LUMA_MODE : predMode >= NUM_LUMA_MODE ? predMode - NUM_LUMA_MODE + 2 : predMode;
+    int modeIdx = modeImplicit > DIA_IDX ? (NUM_LUMA_MODE + 1 - modeImplicit) : modeImplicit;
+    bool isTrTransposed = modeImplicit > DIA_IDX ? true :  false;      
+    uint8_t nSzIdxW = std::min(3, (floorLog2(width) - 2));
+    uint8_t nSzIdxH = std::min(3, (floorLog2(height) - 2));
+    uint8_t nSzIdx = isTrTransposed ? (nSzIdxH * 4 + nSzIdxW) : (nSzIdxW * 4 + nSzIdxH);
+    int nTrType = g_aucImplicitToTrSet[nSzIdx][modeIdx];
+      
+    trTypeHor = g_aucImplicitTrIdxToTr[nTrType][isTrTransposed ? 1 : 0];      
+    trTypeVer = g_aucImplicitTrIdxToTr[nTrType][isTrTransposed ? 0 : 1];
+
+    return;    
+#else
     bool widthDstOk = width >= 4 && width <= 16;
     bool heightDstOk = height >= 4 && height <= 16;
 
@@ -1322,6 +1390,7 @@ void TrQuant::getTrTypes(const TransformUnit tu, const ComponentID compID, int &
       trTypeVer = DST7;
     }
     return;
+#endif
   }
 
 
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index 468481b67..30e8b933f 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -429,6 +429,7 @@
 #define JVET_AD0105_ASPECT1_NUM_SIGN_PRED_BY_QP           1 // JVET-AD0105 Aspect1: NumSignPred based on QP
 #define JVET_AI0096_SIGN_PRED_BIT_DEPTH_FIX               1 // JVET-AI0096: Fix to sign prediction for handling bit depths other than 10
 #endif
+#define JVET_AJ0257_IMPLICIT_MTS_LUT                      1 // JVET-AJ0257: Improved Implicit MTS
 #define JVET_W0103_INTRA_MTS                              1 // JVET-W0103: Extended Intra MTS
 #if JVET_W0103_INTRA_MTS
 #define JVET_Y0142_ADAPT_INTRA_MTS                        1 // JVET-Y0142: Adaptive Intra MTS
-- 
GitLab