From 3ef354706ad13ab558e0b2bd34125d598c1d6abf Mon Sep 17 00:00:00 2001 From: hegilmez <hegilmez@qti.qualcomm.com> Date: Tue, 9 Apr 2019 13:21:51 -0700 Subject: [PATCH] JVET-N0866: unified transform derivation for ISP and implicit MTS - Implicit transform derivation - DST-7 is applied horizontally/vertically as long as the number of (luma) samples are less than or equal to 16 in a row/column, otherwise DCT-2 is applied. --- source/Lib/CommonLib/TrQuant.cpp | 10 ++++ source/Lib/CommonLib/TrQuant.h | 3 +- source/Lib/CommonLib/TypeDef.h | 2 + source/Lib/CommonLib/UnitTools.cpp | 74 ++++++++++++++++++++++++++++++ source/Lib/CommonLib/UnitTools.h | 6 +++ 5 files changed, 94 insertions(+), 1 deletion(-) diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp index 8c701503e..b542d274a 100644 --- a/source/Lib/CommonLib/TrQuant.cpp +++ b/source/Lib/CommonLib/TrQuant.cpp @@ -281,6 +281,7 @@ void TrQuant::invRdpcmNxN(TransformUnit& tu, const ComponentID &compID, PelBuf & // Logical transform // ------------------------------------------------------------------------------------------------ +#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer ) { bool mtsActivated = CU::isIntra( *tu.cu ) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter( *tu.cu ); @@ -358,6 +359,7 @@ void TrQuant::getTrTypes ( TransformUnit tu, const ComponentID compID, int &trTy trTypeHor = trTypeVer = DST7; } } +#endif void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPelBuf &resi, CoeffBuf &dstCoeff, const int width, const int height ) { @@ -371,7 +373,11 @@ void TrQuant::xT( const TransformUnit &tu, const ComponentID &compID, const CPel int trTypeHor = DCT2; int trTypeVer = DCT2; +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP + TU::getTrTypes ( tu, compID, trTypeHor, trTypeVer ); +#else getTrTypes ( tu, compID, trTypeHor, trTypeVer ); +#endif const int skipWidth = ( trTypeHor != DCT2 && width == 32 ) ? 16 : width > JVET_C0024_ZERO_OUT_TH ? width - JVET_C0024_ZERO_OUT_TH : 0; const int skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0; @@ -439,7 +445,11 @@ void TrQuant::xIT( const TransformUnit &tu, const ComponentID &compID, const CCo int trTypeHor = DCT2; int trTypeVer = DCT2; +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP + TU::getTrTypes ( tu, compID, trTypeHor, trTypeVer ); +#else getTrTypes ( tu, compID, trTypeHor, trTypeVer ); +#endif const int skipWidth = ( trTypeHor != DCT2 && width == 32 ) ? 16 : width > JVET_C0024_ZERO_OUT_TH ? width - JVET_C0024_ZERO_OUT_TH : 0; const int skipHeight = ( trTypeVer != DCT2 && height == 32 ) ? 16 : height > JVET_C0024_ZERO_OUT_TH ? height - JVET_C0024_ZERO_OUT_TH : 0; diff --git a/source/Lib/CommonLib/TrQuant.h b/source/Lib/CommonLib/TrQuant.h index 85964c1c8..e023c8c62 100644 --- a/source/Lib/CommonLib/TrQuant.h +++ b/source/Lib/CommonLib/TrQuant.h @@ -78,8 +78,9 @@ public: const bool useTransformSkipFast = false ); +#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP void getTrTypes( TransformUnit tu, const ComponentID compID, int &trTypeHor, int &trTypeVer ); - +#endif protected: diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 67b78ab04..d26c385b9 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -50,6 +50,8 @@ #include <assert.h> #include <cassert> +#define JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP 1 // JVET-N0866: unified transform derivation for ISP and implicit MTS (combining JVET-N0172, JVET-N0375, JVET-N0419 and JVET-N0420) + #define JVET_N0103_CGSIZE_HARMONIZATION 1 // Chroma CG sizes aligned to luma CG sizes #define JVET_N0146_DMVR_BDOF_CONDITION 1 // JVET-N146/N0162/N0442/N0153/N0262/N0440/N0086 applicable condition of DMVR and BDOF diff --git a/source/Lib/CommonLib/UnitTools.cpp b/source/Lib/CommonLib/UnitTools.cpp index 43bfb9bf6..04b41e437 100644 --- a/source/Lib/CommonLib/UnitTools.cpp +++ b/source/Lib/CommonLib/UnitTools.cpp @@ -5568,6 +5568,7 @@ bool TU::getPrevTuCbfAtDepth( const TransformUnit ¤tTu, const ComponentID return ( prevTU != nullptr ) ? TU::getCbfAtDepth( *prevTU, compID, trDepth ) : false; } +#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV ) { typeH = DCT2, typeV = DCT2; @@ -5600,7 +5601,80 @@ void TU::getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, typeV = tuArea.height <= 2 || tuArea.height >= 32 ? DCT2 : typeV; } +#endif + +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP +void TU::getTrTypes ( const TransformUnit &tu, const ComponentID compID, int &trTypeHor, int &trTypeVer ) +{ + const bool isExplicitMTS = (CU::isIntra(*tu.cu) ? tu.cs->sps->getUseIntraMTS() : tu.cs->sps->getUseInterMTS() && CU::isInter(*tu.cu)) && isLuma(compID); + const bool isImplicitMTS = CU::isIntra(*tu.cu) && tu.cs->sps->getUseImplicitMTS() && isLuma(compID); + const bool isISP = CU::isIntra(*tu.cu) && tu.cu->ispMode && isLuma(compID); + const bool isSBT = CU::isInter(*tu.cu) && tu.cu->sbtInfo && isLuma(compID); + + trTypeHor = DCT2; + trTypeVer = DCT2; + + if (isImplicitMTS || isISP) + { + int width = tu.blocks[compID].width; + int height = tu.blocks[compID].height; + bool widthDstOk = width >= 4 && width <= 16; + bool heightDstOk = height >= 4 && height <= 16; + + if (widthDstOk) + trTypeHor = DST7; + if (heightDstOk) + trTypeVer = DST7; + return; + } + + if( isSBT ) + { + uint8_t sbtIdx = tu.cu->getSbtIdx(); + uint8_t sbtPos = tu.cu->getSbtPos(); + + if( sbtIdx == SBT_VER_HALF || sbtIdx == SBT_VER_QUAD ) + { + assert( tu.lwidth() <= MTS_INTER_MAX_CU_SIZE ); + if( tu.lheight() > MTS_INTER_MAX_CU_SIZE ) + { + trTypeHor = trTypeVer = DCT2; + } + else + { + if( sbtPos == SBT_POS0 ) { trTypeHor = DCT8; trTypeVer = DST7; } + else { trTypeHor = DST7; trTypeVer = DST7; } + } + } + else + { + assert( tu.lheight() <= MTS_INTER_MAX_CU_SIZE ); + if( tu.lwidth() > MTS_INTER_MAX_CU_SIZE ) + { + trTypeHor = trTypeVer = DCT2; + } + else + { + if( sbtPos == SBT_POS0 ) { trTypeHor = DST7; trTypeVer = DCT8; } + else { trTypeHor = DST7; trTypeVer = DST7; } + } + } + return; + } + + if ( isExplicitMTS ) + { + if ( tu.mtsIdx > 1 ) + { + int indHor = ( tu.mtsIdx - 2 ) & 1; + int indVer = ( tu.mtsIdx - 2 ) >> 1; + trTypeHor = indHor ? DCT8 : DST7; + trTypeVer = indVer ? DCT8 : DST7; + } + } +} +#endif // other tools uint32_t getCtuAddr( const Position& pos, const PreCalcValues& pcv ) diff --git a/source/Lib/CommonLib/UnitTools.h b/source/Lib/CommonLib/UnitTools.h index 45b749c7a..df6ddbe66 100644 --- a/source/Lib/CommonLib/UnitTools.h +++ b/source/Lib/CommonLib/UnitTools.h @@ -220,7 +220,13 @@ namespace TU #endif TransformUnit* getPrevTU ( const TransformUnit &tu, const ComponentID compID ); bool getPrevTuCbfAtDepth( const TransformUnit &tu, const ComponentID compID, const int trDepth ); +#if !JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP void getTransformTypeISP( const TransformUnit &tu, const ComponentID compID, int &typeH, int &typeV ); +#endif +#if JVET_N0866_UNIF_TRFM_SEL_IMPL_MTS_ISP + void getTrTypes ( const TransformUnit &tu, const ComponentID compID, int &trTypeHor, int &trTypeVer); +#endif + } uint32_t getCtuAddr (const Position& pos, const PreCalcValues &pcv); -- GitLab