diff --git a/cfg/encoder_lowdelay_P_ecm.cfg b/cfg/encoder_lowdelay_P_ecm.cfg index 54500d5bd1f512bed1ecf1fbb86024f80add8a2e..575047c4edd26a14737f1d5c3b35b925e487be91 100644 --- a/cfg/encoder_lowdelay_P_ecm.cfg +++ b/cfg/encoder_lowdelay_P_ecm.cfg @@ -100,7 +100,7 @@ MaxMTTHierarchyDepth : 3 MaxMTTHierarchyDepthISliceL : 3 MaxMTTHierarchyDepthISliceC : 3 -MTS : 1 +MTS : 3 MTSIntraMaxCand : 3 MTSInterMaxCand : 4 SBT : 1 diff --git a/cfg/encoder_lowdelay_ecm.cfg b/cfg/encoder_lowdelay_ecm.cfg index c82ae77afcfd935a3496762c90dd20e392a45ae0..f949e38353877347f3fd3779b113dc49e4ccd9b4 100644 --- a/cfg/encoder_lowdelay_ecm.cfg +++ b/cfg/encoder_lowdelay_ecm.cfg @@ -100,7 +100,7 @@ MaxMTTHierarchyDepth : 3 MaxMTTHierarchyDepthISliceL : 3 MaxMTTHierarchyDepthISliceC : 3 -MTS : 1 +MTS : 3 MTSIntraMaxCand : 3 MTSInterMaxCand : 4 SBT : 1 diff --git a/cfg/encoder_randomaccess_ecm.cfg b/cfg/encoder_randomaccess_ecm.cfg index 58c6de3442c0a8e98f6673765e1decde798be227..e20ad326cded71c4868a671917bc5fc2fe788c3d 100644 --- a/cfg/encoder_randomaccess_ecm.cfg +++ b/cfg/encoder_randomaccess_ecm.cfg @@ -127,7 +127,7 @@ MaxMTTHierarchyDepth : 3 MaxMTTHierarchyDepthISliceL : 3 MaxMTTHierarchyDepthISliceC : 3 -MTS : 1 +MTS : 3 MTSIntraMaxCand : 4 MTSInterMaxCand : 4 SBT : 1 diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp index a815a0cc86168613c83a93758251f6540328104a..76fd1f61365192974d7c829703e6f738073e58d8 100644 --- a/source/App/EncoderApp/EncApp.cpp +++ b/source/App/EncoderApp/EncApp.cpp @@ -814,6 +814,9 @@ void EncApp::xInitLibCfg() } } #endif +#if JVET_AA0133_INTER_MTS_OPT + m_cEncLib.setInterMTSMaxSize(m_interMTSMaxSize); +#endif #if ENABLE_DIMD m_cEncLib.setUseDimd ( m_dimd ); #endif diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp index a78e951a6614e3dfede64c4cc91b913aae244df6..3bd83877dd86dda4d02e6ae9873355cc7a2288c4 100644 --- a/source/App/EncoderApp/EncAppCfg.cpp +++ b/source/App/EncoderApp/EncAppCfg.cpp @@ -1059,6 +1059,9 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) ("LadfQpOffset", cfg_LadfQpOffset, cfg_LadfQpOffset, "LADF QP offset") ("LadfIntervalLowerBound", cfg_LadfIntervalLowerBound, cfg_LadfIntervalLowerBound, "LADF lower bound for 2nd lowest interval") #endif +#if JVET_AA0133_INTER_MTS_OPT + ("InterMTSMaxSize", m_interMTSMaxSize, 32, "InterMTSMaxSize") +#endif #if ENABLE_DIMD ( "DIMD", m_dimd, true, "Enable decoder side intra mode derivation\n" ) #endif @@ -2800,7 +2803,13 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) } } #endif - +#if JVET_AA0133_INTER_MTS_OPT +#if JVET_AA0146_WRAP_AROUND_FIX + m_interMTSMaxSize = (m_sourceHeight > 1080) ? 32 : 16; +#else + m_interMTSMaxSize = (m_iSourceHeight > 1080)? 32 : 16; +#endif +#endif if (m_chromaFormatIDC != CHROMA_420) { if (!m_horCollocatedChromaFlag) @@ -4879,6 +4888,12 @@ void EncAppCfg::xPrintParameter() msg( VERBOSE, "HorCollocatedChroma:%d ", m_horCollocatedChromaFlag ); msg( VERBOSE, "VerCollocatedChroma:%d ", m_verCollocatedChromaFlag ); msg( VERBOSE, "MTS: %1d(intra) %1d(inter) ", m_MTS & 1, ( m_MTS >> 1 ) & 1 ); +#if JVET_AA0133_INTER_MTS_OPT + if ((m_MTS >> 1) & 1) + { + msg(VERBOSE, "InterMTSMaxSize: %d ", m_interMTSMaxSize); + } +#endif msg( VERBOSE, "SBT:%d ", m_SBT ); msg( VERBOSE, "ISP:%d ", m_ISP ); msg( VERBOSE, "SMVD:%d ", m_SMVD ); diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h index 7b0d539e19974f16303466bb4a5a53f2dc52d600..4170416e34b9a67ab0e1cda53f45e7f7a351fe0a 100644 --- a/source/App/EncoderApp/EncAppCfg.h +++ b/source/App/EncoderApp/EncAppCfg.h @@ -424,6 +424,9 @@ protected: std::vector<int> m_LadfQpOffset; int m_LadfIntervalLowerBound[MAX_LADF_INTERVALS]; #endif +#if JVET_AA0133_INTER_MTS_OPT + int m_interMTSMaxSize; +#endif #if ENABLE_DIMD bool m_dimd; #endif diff --git a/source/Lib/CommonLib/Rom.cpp b/source/Lib/CommonLib/Rom.cpp index 475a5bc33ee5764455010dd29ba56fd99bc55952..b69de50e4d8427cc1767dd8d28027949b04cbd86 100644 --- a/source/Lib/CommonLib/Rom.cpp +++ b/source/Lib/CommonLib/Rom.cpp @@ -4572,6 +4572,110 @@ void initROM() } c <<= 1; } + +//////////////////////////////////////////////////////////////////////////////////////////////// +#if JVET_AA0133_INTER_MTS_OPT + TMatrixCoeff KLT4[2][4][4] = { + { + { -79, -88, -46, -17}, + { 74, -10, -86, -58}, + { -59, 69, 4, -91}, + { -35, 62, -82, 67}, + }, + { + { 3, 19, 71, 105}, + { 27, 90, 64, -60}, + { -91, -45, 69, -36}, + { 86, -77, 51, -23}, + }, + }; + TMatrixCoeff KLT8[2][8][8] = { + { + { 55, 81, 92, 87, 68, 43, 25, 15}, + { -84, -87, -28, 46, 82, 72, 49, 31}, + { 75, 32, -63, -79, -2, 71, 84, 63}, + { 80, -15, -88, 15, 88, 11, -68, -75}, + { 65, -52, -37, 83, -13, -92, 20, 92}, + { 63, -83, 33, 39, -82, 53, 53, -83}, + { -45, 76, -75, 53, -12, -46, 96, -71}, + { 22, -44, 61, -75, 83, -85, 75, -38}, + }, + { + { 8, 16, 31, 53, 74, 88, 93, 83}, + { 23, 43, 70, 90, 71, 7, -64, -91}, + { -53, -83, -83, -19, 67, 77, -6, -72}, + { -73, -79, -4, 86, 31, -82, -36, 70}, + { -93, -33, 89, 22, -83, 27, 62, -55}, + { -88, 45, 58, -82, 35, 41, -86, 52}, + { -77, 95, -43, -13, 54, -77, 75, -37}, + { -42, 75, -83, 82, -75, 62, -43, 19}, + }, + }; + TMatrixCoeff KLT16[2][16][16] = { + { + { 42, 57, 71, 82, 90, 94, 93, 88, 79, 64, 50, 38, 28, 20, 15, 11}, + { -72, -91, -95, -83, -55, -16, 26, 61, 80, 83, 75, 63, 50, 39, 31, 23}, + { 71, 78, 56, 12, -39, -77, -87, -65, -20, 31, 68, 84, 84, 75, 63, 49}, + { 87, 77, 22, -47, -89, -76, -15, 57, 88, 67, 17, -31, -62, -75, -74, -61}, + { 80, 51, -24, -78, -62, 12, 77, 73, 3, -73, -90, -49, 12, 62, 88, 84}, + { 89, 29, -68, -85, -1, 84, 64, -37, -87, -31, 53, 84, 48, -16, -68, -82}, + { 80, -1, -84, -36, 68, 62, -44, -82, 10, 90, 41, -58, -89, -35, 51, 93}, + { 83, -31, -91, 24, 91, -21, -90, 20, 85, -15, -83, -15, 73, 67, -19, -82}, + { 66, -50, -57, 69, 34, -82, -5, 88, -25, -87, 41, 88, -29, -96, -10, 87}, + { 67, -72, -27, 91, -35, -67, 83, 7, -85, 52, 56, -78, -37, 82, 38, -74}, + { 57, -80, 16, 62, -77, 17, 59, -83, 27, 63, -89, 5, 93, -54, -71, 76}, + { 55, -92, 57, 17, -77, 82, -32, -38, 81, -68, 4, 72, -81, -3, 91, -66}, + { 43, -82, 77, -37, -19, 68, -87, 73, -28, -32, 79, -81, 25, 58, -100, 58}, + { -28, 60, -71, 64, -37, 0, 43, -79, 98, -96, 71, -21, -40, 84, -88, 43}, + { 30, -67, 92, -106, 108, -101, 88, -71, 55, -40, 25, -11, 0, 7, -8, 4}, + { 6, -13, 17, -19, 17, -10, 0, 16, -36, 63, -91, 115, -125, 115, -84, 36}, + }, + { + { 6, 10, 15, 22, 30, 39, 51, 63, 73, 81, 86, 90, 91, 89, 83, 74}, + { -19, -28, -39, -50, -63, -75, -84, -87, -77, -53, -20, 18, 53, 80, 95, 93}, + { 43, 60, 75, 85, 86, 73, 41, -5, -49, -80, -86, -64, -21, 28, 68, 81}, + { 58, 75, 78, 66, 36, -12, -61, -88, -69, -9, 57, 92, 75, 15, -54, -89}, + { 83, 92, 66, 14, -48, -89, -75, -7, 64, 83, 34, -42, -81, -50, 26, 78}, + { 83, 72, 17, -49, -84, -52, 32, 88, 47, -51, -89, -25, 70, 85, -2, -83}, + { 92, 54, -35, -90, -54, 45, 90, 14, -78, -54, 49, 80, -13, -88, -24, 73}, + { 83, 23, -68, -74, 22, 87, 12, -87, -25, 87, 36, -83, -44, 80, 50, -71}, + { 87, -10, -95, -22, 91, 34, -89, -23, 89, 1, -85, 21, 78, -43, -62, 58}, + { 76, -37, -80, 42, 74, -61, -47, 86, -5, -86, 60, 48, -89, 10, 82, -58}, + { 79, -71, -52, 94, -3, -86, 68, 24, -85, 57, 26, -81, 53, 29, -81, 48}, + { 68, -90, 7, 78, -78, 5, 66, -84, 42, 34, -84, 71, -6, -64, 89, -45}, + { 59, -98, 60, 19, -79, 84, -41, -21, 70, -88, 68, -15, -43, 80, -78, 36}, + { 44, -87, 86, -47, -9, 61, -91, 97, -80, 43, 3, -44, 70, -75, 59, -25}, + { 5, -8, 6, 2, -13, 28, -43, 59, -76, 92, -105, 110, -103, 86, -59, 23}, + { 36, -81, 111, -122, 114, -94, 67, -41, 18, 0, -14, 22, -25, 22, -16, 7}, + }, + }; + + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + g_aiTr4[KLT0][i][j] = KLT4[0][i][j]; + g_aiTr4[KLT1][i][j] = KLT4[1][i][j]; + } + } + for (int i = 0; i < 8; i++) + { + for (int j = 0; j < 8; j++) + { + g_aiTr8[KLT0][i][j] = KLT8[0][i][j]; + g_aiTr8[KLT1][i][j] = KLT8[1][i][j]; + } + } + for (int i = 0; i < 16; i++) + { + for (int j = 0; j < 16; j++) + { + g_aiTr16[KLT0][i][j] = KLT16[0][i][j]; + g_aiTr16[KLT1][i][j] = KLT16[1][i][j]; + } + } +#endif +//////////////////////////////////////////////////////////////////////////////////////////////// #endif #if JVET_AA0107_RMVF_AFFINE_MERGE_DERIVATION g_rmvfMultApproxTbl[0] = 0; diff --git a/source/Lib/CommonLib/Slice.cpp b/source/Lib/CommonLib/Slice.cpp index ab0b8a20f9a8cff488b6c2d06cf70fec8a326ce9..08eea3167478e203e4dedb42516039880786033e 100644 --- a/source/Lib/CommonLib/Slice.cpp +++ b/source/Lib/CommonLib/Slice.cpp @@ -3604,7 +3604,9 @@ SPS::SPS() , m_LadfQpOffset { 0 } , m_LadfIntervalLowerBound { 0 } #endif - +#if JVET_AA0133_INTER_MTS_OPT +, m_interMTSMaxSize ( 32 ) +#endif #if MULTI_HYP_PRED , m_InterMultiHyp(false) , m_maxNumAddHyps(0) diff --git a/source/Lib/CommonLib/Slice.h b/source/Lib/CommonLib/Slice.h index 477d3b69177c3c85c77e039009c9dc2cfd4bc31d..1ca1eeebbc16eba3af3dea0c4ecce3e5c2bd663a 100644 --- a/source/Lib/CommonLib/Slice.h +++ b/source/Lib/CommonLib/Slice.h @@ -1715,7 +1715,9 @@ private: int m_LadfQpOffset[MAX_LADF_INTERVALS]; int m_LadfIntervalLowerBound[MAX_LADF_INTERVALS]; #endif - +#if JVET_AA0133_INTER_MTS_OPT + int m_interMTSMaxSize; +#endif #if MULTI_HYP_PRED bool m_InterMultiHyp; // multi hypothesis inter prediction int m_maxNumAddHyps; @@ -2173,7 +2175,10 @@ void setCCALFEnabledFlag( bool b ) void setLadfIntervalLowerBound( int value, int idx ) { m_LadfIntervalLowerBound[ idx ] = value; } int getLadfIntervalLowerBound( int idx ) const { return m_LadfIntervalLowerBound[ idx ]; } #endif - +#if JVET_AA0133_INTER_MTS_OPT + void setInterMTSMaxSize(int size) { m_interMTSMaxSize = size; } + int getInterMTSMaxSize() const { return m_interMTSMaxSize; } +#endif #if MULTI_HYP_PRED bool getUseInterMultiHyp() const { return m_InterMultiHyp; } int getMaxNumAddHyps() const { return m_maxNumAddHyps; } diff --git a/source/Lib/CommonLib/TrQuant.cpp b/source/Lib/CommonLib/TrQuant.cpp index 05bb11fabb968624fa6073d7bd39ee3d49486e8b..f8b39827d48ff6fe169d73854f7c3c7af5261caa 100644 --- a/source/Lib/CommonLib/TrQuant.cpp +++ b/source/Lib/CommonLib/TrQuant.cpp @@ -245,6 +245,10 @@ void TrQuant::init( const Quant* otherQuant, { nullptr, fastForwardDST4_B4, fastForwardDST4_B8, fastForwardDST4_B16, fastForwardDST4_B32, fastForwardDST4_B64, fastForwardDST4_B128, fastForwardDST4_B256 }, { nullptr, fastForwardDST1_B4, fastForwardDST1_B8, fastForwardDST1_B16, fastForwardDST1_B32, fastForwardDST1_B64, fastForwardDST1_B128, fastForwardDST1_B256 }, { nullptr, fastForwardIDTR_B4, fastForwardIDTR_B8, fastForwardIDTR_B16, fastForwardIDTR_B32, fastForwardIDTR_B64, fastForwardIDTR_B128, fastForwardIDTR_B256 }, +#if JVET_AA0133_INTER_MTS_OPT + {nullptr, fastForwardKLT0_B4, fastForwardKLT0_B8, fastForwardKLT0_B16, nullptr, nullptr, nullptr, nullptr }, + {nullptr, fastForwardKLT1_B4, fastForwardKLT1_B8, fastForwardKLT1_B16, nullptr, nullptr, nullptr, nullptr }, +#endif #endif } }; @@ -258,6 +262,10 @@ void TrQuant::init( const Quant* otherQuant, { nullptr, fastInverseDST4_B4, fastInverseDST4_B8, fastInverseDST4_B16, fastInverseDST4_B32, fastInverseDST4_B64, fastInverseDST4_B128, fastInverseDST4_B256 }, { nullptr, fastInverseDST1_B4, fastInverseDST1_B8, fastInverseDST1_B16, fastInverseDST1_B32, fastInverseDST1_B64, fastInverseDST1_B128, fastInverseDST1_B256 }, { nullptr, fastInverseIDTR_B4, fastInverseIDTR_B8, fastInverseIDTR_B16, fastInverseIDTR_B32, fastInverseIDTR_B64, fastInverseIDTR_B128, fastInverseIDTR_B256 }, +#if JVET_AA0133_INTER_MTS_OPT + {nullptr, fastInverseKLT0_B4, fastInverseKLT0_B8, fastInverseKLT0_B16, nullptr, nullptr, nullptr, nullptr }, + {nullptr, fastInverseKLT1_B4, fastInverseKLT1_B8, fastInverseKLT1_B16, nullptr, nullptr, nullptr, nullptr }, +#endif #endif } }; #else @@ -1167,6 +1175,16 @@ void TrQuant::getTrTypes(const TransformUnit tu, const ComponentID compID, int & int indVer = (tu.mtsIdx[compID] - MTS_DST7_DST7) >> 1; trTypeHor = indHor ? DCT8 : DST7; trTypeVer = indVer ? DCT8 : DST7; +#if JVET_AA0133_INTER_MTS_OPT + uint32_t width = tu.blocks[compID].width; + uint32_t height = tu.blocks[compID].height; + CHECK(width < 4 || height < 4, "width < 4 || height < 4 for KLT"); + if (width <= 16 && height <= 16) + { + trTypeHor = indHor ? KLT1 : KLT0; + trTypeVer = indVer ? KLT1 : KLT0; + } +#endif } } } @@ -1409,10 +1427,12 @@ void TrQuant::transformNxN( TransformUnit& tu, const ComponentID& compID, const #else const double facBB[] = { 1.2, 1.3, 1.3, 1.4, 1.5 }; #endif + while( it != trModes->end() ) { tu.mtsIdx[compID] = it->first; CoeffBuf tempCoeff( m_mtsCoeffs[tu.mtsIdx[compID]], rect); + if( tu.noResidual ) { int sumAbs = 0; @@ -1420,7 +1440,6 @@ void TrQuant::transformNxN( TransformUnit& tu, const ComponentID& compID, const it++; continue; } - if ( tu.mtsIdx[compID] == MTS_SKIP ) { xTransformSkip( tu, compID, resiBuf, tempCoeff.buf ); @@ -1451,7 +1470,6 @@ void TrQuant::transformNxN( TransformUnit& tu, const ComponentID& compID, const tu.cu->slice->getSPS()->getMaxLog2TrDynamicRange(toChannelType(compID))); scaleSAD *= pow(2, trShift); } - #if JVET_R0351_HIGH_BIT_DEPTH_SUPPORT_VS trCosts.push_back( TrCost( int(std::min<double>(sumAbs*scaleSAD, std::numeric_limits<int>::max())), pos++ ) ); #else @@ -1459,7 +1477,25 @@ void TrQuant::transformNxN( TransformUnit& tu, const ComponentID& compID, const #endif it++; } - +#if JVET_AA0133_INTER_MTS_OPT + if (CU::isInter(*tu.cu) && tu.cu->mtsFlag && compID == COMPONENT_Y) + { + std::stable_sort(trCosts.begin(), trCosts.end(), [](const TrCost l, const TrCost r) {return l.first < r.first; }); + std::vector<TrMode> trModesTemp; + trModesTemp.resize(trModes->size()); + for (int i = 0; i < trModes->size(); i++) + { + trModesTemp[i] = trModes->at(i); + } + for (int i = 0; i < trModes->size(); i++) + { + int index = trCosts[i].second; + trModes->at(i) = trModesTemp[index]; + } + trModesTemp.resize(0); + return; + } +#endif int numTests = 0; std::vector<TrCost>::iterator itC = trCosts.begin(); const double fac = facBB[std::max(0, floorLog2(std::max(width, height)) - 2)]; diff --git a/source/Lib/CommonLib/TrQuant_EMT.cpp b/source/Lib/CommonLib/TrQuant_EMT.cpp index ef0050b970e5e459b0454840ed5c66968b6650ad..10af6a81c9a766a44186bb9cadbca34922dae917 100644 --- a/source/Lib/CommonLib/TrQuant_EMT.cpp +++ b/source/Lib/CommonLib/TrQuant_EMT.cpp @@ -2258,4 +2258,66 @@ void fastInverseIDTR_B256(const TCoeff *src, TCoeff *dst, int shift, int line, i { _fastInverseMM< 256 >(src, dst, shift, line, iSkipLine, iSkipLine2, outputMinimum, outputMaximum, g_aiTr256[IDTR][0]); } +#if JVET_AA0133_INTER_MTS_OPT +//KLT0 +void fastForwardKLT0_B4(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2) +{ + _fastForwardMM< 4 >(src, dst, shift, line, iSkipLine, iSkipLine2, g_aiTr4[KLT0][0]); +} + +void fastInverseKLT0_B4(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum) +{ + _fastInverseMM< 4 >(src, dst, shift, line, iSkipLine, iSkipLine2, outputMinimum, outputMaximum, g_aiTr4[KLT0][0]); +} + +void fastForwardKLT0_B8(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2) +{ + _fastForwardMM< 8 >(src, dst, shift, line, iSkipLine, iSkipLine2, g_aiTr8[KLT0][0]); +} + +void fastInverseKLT0_B8(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum) +{ + _fastInverseMM< 8 >(src, dst, shift, line, iSkipLine, iSkipLine2, outputMinimum, outputMaximum, g_aiTr8[KLT0][0]); +} + +void fastForwardKLT0_B16(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2) +{ + _fastForwardMM< 16 >(src, dst, shift, line, iSkipLine, iSkipLine2, g_aiTr16[KLT0][0]); +} + +void fastInverseKLT0_B16(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum) +{ + _fastInverseMM< 16 >(src, dst, shift, line, iSkipLine, iSkipLine2, outputMinimum, outputMaximum, g_aiTr16[KLT0][0]); +} +//KLT1 +void fastForwardKLT1_B4(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2) +{ + _fastForwardMM< 4 >(src, dst, shift, line, iSkipLine, iSkipLine2, g_aiTr4[KLT1][0]); +} + +void fastInverseKLT1_B4(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum) +{ + _fastInverseMM< 4 >(src, dst, shift, line, iSkipLine, iSkipLine2, outputMinimum, outputMaximum, g_aiTr4[KLT1][0]); +} + +void fastForwardKLT1_B8(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2) +{ + _fastForwardMM< 8 >(src, dst, shift, line, iSkipLine, iSkipLine2, g_aiTr8[KLT1][0]); +} + +void fastInverseKLT1_B8(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum) +{ + _fastInverseMM< 8 >(src, dst, shift, line, iSkipLine, iSkipLine2, outputMinimum, outputMaximum, g_aiTr8[KLT1][0]); +} + +void fastForwardKLT1_B16(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2) +{ + _fastForwardMM< 16 >(src, dst, shift, line, iSkipLine, iSkipLine2, g_aiTr16[KLT1][0]); +} + +void fastInverseKLT1_B16(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum) +{ + _fastInverseMM< 16 >(src, dst, shift, line, iSkipLine, iSkipLine2, outputMinimum, outputMaximum, g_aiTr16[KLT1][0]); +} +#endif #endif diff --git a/source/Lib/CommonLib/TrQuant_EMT.h b/source/Lib/CommonLib/TrQuant_EMT.h index 5a97b41b6907b466f41e52b5d316b24832b2bd62..a2a71b8fea09d371b5b0734ad21e9b511326babd 100644 --- a/source/Lib/CommonLib/TrQuant_EMT.h +++ b/source/Lib/CommonLib/TrQuant_EMT.h @@ -169,6 +169,23 @@ void fastInverseIDTR_B128(const TCoeff *src, TCoeff *dst, int shift, int line, i void fastForwardIDTR_B256(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2); void fastInverseIDTR_B256(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum); #endif +#if JVET_AA0133_INTER_MTS_OPT +//KLT0 transforms +void fastForwardKLT0_B4(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2); +void fastInverseKLT0_B4(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum); +void fastForwardKLT0_B8(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2); +void fastInverseKLT0_B8(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum); +void fastForwardKLT0_B16(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2); +void fastInverseKLT0_B16(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum); + +//KLT1 transforms +void fastForwardKLT1_B4(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2); +void fastInverseKLT1_B4(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum); +void fastForwardKLT1_B8(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2); +void fastInverseKLT1_B8(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum); +void fastForwardKLT1_B16(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2); +void fastInverseKLT1_B16(const TCoeff *src, TCoeff *dst, int shift, int line, int iSkipLine, int iSkipLine2, const TCoeff outputMinimum, const TCoeff outputMaximum); +#endif #endif #endif // __TRQUANT__ diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 4dff1f1e01fa7f52c2d7dbca3d74d69e427fea8e..486ccfba4aa6e4dd55bd3a06daf44cb769947051 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -236,6 +236,7 @@ #define JVET_Y0159_INTER_MTS 1 // JVET-Y0159: Inter MTS uses fixed 4 candidates #endif #endif +#define JVET_AA0133_INTER_MTS_OPT 1 // JVET-AA0133: Inter MTS optimization // Entropy Coding #define EC_HIGH_PRECISION 1 // CABAC high precision @@ -677,6 +678,19 @@ enum QuantFlags enum TransType { #if JVET_W0103_INTRA_MTS +#if JVET_AA0133_INTER_MTS_OPT + DCT2 = 0, + DCT8 = 1, + DST7 = 2, + DCT5 = 3, + DST4 = 4, + DST1 = 5, + IDTR = 6, + KLT0 = 7, + KLT1 = 8, + NUM_TRANS_TYPE = 9, + DCT2_EMT = 10 +#else DCT2 = 0, DCT8 = 1, DST7 = 2, @@ -686,6 +700,7 @@ enum TransType IDTR = 6, NUM_TRANS_TYPE = 7, DCT2_EMT = 8 +#endif #else DCT2 = 0, DCT8 = 1, diff --git a/source/Lib/CommonLib/UnitTools.cpp b/source/Lib/CommonLib/UnitTools.cpp index 510342dc78f48ee731935231f389bd4e347097e5..5a6a300a6a26b6e28d236ee1bc4d6abda961f980 100644 --- a/source/Lib/CommonLib/UnitTools.cpp +++ b/source/Lib/CommonLib/UnitTools.cpp @@ -13396,7 +13396,11 @@ bool CU::bdpcmAllowed( const CodingUnit& cu, const ComponentID compID ) bool CU::isMTSAllowed(const CodingUnit &cu, const ComponentID compID) { SizeType tsMaxSize = 1 << cu.cs->sps->getLog2MaxTransformSkipBlockSize(); +#if JVET_AA0133_INTER_MTS_OPT + const int maxSize = CU::isIntra(cu) ? MTS_INTRA_MAX_CU_SIZE : cu.cs->sps->getInterMTSMaxSize(); +#else const int maxSize = CU::isIntra( cu ) ? MTS_INTRA_MAX_CU_SIZE : MTS_INTER_MAX_CU_SIZE; +#endif const int cuWidth = cu.blocks[0].lumaSize().width; const int cuHeight = cu.blocks[0].lumaSize().height; bool mtsAllowed = cu.chType == CHANNEL_TYPE_LUMA && compID == COMPONENT_Y; diff --git a/source/Lib/CommonLib/x86/TrQuantX86.h b/source/Lib/CommonLib/x86/TrQuantX86.h index 7dbf585b0f9d4fa7e7ee535573618c5035e3d678..0cc85f201e6b9db7853a016cdedecf1467e32ab5 100644 --- a/source/Lib/CommonLib/x86/TrQuantX86.h +++ b/source/Lib/CommonLib/x86/TrQuantX86.h @@ -640,6 +640,10 @@ void TrQuant::_initTrQuantX86() { nullptr, g_aiTr4[DST4][0], g_aiTr8[DST4][0], g_aiTr16[DST4][0], g_aiTr32[DST4][0], g_aiTr64[DST4][0], g_aiTr128[DST4][0], g_aiTr256[DST4][0] }, { nullptr, g_aiTr4[DST1][0], g_aiTr8[DST1][0], g_aiTr16[DST1][0], g_aiTr32[DST1][0], g_aiTr64[DST1][0], g_aiTr128[DST1][0], g_aiTr256[DST1][0] }, { nullptr, g_aiTr4[IDTR][0], g_aiTr8[IDTR][0], g_aiTr16[IDTR][0], g_aiTr32[IDTR][0], g_aiTr64[IDTR][0], g_aiTr128[IDTR][0], g_aiTr256[IDTR][0] }, +#if JVET_AA0133_INTER_MTS_OPT + { nullptr, g_aiTr4[KLT0][0], g_aiTr8[KLT0][0], g_aiTr16[KLT0][0], nullptr, nullptr, nullptr, nullptr }, + { nullptr, g_aiTr4[KLT1][0], g_aiTr8[KLT1][0], g_aiTr16[KLT1][0], nullptr, nullptr, nullptr, nullptr }, +#endif #endif } }; @@ -653,6 +657,10 @@ void TrQuant::_initTrQuantX86() { nullptr, g_aiTr4[DST4][0], g_aiTr8[DST4][0], g_aiTr16[DST4][0], g_aiTr32[DST4][0], g_aiTr64[DST4][0], g_aiTr128[DST4][0], g_aiTr256[DST4][0] }, { nullptr, g_aiTr4[DST1][0], g_aiTr8[DST1][0], g_aiTr16[DST1][0], g_aiTr32[DST1][0], g_aiTr64[DST1][0], g_aiTr128[DST1][0], g_aiTr256[DST1][0] }, { nullptr, g_aiTr4[IDTR][0], g_aiTr8[IDTR][0], g_aiTr16[IDTR][0], g_aiTr32[IDTR][0], g_aiTr64[IDTR][0], g_aiTr128[IDTR][0], g_aiTr256[IDTR][0] }, +#if JVET_AA0133_INTER_MTS_OPT + { nullptr, g_aiTr4[KLT0][0], g_aiTr8[KLT0][0], g_aiTr16[KLT0][0], nullptr, nullptr, nullptr, nullptr }, + { nullptr, g_aiTr4[KLT1][0], g_aiTr8[KLT1][0], g_aiTr16[KLT1][0], nullptr, nullptr, nullptr, nullptr }, +#endif #endif } }; @@ -719,6 +727,25 @@ void TrQuant::_initTrQuantX86() fastFwdTrans[6][5] = fastForwardTransform_SIMD<IDTR, 64>; fastFwdTrans[6][6] = fastForwardTransform_SIMD<IDTR, 128>; fastFwdTrans[6][7] = fastForwardTransform_SIMD<IDTR, 256>; +#if JVET_AA0133_INTER_MTS_OPT + fastFwdTrans[7][0] = nullptr; + fastFwdTrans[7][1] = fastForwardTransform_SIMD<KLT0, 4>; + fastFwdTrans[7][2] = fastForwardTransform_SIMD<KLT0, 8>; + fastFwdTrans[7][3] = fastForwardTransform_SIMD<KLT0, 16>; + fastFwdTrans[7][4] = nullptr; + fastFwdTrans[7][5] = nullptr; + fastFwdTrans[7][6] = nullptr; + fastFwdTrans[7][7] = nullptr; + + fastFwdTrans[8][0] = nullptr; + fastFwdTrans[8][1] = fastForwardTransform_SIMD<KLT1, 4>; + fastFwdTrans[8][2] = fastForwardTransform_SIMD<KLT1, 8>; + fastFwdTrans[8][3] = fastForwardTransform_SIMD<KLT1, 16>; + fastFwdTrans[8][4] = nullptr; + fastFwdTrans[8][5] = nullptr; + fastFwdTrans[8][6] = nullptr; + fastFwdTrans[8][7] = nullptr; +#endif #endif fastInvTrans[0][0] = fastInverseTransform_SIMD<DCT2, 2>; @@ -784,6 +811,25 @@ void TrQuant::_initTrQuantX86() fastInvTrans[6][5] = fastInverseTransform_SIMD<IDTR, 64>; fastInvTrans[6][6] = fastInverseTransform_SIMD<IDTR, 128>; fastInvTrans[6][7] = fastInverseTransform_SIMD<IDTR, 256>; +#if JVET_AA0133_INTER_MTS_OPT + fastInvTrans[7][0] = nullptr; + fastInvTrans[7][1] = fastInverseTransform_SIMD<KLT0, 4>; + fastInvTrans[7][2] = fastInverseTransform_SIMD<KLT0, 8>; + fastInvTrans[7][3] = fastInverseTransform_SIMD<KLT0, 16>; + fastInvTrans[7][4] = nullptr; + fastInvTrans[7][5] = nullptr; + fastInvTrans[7][6] = nullptr; + fastInvTrans[7][7] = nullptr; + + fastInvTrans[8][0] = nullptr; + fastInvTrans[8][1] = fastInverseTransform_SIMD<KLT1, 4>; + fastInvTrans[8][2] = fastInverseTransform_SIMD<KLT1, 8>; + fastInvTrans[8][3] = fastInverseTransform_SIMD<KLT1, 16>; + fastInvTrans[8][4] = nullptr; + fastInvTrans[8][5] = nullptr; + fastInvTrans[8][6] = nullptr; + fastInvTrans[8][7] = nullptr; +#endif #endif #else m_forwardTransformKernels = diff --git a/source/Lib/DecoderLib/VLCReader.cpp b/source/Lib/DecoderLib/VLCReader.cpp index a192a039b9abbc82dd7aaa7f264a6c5133f3bd88..d3539facebf902f6ad8106aabc73babf93c81611 100644 --- a/source/Lib/DecoderLib/VLCReader.cpp +++ b/source/Lib/DecoderLib/VLCReader.cpp @@ -2041,6 +2041,12 @@ void HLSyntaxReader::parseSPS(SPS* pcSPS) { READ_FLAG(uiCode, "sps_explicit_mts_intra_enabled_flag"); pcSPS->setUseIntraMTS(uiCode != 0); READ_FLAG(uiCode, "sps_explicit_mts_inter_enabled_flag"); pcSPS->setUseInterMTS(uiCode != 0); +#if JVET_AA0133_INTER_MTS_OPT + if (pcSPS->getUseInterMTS()) + { + READ_FLAG(uiCode, "sps_inter_mts_max_size"); pcSPS->setInterMTSMaxSize((uiCode != 0) ? 16 : 32); + } +#endif } READ_FLAG(uiCode, "sps_lfnst_enabled_flag"); pcSPS->setUseLFNST(uiCode != 0); #endif diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h index 85726c6e3c62609a12fd825e4dfc0ab22a6c514d..8a488b86cbd8aec3e0dc18db7a0c2c929535da0d 100644 --- a/source/Lib/EncoderLib/EncCfg.h +++ b/source/Lib/EncoderLib/EncCfg.h @@ -431,6 +431,9 @@ protected: int m_LadfQpOffset[MAX_LADF_INTERVALS]; int m_LadfIntervalLowerBound[MAX_LADF_INTERVALS]; #endif +#if JVET_AA0133_INTER_MTS_OPT + int m_interMTSMaxSize; +#endif #if ENABLE_DIMD bool m_dimd; #endif @@ -1385,6 +1388,10 @@ public: int getLadfIntervalLowerBound ( int idx ) const { return m_LadfIntervalLowerBound[ idx ]; } #endif +#if JVET_AA0133_INTER_MTS_OPT + void setInterMTSMaxSize(int size) { m_interMTSMaxSize = size; } + int getInterMTSMaxSize() const { return m_interMTSMaxSize; } +#endif #if ENABLE_DIMD void setUseDimd ( bool b ) { m_dimd = b; } diff --git a/source/Lib/EncoderLib/EncCu.cpp b/source/Lib/EncoderLib/EncCu.cpp index 8c705f2c8dcaa00ae73807807100b60424029f4a..0c885ca42e873287a243f45e364f6dca33f90d26 100644 --- a/source/Lib/EncoderLib/EncCu.cpp +++ b/source/Lib/EncoderLib/EncCu.cpp @@ -741,6 +741,9 @@ bool EncCu::xCheckBestMode( CodingStructure *&tempCS, CodingStructure *&bestCS, void EncCu::xCompressCU( CodingStructure*& tempCS, CodingStructure*& bestCS, Partitioner& partitioner, double maxCostAllowed ) { CHECK(maxCostAllowed < 0, "Wrong value of maxCostAllowed!"); +#if JVET_AA0133_INTER_MTS_OPT + m_pcInterSearch->setBestCost(maxCostAllowed); +#endif #if ENABLE_SPLIT_PARALLELISM CHECK( m_dataId != tempCS->picture->scheduler.getDataId(), "Working in the wrong dataId!" ); @@ -986,7 +989,9 @@ void EncCu::xCompressCU( CodingStructure*& tempCS, CodingStructure*& bestCS, Par #endif } m_sbtCostSave[0] = m_sbtCostSave[1] = MAX_DOUBLE; - +#if JVET_AA0133_INTER_MTS_OPT + m_mtsCostSave = MAX_DOUBLE; +#endif m_CurrCtx->start = m_CABACEstimator->getCtx(); m_cuChromaQpOffsetIdxPlus1 = 0; @@ -11625,6 +11630,9 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS double bestCostBegin = bestCS->cost; CodingUnit* prevBestCU = bestCS->getCU( partitioner.chType ); uint8_t prevBestSbt = ( prevBestCU == nullptr ) ? 0 : prevBestCU->sbtInfo; +#if JVET_AA0133_INTER_MTS_OPT + bool prevBestMts = (prevBestCU == nullptr) ? 0 : (prevBestCU->firstTU->mtsIdx[COMPONENT_Y] > MTS_SKIP)? true : false ; +#endif bool swapped = false; // avoid unwanted data copy bool reloadCU = false; @@ -11698,7 +11706,13 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS } } } +#if JVET_AA0133_INTER_MTS_OPT + m_pcInterSearch->setBestCost(bestCS->cost); + cu->mtsFlag = false; + const bool mtsAllowed = tempCS->sps->getUseInterMTS() && CU::isInter(*cu) && partitioner.currArea().lwidth() <= tempCS->sps->getInterMTSMaxSize() && partitioner.currArea().lheight() <= tempCS->sps->getInterMTSMaxSize(); +#else const bool mtsAllowed = tempCS->sps->getUseInterMTS() && CU::isInter( *cu ) && partitioner.currArea().lwidth() <= MTS_INTER_MAX_CU_SIZE && partitioner.currArea().lheight() <= MTS_INTER_MAX_CU_SIZE; +#endif uint8_t sbtAllowed = cu->checkAllowedSbt(); //SBT resolution-dependent fast algorithm: not try size-64 SBT in RDO for low-resolution sequences (now resolution below HD) if( tempCS->pps->getPicWidthInLumaSamples() < (uint32_t)m_pcEncCfg->getSBTFast64WidthTh() ) @@ -11711,7 +11725,9 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS double sbtOffCost = MAX_DOUBLE; double currBestCost = MAX_DOUBLE; bool doPreAnalyzeResi = ( sbtAllowed || mtsAllowed ) && residualPass == 0; - +#if JVET_AA0133_INTER_MTS_OPT + double mtsOffCost = MAX_DOUBLE; +#endif m_pcInterSearch->initTuAnalyzer(); if( doPreAnalyzeResi ) { @@ -11790,7 +11806,11 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS bestCS->tmpColorSpaceCost = tempCS->tmpColorSpaceCost; bestCS->firstColorSpaceSelected = tempCS->firstColorSpaceSelected; } +#if JVET_AA0133_INTER_MTS_OPT + numRDOTried += 1; +#else numRDOTried += mtsAllowed ? 2 : 1; +#endif xEncodeDontSplit( *tempCS, partitioner ); xCheckDQP( *tempCS, partitioner ); @@ -11967,11 +11987,116 @@ void EncCu::xEncodeInterResidual( CodingStructure *&tempCS #endif xCheckBestMode( tempCS, bestCS, partitioner, encTestMode ); } +#if JVET_AA0133_INTER_MTS_OPT + if (!skipResidual && mtsAllowed) + { + if (bestCost == bestCS->cost) //The first EMT pass didn't become the bestCS, so we clear the TUs generated + { + tempCS->clearTUs(); + } + else if (false == swapped) + { + tempCS->initStructData(encTestMode.qp); + tempCS->copyStructure(*bestCS, partitioner.chType); + tempCS->getPredBuf().copyFrom(bestCS->getPredBuf()); + bestCost = bestCS->cost; + cu = tempCS->getCU(partitioner.chType); + swapped = true; + } + else + { + tempCS->clearTUs(); + bestCost = bestCS->cost; + cu = tempCS->getCU(partitioner.chType); + } + + //we need to restart the distortion for the new tempCS, the bit count and the cost + tempCS->dist = 0; + tempCS->fracBits = 0; + tempCS->cost = MAX_DOUBLE; + tempCS->costDbOffset = 0; + cu->skip = false; + cu->sbtInfo = 0; + cu->mtsFlag = true; + m_pcInterSearch->setBestCost(bestCS->cost); + mtsOffCost = currBestCost; + bool testMts = true; + if (bestCost != MAX_DOUBLE && mtsOffCost != MAX_DOUBLE) + { + double th = 1.07; + if (!(prevBestMts == 0 || m_mtsCostSave == MAX_DOUBLE)) + { + assert(m_sbtCostSave[1] <= m_mtsCostSave); + th *= (m_mtsCostSave / m_sbtCostSave[1]); + } + if (mtsOffCost > bestCost * th) + { + testMts = false; + } + } + if(testMts) + { + //try residual coding + bool isValid = m_pcInterSearch->encodeResAndCalcRdInterCU(*tempCS, partitioner, skipResidual); + if (isValid) + { + if (tempCS->slice->getSPS()->getUseColorTrans()) + { + bestCS->tmpColorSpaceCost = tempCS->tmpColorSpaceCost; + bestCS->firstColorSpaceSelected = tempCS->firstColorSpaceSelected; + } + numRDOTried++; + + xEncodeDontSplit(*tempCS, partitioner); + + xCheckDQP(*tempCS, partitioner); + xCheckChromaQPOffset(*tempCS, partitioner); + + if (NULL != bestHasNonResi && (bestCostInternal > tempCS->cost)) + { + bestCostInternal = tempCS->cost; + if (!(tempCS->getPU(partitioner.chType)->ciipFlag)) + { + *bestHasNonResi = !cu->rootCbf; + } + } + if (cu->rootCbf == false) + { + if (tempCS->getPU(partitioner.chType)->ciipFlag) + { + tempCS->cost = MAX_DOUBLE; + tempCS->costDbOffset = 0; + return; + } + } + if (tempCS->cost < currBestCost) + { + currBestCost = tempCS->cost; + sbtOffCost = tempCS->cost; + sbtOffDist = tempCS->dist; + sbtOffRootCbf = cu->rootCbf; + currBestSbt = CU::getSbtInfo(cu->firstTU->mtsIdx[COMPONENT_Y] > MTS_SKIP ? SBT_OFF_MTS : SBT_OFF_DCT, 0); + currBestTrs = cu->firstTU->mtsIdx[COMPONENT_Y]; + } + + #if WCG_EXT + DTRACE_MODE_COST(*tempCS, m_pcRdCost->getLambda(true)); + #else + DTRACE_MODE_COST(*tempCS, m_pcRdCost->getLambda()); + #endif + xCheckBestMode(tempCS, bestCS, partitioner, encTestMode); + } + } + } +#endif if( bestCostBegin != bestCS->cost ) { m_sbtCostSave[0] = sbtOffCost; m_sbtCostSave[1] = currBestCost; +#if JVET_AA0133_INTER_MTS_OPT + m_mtsCostSave = mtsOffCost; +#endif } } //end emt loop diff --git a/source/Lib/EncoderLib/EncCu.h b/source/Lib/EncoderLib/EncCu.h index 8274b2870df27fb0ca42b15c4e169045753e1bd7..22ba1bc679e91c7576696986a3f9a9e88e96daf7 100644 --- a/source/Lib/EncoderLib/EncCu.h +++ b/source/Lib/EncoderLib/EncCu.h @@ -370,6 +370,9 @@ private: const bool updateRdCostLambda ); #endif double m_sbtCostSave[2]; +#if JVET_AA0133_INTER_MTS_OPT + double m_mtsCostSave; +#endif #if JVET_W0097_GPM_MMVD_TM MergeCtx m_mergeCand; bool m_mergeCandAvail; diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp index 35376a2d94409f445d531c01e082a5c50e543d87..1994d7e985a391b528bb3ab44c9c8d4f7eae83e3 100644 --- a/source/Lib/EncoderLib/EncLib.cpp +++ b/source/Lib/EncoderLib/EncLib.cpp @@ -1585,6 +1585,9 @@ void EncLib::xInitSPS( SPS& sps ) CHECK( m_LadfIntervalLowerBound[0] != 0, "abnormal value set to LadfIntervalLowerBound[0]" ); } #endif +#if JVET_AA0133_INTER_MTS_OPT + sps.setInterMTSMaxSize(m_interMTSMaxSize); +#endif #if ENABLE_DIMD sps.setUseDimd ( m_dimd ); #endif diff --git a/source/Lib/EncoderLib/InterSearch.cpp b/source/Lib/EncoderLib/InterSearch.cpp index effcbe1f7f593d590a0c9ec1f9aee6265920a044..2ceaced3b9bf0c7af73e712699dbd9c94336838b 100644 --- a/source/Lib/EncoderLib/InterSearch.cpp +++ b/source/Lib/EncoderLib/InterSearch.cpp @@ -8684,11 +8684,17 @@ uint8_t InterSearch::skipSbtByRDCost( int width, int height, int mtDepth, uint8_ } return MAX_UCHAR; } - +#if JVET_AA0133_INTER_MTS_OPT +bool InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &partitioner, Distortion *puiZeroDist /*= NULL*/ + , const bool luma, const bool chroma + , PelUnitBuf* orgResi +) +#else void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &partitioner, Distortion *puiZeroDist /*= NULL*/ , const bool luma, const bool chroma , PelUnitBuf* orgResi ) +#endif { const UnitArea& currArea = partitioner.currArea(); const SPS &sps = *cs.sps; @@ -8798,10 +8804,21 @@ void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &par } - const bool tsAllowed = TU::isTSAllowed(tu, compID) && (isLuma(compID) || (isChroma(compID) && m_pcEncCfg->getUseChromaTS())); +#if JVET_AA0133_INTER_MTS_OPT + const bool mtsAllowed = CU::isMTSAllowed(*tu.cu, compID) && cu.mtsFlag; + const bool tsAllowed = TU::isTSAllowed(tu, compID) && ((isLuma(compID) && !cu.mtsFlag) || (isChroma(compID) && m_pcEncCfg->getUseChromaTS())); +#else + const bool tsAllowed = TU::isTSAllowed(tu, compID) && (isLuma(compID) || (isChroma(compID) && m_pcEncCfg->getUseChromaTS())); const bool mtsAllowed = CU::isMTSAllowed( *tu.cu, compID ); +#endif uint8_t nNumTransformCands = 1 + ( tsAllowed ? 1 : 0 ) + ( mtsAllowed ? 4 : 0 ); // DCT + TS + 4 MTS = 6 tests +#if JVET_AA0133_INTER_MTS_OPT + if (cu.mtsFlag && compID == COMPONENT_Y) + { + nNumTransformCands = (mtsAllowed ? 4 : 0); + } +#endif std::vector<TrMode> trModes; #if TU_256 if(tu.idx != cu.firstTU->idx) @@ -8816,11 +8833,21 @@ void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &par { nNumTransformCands = 0; } +#if JVET_AA0133_INTER_MTS_OPT + else if(!(cu.mtsFlag && compID == COMPONENT_Y)) +#else else +#endif { trModes.push_back( TrMode( 0, true ) ); //DCT2 nNumTransformCands = 1; } +#if JVET_AA0133_INTER_MTS_OPT + else + { + nNumTransformCands = 0; + } +#endif //for a SBT-no-residual TU, the RDO process should be called once, in order to get the RD cost if( tsAllowed && !tu.noResidual ) { @@ -8866,11 +8893,16 @@ void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &par #endif m_pcRdCost->lambdaAdjustColorTrans(true, compID); } - +#if JVET_AA0133_INTER_MTS_OPT + bool skipRemainingMTS = false; + bool skipMTSPass = false; + int countSkipMTSLoop = 0; +#endif const int numTransformCandidates = nNumTransformCands; for( int transformMode = 0; transformMode < numTransformCandidates; transformMode++ ) { const bool isFirstMode = transformMode == 0; + // copy the original residual into the residual buffer #if JVET_S0234_ACT_CRS_FIX if (colorTransFlag) @@ -8898,6 +8930,12 @@ void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &par } tu.mtsIdx[compID] = trModes[transformMode].first; } +#if JVET_AA0133_INTER_MTS_OPT + if (compID == COMPONENT_Y && cu.mtsFlag && skipRemainingMTS) + { + break; + } +#endif QpParam cQP(tu, compID); // note: uses tu.transformSkip[compID] #if RDOQ_CHROMA_LAMBDA @@ -9185,7 +9223,11 @@ void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &par #endif } } +#if JVET_AA0133_INTER_MTS_OPT + else if ((cu.mtsFlag && compID == COMPONENT_Y) || (transformMode > 0)) +#else else if( transformMode > 0 ) +#endif { currCompCost = MAX_DOUBLE; } @@ -9197,7 +9239,23 @@ void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &par tu.cbf[compID] = 0; } - +#if JVET_AA0133_INTER_MTS_OPT + if (compID == COMPONENT_Y && cu.mtsFlag && currCompCost < MAX_DOUBLE) + { + double globalMinCost = std::min(minCost[compID], m_bestDCT2PassLumaCost); + double fac = std::max((1.0 + 1.0 / sqrt(tu.lumaSize().width * tu.lumaSize().height)), 1.06); + if (currCompCost > fac*globalMinCost) + { + skipRemainingMTS = true; + if (countSkipMTSLoop == 0) + { + //skip MTS candidate condition fulfilled for first valid MTS candidate (count = 0), so skip chroma coding. + skipMTSPass = true; + } + } + countSkipMTSLoop++; + } + #endif // evaluate #if TU_256 if( isFirstMode || ( currCompCost < minCost[compID] ) || ( transformMode == 1 && currCompCost == minCost[compID] ) ) @@ -9230,7 +9288,29 @@ void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &par CHECK( currCompFracBits > 0 || currAbsSum, "currCompFracBits > 0 when tu noResidual" ); } } - +#if JVET_AA0133_INTER_MTS_OPT + if (compID == 0) + { + if (cu.mtsFlag) + { + if (minCost[compID] == MAX_DOUBLE || bestTU.cbf[0] == 0) //When checking only MTS cands, cbf can't be zero, or just only contain DC coefficients. + { + return false; + } + if (skipMTSPass) //Luma is not selecting any MTS (skipping), so skip chroma coding (Encoder speedup). + { + return false; + } + } + else + { + if (!cu.sbtInfo) + { + m_bestDCT2PassLumaCost = minCost[compID]; + } + } + } +#endif // copy component tu.copyComponentFrom( bestTU, compID ); csFull->getResiBuf( compArea ).copyFrom( saveCS.getResiBuf( compArea ) ); @@ -9756,11 +9836,19 @@ void InterSearch::xEstimateInterResidualQT(CodingStructure &cs, Partitioner &par csFull ->releaseIntermediateData(); } } +#if JVET_AA0133_INTER_MTS_OPT + return true; +#endif } - +#if JVET_AA0133_INTER_MTS_OPT +bool InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &partitioner, const bool &skipResidual + , const bool luma, const bool chroma +) +#else void InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &partitioner, const bool &skipResidual , const bool luma, const bool chroma ) +#endif { m_pcRdCost->setChromaFormat(cs.sps->getChromaFormatIdc()); @@ -9851,8 +9939,11 @@ void InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &pa cs.dist = distortion; cs.fracBits = m_CABACEstimator->getEstFracBits(); cs.cost = m_pcRdCost->calcRdCost(cs.fracBits, cs.dist); - +#if JVET_AA0133_INTER_MTS_OPT + return true; +#else return; +#endif } // Residual coding. @@ -10027,7 +10118,15 @@ void InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &pa cs.getOrgResiBuf().bufs[1].copyFrom(orgResidual.bufs[1]); cs.getOrgResiBuf().bufs[2].copyFrom(orgResidual.bufs[2]); } +#if JVET_AA0133_INTER_MTS_OPT + bool isValidReturn = xEstimateInterResidualQT(cs, partitioner, &zeroDistortion, luma, chroma); + if (cu.mtsFlag && !isValidReturn) + { + return false; + } +#else xEstimateInterResidualQT(cs, partitioner, &zeroDistortion, luma, chroma); +#endif } TransformUnit &firstTU = *cs.getTU( partitioner.chType ); @@ -10387,6 +10486,9 @@ void InterSearch::encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &pa } CHECK(cs.tus.size() == 0, "No TUs present"); +#if JVET_AA0133_INTER_MTS_OPT + return true; +#endif } uint64_t InterSearch::xGetSymbolFracBitsInter(CodingStructure &cs, Partitioner &partitioner) diff --git a/source/Lib/EncoderLib/InterSearch.h b/source/Lib/EncoderLib/InterSearch.h index 24b3db71aa15620a3de77282003f35ebca02d67e..7cd239f081b9905d7642d7973963ebfd8a15b18c 100644 --- a/source/Lib/EncoderLib/InterSearch.h +++ b/source/Lib/EncoderLib/InterSearch.h @@ -351,7 +351,10 @@ protected: CABACWriter* m_CABACEstimator; CtxCache* m_CtxCache; DistParam m_cDistParam; - +#if JVET_AA0133_INTER_MTS_OPT + double m_globalBestLumaCost; + double m_bestDCT2PassLumaCost; +#endif RefPicList m_currRefPicList; int m_currRefPicIndex; bool m_skipFracME; @@ -418,6 +421,7 @@ public: m_geoMrgCtx = geoMrgCtx; } #endif + InterSearch(); virtual ~InterSearch(); @@ -458,6 +462,9 @@ public: void resetCtuRecord () { m_ctuRecord.clear(); } #if ENABLE_SPLIT_PARALLELISM void copyState ( const InterSearch& other ); +#endif +#if JVET_AA0133_INTER_MTS_OPT + void setBestCost(double cost) { m_globalBestLumaCost = cost; } #endif void setAffineModeSelected ( bool flag) { m_affineModeSelected = flag; } void resetAffineMVList() { m_affMVListIdx = 0; m_affMVListSize = 0; } @@ -984,15 +991,27 @@ protected: private: void xxIBCHashSearch(PredictionUnit& pu, Mv* mvPred, int numMvPred, Mv &mv, int& idxMvPred, IbcHashMap& ibcHashMap); public: - +#if JVET_AA0133_INTER_MTS_OPT + bool encodeResAndCalcRdInterCU(CodingStructure &cs, Partitioner &partitioner, const bool &skipResidual + , const bool luma = true, const bool chroma = true + ); +#else void encodeResAndCalcRdInterCU (CodingStructure &cs, Partitioner &partitioner, const bool &skipResidual , const bool luma = true, const bool chroma = true ); +#endif void xEncodeInterResidualQT (CodingStructure &cs, Partitioner &partitioner, const ComponentID &compID); +#if JVET_AA0133_INTER_MTS_OPT + bool xEstimateInterResidualQT(CodingStructure &cs, Partitioner &partitioner, Distortion *puiZeroDist = NULL + , const bool luma = true, const bool chroma = true + , PelUnitBuf* orgResi = NULL + ); +#else void xEstimateInterResidualQT (CodingStructure &cs, Partitioner &partitioner, Distortion *puiZeroDist = NULL , const bool luma = true, const bool chroma = true , PelUnitBuf* orgResi = NULL ); +#endif uint64_t xGetSymbolFracBitsInter (CodingStructure &cs, Partitioner &partitioner); uint64_t xCalcPuMeBits (PredictionUnit& pu); diff --git a/source/Lib/EncoderLib/VLCWriter.cpp b/source/Lib/EncoderLib/VLCWriter.cpp index a7bfc77196d9970496be521bbd68f54aaaef0410..0611d87735f235c2e1c738f353f454da6c270974 100644 --- a/source/Lib/EncoderLib/VLCWriter.cpp +++ b/source/Lib/EncoderLib/VLCWriter.cpp @@ -1234,6 +1234,14 @@ void HLSWriter::codeSPS( const SPS* pcSPS ) { WRITE_FLAG(pcSPS->getUseIntraMTS() ? 1 : 0, "sps_explicit_mts_intra_enabled_flag"); WRITE_FLAG(pcSPS->getUseInterMTS() ? 1 : 0, "sps_explicit_mts_inter_enabled_flag"); +#if JVET_AA0133_INTER_MTS_OPT + if (pcSPS->getUseInterMTS()) + { + int interMTSMaxCU = pcSPS->getInterMTSMaxSize(); + CHECK((interMTSMaxCU != 16 && interMTSMaxCU != 32), "interMTSMaxSize != 32 or 16"); + WRITE_FLAG(interMTSMaxCU == 16? 1 : 0, "sps_inter_mts_max_size"); + } +#endif } WRITE_FLAG(pcSPS->getUseLFNST() ? 1 : 0, "sps_lfnst_enabled_flag"); #endif