From abe7f177622fa727d1a85791e850ea4a610123a3 Mon Sep 17 00:00:00 2001
From: deluxan <santiago.de.luxan@hhi.fraunhofer.de>
Date: Fri, 15 Feb 2019 16:43:45 +0100
Subject: [PATCH] Speed-Up for ISP when JVET_M0464_UNI_MTS is enabled. It can
 be enabled/disabled with the config. file parameter ISPFast.

When enabled,
-it merges all full RD intra mode lists into one,
-it tests fewer non-DCT-II transforms if ISP is likely to become the best mode and
-it stops testing intra modes for an ISP split if all sub-partitions obtained a zero-cbf.

Results in CTC:

AI -> 0.04% loss, EncT = 97%, DecT = 100%
RA -> 0.01% loss, EncT = 99%, DecT = 100%
---
 cfg/encoder_intra_vtm.cfg             |  1 +
 cfg/encoder_randomaccess_vtm.cfg      |  1 +
 source/App/EncoderApp/EncApp.cpp      |  3 +
 source/App/EncoderApp/EncAppCfg.cpp   |  6 ++
 source/App/EncoderApp/EncAppCfg.h     |  3 +
 source/Lib/CommonLib/TypeDef.h        | 10 +++
 source/Lib/EncoderLib/EncCfg.h        |  7 ++
 source/Lib/EncoderLib/IntraSearch.cpp | 98 +++++++++++++++++++++++++--
 source/Lib/EncoderLib/IntraSearch.h   |  4 ++
 9 files changed, 129 insertions(+), 4 deletions(-)

diff --git a/cfg/encoder_intra_vtm.cfg b/cfg/encoder_intra_vtm.cfg
index c758b499..0a7fae75 100644
--- a/cfg/encoder_intra_vtm.cfg
+++ b/cfg/encoder_intra_vtm.cfg
@@ -117,6 +117,7 @@ LumaReshapeEnable            : 1      # luma reshaping. 0: disable 1:enable
 
 # Fast tools
 PBIntraFast                  : 1
+ISPFast                      : 1
 FastMrg                      : 1
 AMaxBT                       : 1
 
diff --git a/cfg/encoder_randomaccess_vtm.cfg b/cfg/encoder_randomaccess_vtm.cfg
index b15a4388..beb70d7e 100644
--- a/cfg/encoder_randomaccess_vtm.cfg
+++ b/cfg/encoder_randomaccess_vtm.cfg
@@ -153,6 +153,7 @@ DMVR                         : 1
 
 # Fast tools
 PBIntraFast                  : 1
+ISPFast                      : 1
 FastMrg                      : 1
 AMaxBT                       : 1
 
diff --git a/source/App/EncoderApp/EncApp.cpp b/source/App/EncoderApp/EncApp.cpp
index 26753f94..afb40251 100644
--- a/source/App/EncoderApp/EncApp.cpp
+++ b/source/App/EncoderApp/EncApp.cpp
@@ -326,6 +326,9 @@ void EncApp::xInitLibCfg()
   m_cEncLib.setUseBLambdaForNonKeyLowDelayPictures               ( m_bUseBLambdaForNonKeyLowDelayPictures );
   m_cEncLib.setPCMLog2MinSize                                    ( m_uiPCMLog2MinSize);
   m_cEncLib.setUsePCM                                            ( m_usePCM );
+#if JVET_M0102_INTRA_SUBPARTITIONS
+  m_cEncLib.setUseFastISP                                        ( m_useFastISP );
+#endif
 
   // set internal bit-depth and constants
   for (uint32_t channelType = 0; channelType < MAX_NUM_CHANNEL_TYPE; channelType++)
diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp
index eb18b7e6..251cdce2 100644
--- a/source/App/EncoderApp/EncAppCfg.cpp
+++ b/source/App/EncoderApp/EncAppCfg.cpp
@@ -1027,6 +1027,9 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] )
 #else
   ("TransformSkipFast",                               m_useTransformSkipFast,                           false, "Fast intra transform skipping")
   ("TransformSkipLog2MaxSize",                        m_log2MaxTransformSkipBlockSize,                     2U, "Specify transform-skip maximum size. Minimum 2. (not valid in V1 profiles)")
+#endif
+#if JVET_M0102_INTRA_SUBPARTITIONS
+  ("ISPFast",                                         m_useFastISP,                                     false, "Fast encoder search for ISP")
 #endif
   ("ImplicitResidualDPCM",                            m_rdpcmEnabledFlag[RDPCM_SIGNAL_IMPLICIT],        false, "Enable implicitly signalled residual DPCM for intra (also known as sample-adaptive intra predict) (not valid in V1 profiles)")
   ("ExplicitResidualDPCM",                            m_rdpcmEnabledFlag[RDPCM_SIGNAL_EXPLICIT],        false, "Enable explicitly signalled residual DPCM for inter (not valid in V1 profiles)")
@@ -3235,6 +3238,9 @@ void EncAppCfg::xPrintParameter()
   if( m_MTS ) msg( VERBOSE, "MTSMaxCand: %1d(intra) %1d(inter) ", m_MTSIntraMaxCand, m_MTSInterMaxCand );
 #else
   if( m_EMT ) msg( VERBOSE, "EMTFast: %1d(intra) %1d(inter) ", ( m_FastEMT & m_EMT & 1 ), ( m_FastEMT >> 1 ) & ( m_EMT >> 1 ) & 1 );
+#endif
+#if JVET_M0102_INTRA_SUBPARTITIONS
+  msg( VERBOSE, "ISPFast:%d ", m_useFastISP );
 #endif
   msg( VERBOSE, "AMaxBT:%d ", m_useAMaxBT );
   msg( VERBOSE, "E0023FastEnc:%d ", m_e0023FastEnc );
diff --git a/source/App/EncoderApp/EncAppCfg.h b/source/App/EncoderApp/EncAppCfg.h
index 45895dc4..126e06dc 100644
--- a/source/App/EncoderApp/EncAppCfg.h
+++ b/source/App/EncoderApp/EncAppCfg.h
@@ -150,6 +150,9 @@ protected:
   bool      m_rdpcmEnabledFlag[NUMBER_OF_RDPCM_SIGNALLING_MODES];///< control flags for residual DPCM
   bool      m_persistentRiceAdaptationEnabledFlag;            ///< control flag for Golomb-Rice parameter adaptation over each slice
   bool      m_cabacBypassAlignmentEnabledFlag;
+#if JVET_M0102_INTRA_SUBPARTITIONS
+  bool      m_useFastISP;                                    ///< flag for enabling fast methods for ISP
+#endif
 
   // coding quality
 #if QP_SWITCHING_FOR_PARALLEL
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index b77ddfde..e0a451f5 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -1386,6 +1386,16 @@ public:
                                                   iterator it = const_cast<iterator>( _pos ); _size += numEl;
                                                   while( first != last ) *it++ = *first++;
                                                   return const_cast<iterator>( _pos ); }
+
+#if JVET_M0102_INTRA_SUBPARTITIONS && JVET_M0464_UNI_MTS
+  iterator        insert( const_iterator _pos, size_t numEl, const T& val )
+                                                { //const difference_type numEl = last - first;
+                                                  CHECKD( _size + numEl >= N, "capacity exceeded" );
+                                                  for( difference_type i = _size - 1; i >= _pos - _arr; i-- ) _arr[i + numEl] = _arr[i];
+                                                  iterator it = const_cast<iterator>( _pos ); _size += numEl;
+                                                  for ( int k = 0; k < numEl; k++) *it++ = val;
+                                                  return const_cast<iterator>( _pos ); }
+#endif
 };
 
 
diff --git a/source/Lib/EncoderLib/EncCfg.h b/source/Lib/EncoderLib/EncCfg.h
index 3bf9fe72..41445cb4 100644
--- a/source/Lib/EncoderLib/EncCfg.h
+++ b/source/Lib/EncoderLib/EncCfg.h
@@ -390,6 +390,9 @@ protected:
   int*      m_aidQP;
   uint32_t      m_uiDeltaQpRD;
   bool      m_bFastDeltaQP;
+#if JVET_M0102_INTRA_SUBPARTITIONS
+  bool      m_useFastISP;
+#endif
 
   bool      m_bUseConstrainedIntraPred;
   bool      m_bFastUDIUseMPMEnabled;
@@ -1104,6 +1107,10 @@ public:
   void setLog2MaxTransformSkipBlockSize                ( uint32_t u )    { m_log2MaxTransformSkipBlockSize  = u;       }
   bool getIntraSmoothingDisabledFlag               ()      const { return m_intraSmoothingDisabledFlag; }
   void setIntraSmoothingDisabledFlag               (bool bValue) { m_intraSmoothingDisabledFlag=bValue; }
+#if JVET_M0102_INTRA_SUBPARTITIONS
+  bool getUseFastISP                                   ()         { return m_useFastISP;    }
+  void setUseFastISP                                   ( bool b ) { m_useFastISP  = b;   }
+#endif
 
   const int* getdQPs                        () const { return m_aidQP;       }
   uint32_t      getDeltaQpRD                    () const { return m_uiDeltaQpRD; }
diff --git a/source/Lib/EncoderLib/IntraSearch.cpp b/source/Lib/EncoderLib/IntraSearch.cpp
index d4cab58f..6248857e 100644
--- a/source/Lib/EncoderLib/IntraSearch.cpp
+++ b/source/Lib/EncoderLib/IntraSearch.cpp
@@ -811,6 +811,45 @@ void IntraSearch::estIntraPredLumaQT( CodingUnit &cu, Partitioner &partitioner )
       }
     }
 
+#if JVET_M0102_INTRA_SUBPARTITIONS && JVET_M0464_UNI_MTS
+    if ( nOptionsForISP > 1 )
+    {
+      //we create a single full RD list that includes all intra modes using regular intra, MRL and ISP
+      auto* firstIspList  = ispOptions[1] == HOR_INTRA_SUBPARTITIONS ? &m_rdModeListWithoutMrlHor : &m_rdModeListWithoutMrlVer;
+      auto* secondIspList = ispOptions[1] == HOR_INTRA_SUBPARTITIONS ? &m_rdModeListWithoutMrlVer : &m_rdModeListWithoutMrlHor;
+
+      if ( m_pcEncCfg->getUseFastISP() )
+      {
+        // find the first non-MRL mode
+        size_t indexFirstMode = std::find( extendRefList.begin(), extendRefList.end(), 0 ) - extendRefList.begin();
+        // if not found, just take the last mode
+        if( indexFirstMode >= extendRefList.size() ) indexFirstMode = extendRefList.size() - 1;
+        // move the mode indicated by indexFirstMode to the beginning
+        for( int idx = ((int)indexFirstMode) - 1; idx >= 0; idx-- )
+        {
+          std::swap( extendRefList[idx], extendRefList[idx + 1] );
+          std::swap( uiRdModeList [idx], uiRdModeList [idx + 1] );
+        }
+        //insert all ISP modes after the first non-mrl mode
+        uiRdModeList.insert( uiRdModeList.begin() + 1, secondIspList->begin(), secondIspList->end() );
+        uiRdModeList.insert( uiRdModeList.begin() + 1, firstIspList->begin() , firstIspList->end()  );
+
+        extendRefList.insert( extendRefList.begin() + 1, secondIspList->size(), MRL_NUM_REF_LINES + ispOptions[2] ); 
+        extendRefList.insert( extendRefList.begin() + 1, firstIspList->size() , MRL_NUM_REF_LINES + ispOptions[1] ); 
+      }
+      else
+      {
+        //insert all ISP modes at the end of the current list
+        uiRdModeList.insert( uiRdModeList.end(), secondIspList->begin(), secondIspList->end() );
+        uiRdModeList.insert( uiRdModeList.end(), firstIspList->begin() , firstIspList->end()  );
+
+        extendRefList.insert( extendRefList.end(), secondIspList->size(), MRL_NUM_REF_LINES + ispOptions[2] );
+        extendRefList.insert( extendRefList.end(), firstIspList->size() , MRL_NUM_REF_LINES + ispOptions[1] );
+      }
+    }
+    CHECKD(uiRdModeList.size() != extendRefList.size(),"uiRdModeList and extendRefList do not have the same size!");
+#endif
+
     //===== check modes (using r-d costs) =====
     uint32_t       uiBestPUMode  = 0;
     int            bestExtendRef = 0;
@@ -830,12 +869,21 @@ void IntraSearch::estIntraPredLumaQT( CodingUnit &cu, Partitioner &partitioner )
     int       bestNormalIntraModeIndex    = -1;
     uint8_t   bestIspOption               = NOT_INTRA_SUBPARTITIONS;
     TUIntraSubPartitioner subTuPartitioner( partitioner );
-#if !JVET_M0464_UNI_MTS
+#if JVET_M0464_UNI_MTS
+    bool      ispHorAllZeroCbfs = false, ispVerAllZeroCbfs = false;
+
+    for (uint32_t uiMode = 0; uiMode < numModesForFullRD; uiMode++)
+    {
+      // set luma prediction mode
+      uint32_t uiOrgMode = uiRdModeList[uiMode];
+
+      cu.ispMode = extendRefList[uiMode] > MRL_NUM_REF_LINES ? extendRefList[uiMode] - MRL_NUM_REF_LINES : NOT_INTRA_SUBPARTITIONS;
+#else
     if ( !cu.ispMode && !cu.emtFlag )
     {
       m_modeCtrl->setEmtFirstPassNoIspCost( MAX_DOUBLE );
     }
-#endif
+
     for( uint32_t ispOptionIdx = 0; ispOptionIdx < nOptionsForISP; ispOptionIdx++ )
     {
       cu.ispMode = ispOptions[ispOptionIdx];
@@ -844,7 +892,7 @@ void IntraSearch::estIntraPredLumaQT( CodingUnit &cu, Partitioner &partitioner )
       {
         // set luma prediction mode
         uint32_t uiOrgMode = cu.ispMode == NOT_INTRA_SUBPARTITIONS ? uiRdModeList[uiMode] : cu.ispMode == HOR_INTRA_SUBPARTITIONS ? m_rdModeListWithoutMrlHor[uiMode] : m_rdModeListWithoutMrlVer[uiMode];
-
+#endif
         pu.intraDir[0] = uiOrgMode;
 
         int multiRefIdx = 0;
@@ -853,6 +901,12 @@ void IntraSearch::estIntraPredLumaQT( CodingUnit &cu, Partitioner &partitioner )
         {
           intraSubPartitionsProcOrder = CU::getISPType( cu, COMPONENT_Y );
           bool tuIsDividedInRows = CU::divideTuInRows( cu );
+#if JVET_M0464_UNI_MTS
+          if ( ( tuIsDividedInRows && ispHorAllZeroCbfs ) || ( !tuIsDividedInRows && ispVerAllZeroCbfs ) )
+          {
+            continue;
+          }
+#endif
           if( m_intraModeDiagRatio.at( bestNormalIntraModeIndex ) > 1.25 )
           {
             continue;
@@ -894,11 +948,25 @@ void IntraSearch::estIntraPredLumaQT( CodingUnit &cu, Partitioner &partitioner )
       }
       else
       {
+#if JVET_M0464_UNI_MTS
+        xRecurIntraCodingLumaQT( *csTemp, partitioner, bestIspOption ? bestCurrentCost : MAX_DOUBLE, -1, TU_NO_ISP, bestIspOption );
+#else
         xRecurIntraCodingLumaQT( *csTemp, partitioner, MAX_DOUBLE, -1 );
+#endif
       }
 
       if( cu.ispMode && !csTemp->cus[0]->firstTU->cbf[COMPONENT_Y] )
       {
+#if JVET_M0464_UNI_MTS
+        if ( cu.ispMode == HOR_INTRA_SUBPARTITIONS )
+        {
+          ispHorAllZeroCbfs |= ( m_pcEncCfg->getUseFastISP() && csTemp->tus[0]->lheight() > 2 && csTemp->cost >= bestCurrentCost );
+        }
+        else
+        {
+          ispVerAllZeroCbfs |= ( m_pcEncCfg->getUseFastISP() && csTemp->tus[0]->lwidth() > 2 && csTemp->cost >= bestCurrentCost );
+        }
+#endif
         csTemp->cost = MAX_DOUBLE;
       }
 #else
@@ -959,8 +1027,8 @@ void IntraSearch::estIntraPredLumaQT( CodingUnit &cu, Partitioner &partitioner )
     {
       m_modeCtrl->setEmtFirstPassNoIspCost(csBest->cost);
     }
-#endif
     }
+#endif
     cu.ispMode = bestIspOption;
 #endif
 
@@ -1980,7 +2048,11 @@ void IntraSearch::xIntraCodingTUBlock(TransformUnit &tu, const ComponentID &comp
 }
 
 #if JVET_M0102_INTRA_SUBPARTITIONS
+#if JVET_M0464_UNI_MTS
+void IntraSearch::xRecurIntraCodingLumaQT( CodingStructure &cs, Partitioner &partitioner, const double bestCostSoFar, const int subTuIdx, const PartSplit ispType, const bool ispIsCurrentWinnder )
+#else
 void IntraSearch::xRecurIntraCodingLumaQT( CodingStructure &cs, Partitioner &partitioner, const double bestCostSoFar, const int subTuIdx, const PartSplit ispType )
+#endif
 {
         int   subTuCounter = subTuIdx;
   const UnitArea &currArea = partitioner.currArea();
@@ -2122,6 +2194,10 @@ void IntraSearch::xRecurIntraCodingLumaQT( CodingStructure &cs, Partitioner &par
 #endif
 
 #if JVET_M0464_UNI_MTS
+#if JVET_M0102_INTRA_SUBPARTITIONS
+    double bestDCT2cost = MAX_DOUBLE;
+    double threshold = m_pcEncCfg->getUseFastISP() && !cu.ispMode && ispIsCurrentWinnder && nNumTransformCands > 1 ? 1 + 1.4 / sqrt( cu.lwidth() * cu.lheight() ) : 1;
+#endif
     for( int modeId = firstCheckId; modeId < nNumTransformCands; modeId++ )
     {
       if( !cbfDCT2 || ( m_pcEncCfg->getUseTransformSkipFast() && bestModeId[COMPONENT_Y] == 1 ) )
@@ -2132,6 +2208,13 @@ void IntraSearch::xRecurIntraCodingLumaQT( CodingStructure &cs, Partitioner &par
       {
         continue;
       }
+#if JVET_M0102_INTRA_SUBPARTITIONS
+      //we compare the DCT-II cost against the best ISP cost so far (except for TS)
+      if ( m_pcEncCfg->getUseFastISP() && !cu.ispMode && ispIsCurrentWinnder && trModes[modeId].first != 0 && ( trModes[modeId].first != 1 || !tsAllowed ) && bestDCT2cost > bestCostSoFar * threshold )
+      {
+        continue;
+      }
+#endif
       tu.mtsIdx = trModes[modeId].first;
 #else
     for( int modeId = firstCheckId; modeId <= lastCheckId; modeId++ )
@@ -2248,6 +2331,13 @@ void IntraSearch::xRecurIntraCodingLumaQT( CodingStructure &cs, Partitioner &par
         singleCostTmp     = m_pcRdCost->calcRdCost( singleTmpFracBits, singleDistTmpLuma );
       }
 
+#if JVET_M0102_INTRA_SUBPARTITIONS && JVET_M0464_UNI_MTS
+      if ( !cu.ispMode && nNumTransformCands > 1 && modeId == firstCheckId )
+      {
+        bestDCT2cost = singleCostTmp;
+      }
+#endif
+
       if (singleCostTmp < dSingleCost)
       {
         dSingleCost       = singleCostTmp;
diff --git a/source/Lib/EncoderLib/IntraSearch.h b/source/Lib/EncoderLib/IntraSearch.h
index 1879d06c..f4dd9f19 100644
--- a/source/Lib/EncoderLib/IntraSearch.h
+++ b/source/Lib/EncoderLib/IntraSearch.h
@@ -195,7 +195,11 @@ protected:
 
 #if JVET_M0102_INTRA_SUBPARTITIONS
   ChromaCbfs xRecurIntraChromaCodingQT( CodingStructure &cs, Partitioner& pm, const double bestCostSoFar = MAX_DOUBLE,                          const PartSplit ispType = TU_NO_ISP );
+#if JVET_M0464_UNI_MTS
+  void       xRecurIntraCodingLumaQT  ( CodingStructure &cs, Partitioner& pm, const double bestCostSoFar = MAX_DOUBLE, const int subTuIdx = -1, const PartSplit ispType = TU_NO_ISP, const bool ispIsCurrentWinnder = false );
+#else
   void       xRecurIntraCodingLumaQT  ( CodingStructure &cs, Partitioner& pm, const double bestCostSoFar = MAX_DOUBLE, const int subTuIdx = -1, const PartSplit ispType = TU_NO_ISP );
+#endif
 #else
   ChromaCbfs xRecurIntraChromaCodingQT  (CodingStructure &cs, Partitioner& pm);
 
-- 
GitLab