Skip to content
Snippets Groups Projects
DepQuant.cpp 70.3 KiB
Newer Older
  • Learn to ignore specific revisions
  •     SbbCtx                      m_allSbbCtx  [8];
        SbbCtx*                     m_currSbbCtx;
        SbbCtx*                     m_prevSbbCtx;
        uint8_t                     m_memory[ 8 * ( MAX_TU_SIZE * MAX_TU_SIZE + MLS_GRP_NUM ) ];
      };
    
    
    #if JVET_M0470
      const int32_t g_goRiceBits[4][RICEMAX] =
      {
          { 32768,	65536,	98304,	131072,	163840,	196608,	262144,	262144,	327680,	327680,	327680,	327680,	393216,	393216,	393216,	393216,	393216,	393216,	393216,	393216,	458752,	458752,	458752,	458752,	458752,	458752,	458752,	458752,	458752,	458752,	458752,	458752},
          { 65536,	65536,	98304,	98304,	131072,	131072,	163840,	163840,	196608,	196608,	229376,	229376,	294912,	294912,	294912,	294912,	360448,	360448,	360448,	360448,	360448,	360448,	360448,	360448,	425984,	425984,	425984,	425984,	425984,	425984,	425984,	425984},
          { 98304,	98304,	98304,	98304,	131072,	131072,	131072,	131072,	163840,	163840,	163840,	163840,	196608,	196608,	196608,	196608,	229376,	229376,	229376,	229376,	262144,	262144,	262144,	262144,	327680,	327680,	327680,	327680,	327680,	327680,	327680,	327680},
          { 131072,	131072,	131072,	131072,	131072,	131072,	131072,	131072,	163840,	163840,	163840,	163840,	163840,	163840,	163840,	163840,	196608,	196608,	196608,	196608,	196608,	196608,	196608,	196608,	229376,	229376,	229376,	229376,	229376,	229376,	229376,	229376}
      };
    #else
    
      const int32_t g_goRiceBits[4][RICEMAX] =
      {
        {  32768,  65536,  98304, 131072, 163840, 196608, 229376, 294912, 294912, 360448, 360448, 360448, 360448, 425984, 425984, 425984, 425984, 425984, 425984, 425984, 425984, 491520, 491520, 491520, 491520, 491520, 491520, 491520, 491520, 491520, 491520, 491520 },
        {  65536,  65536,  98304,  98304, 131072, 131072, 163840, 163840, 196608, 196608, 229376, 229376, 294912, 294912, 294912, 294912, 360448, 360448, 360448, 360448, 360448, 360448, 360448, 360448, 425984, 425984, 425984, 425984, 425984, 425984, 425984, 425984 },
        {  98304,  98304,  98304,  98304, 131072, 131072, 131072, 131072, 163840, 163840, 163840, 163840, 196608, 196608, 196608, 196608, 229376, 229376, 229376, 229376, 262144, 262144, 262144, 262144, 294912, 294912, 294912, 294912, 360448, 360448, 360448, 360448 },
        { 131072, 131072, 131072, 131072, 131072, 131072, 131072, 131072, 163840, 163840, 163840, 163840, 163840, 163840, 163840, 163840, 196608, 196608, 196608, 196608, 196608, 196608, 196608, 196608, 229376, 229376, 229376, 229376, 229376, 229376, 229376, 229376 }
      };
    
    
      class State
      {
        friend class CommonCtx;
      public:
        State( const RateEstimator& rateEst, CommonCtx& commonCtx, const int stateId );
    
        template<uint8_t numIPos>
        inline void updateState(const ScanInfo &scanInfo, const State *prevStates, const Decision &decision);
        inline void updateStateEOS(const ScanInfo &scanInfo, const State *prevStates, const State *skipStates,
                                   const Decision &decision);
    
        inline void init()
        {
          m_rdCost        = std::numeric_limits<int64_t>::max()>>1;
          m_numSigSbb     = 0;
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
          m_remRegBins    = 4;  // just large enough for last scan pos
    #else
    
          m_remRegBins    = 3;  // just large enough for last scan pos
    
          m_refSbbCtxId   = -1;
          m_sigFracBits   = m_sigFracBitsArray[ 0 ];
          m_coeffFracBits = m_gtxFracBitsArray[ 0 ];
          m_goRicePar     = 0;
    
        void checkRdCosts( const ScanPosType spt, const PQData& pqDataA, const PQData& pqDataB, Decision& decisionA, Decision& decisionB) const
        {
          const int32_t*  goRiceTab = g_goRiceBits[m_goRicePar];
          int64_t         rdCostA   = m_rdCost + pqDataA.deltaDist;
          int64_t         rdCostB   = m_rdCost + pqDataB.deltaDist;
          int64_t         rdCostZ   = m_rdCost;
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
          if( m_remRegBins >= 4 )
    #else
    
          {
            if( pqDataA.absLevel < 4 )
              rdCostA += m_coeffFracBits.bits[pqDataA.absLevel];
            else
            {
              const unsigned value = (pqDataA.absLevel - 4) >> 1;
              rdCostA += m_coeffFracBits.bits[pqDataA.absLevel - (value << 1)] + goRiceTab[value<RICEMAX ? value : RICEMAX-1];
            }
            if( pqDataB.absLevel < 4 )
              rdCostB += m_coeffFracBits.bits[pqDataB.absLevel];
            else
            {
              const unsigned value = (pqDataB.absLevel - 4) >> 1;
              rdCostB += m_coeffFracBits.bits[pqDataB.absLevel - (value << 1)] + goRiceTab[value<RICEMAX ? value : RICEMAX-1];
            }
            if( spt == SCAN_ISCSBB )
            {
              rdCostA += m_sigFracBits.intBits[1];
              rdCostB += m_sigFracBits.intBits[1];
              rdCostZ += m_sigFracBits.intBits[0];
            }
            else if( spt == SCAN_SOCSBB )
            {
              rdCostA += m_sbbFracBits.intBits[1] + m_sigFracBits.intBits[1];
              rdCostB += m_sbbFracBits.intBits[1] + m_sigFracBits.intBits[1];
              rdCostZ += m_sbbFracBits.intBits[1] + m_sigFracBits.intBits[0];
            }
            else if( m_numSigSbb )
            {
              rdCostA += m_sigFracBits.intBits[1];
              rdCostB += m_sigFracBits.intBits[1];
              rdCostZ += m_sigFracBits.intBits[0];
            }
            else
            {
              rdCostZ = decisionA.rdCost;
            }
          }
          else
          {
            rdCostA += (1 << SCALE_BITS) + goRiceTab[pqDataA.absLevel <= m_goRiceZero ? pqDataA.absLevel - 1 : (pqDataA.absLevel<RICEMAX ? pqDataA.absLevel : RICEMAX-1)];
            rdCostB += (1 << SCALE_BITS) + goRiceTab[pqDataB.absLevel <= m_goRiceZero ? pqDataB.absLevel - 1 : (pqDataB.absLevel<RICEMAX ? pqDataB.absLevel : RICEMAX-1)];
            rdCostZ += goRiceTab[m_goRiceZero];
          }
          if( rdCostA < decisionA.rdCost )
          {
            decisionA.rdCost   = rdCostA;
            decisionA.absLevel = pqDataA.absLevel;
            decisionA.prevId   = m_stateId;
          }
          if( rdCostZ < decisionA.rdCost )
          {
            decisionA.rdCost   = rdCostZ;
            decisionA.absLevel = 0;
            decisionA.prevId   = m_stateId;
          }
          if( rdCostB < decisionB.rdCost )
          {
            decisionB.rdCost   = rdCostB;
            decisionB.absLevel = pqDataB.absLevel;
            decisionB.prevId   = m_stateId;
          }
        }
    
    
        inline void checkRdCostStart(int32_t lastOffset, const PQData &pqData, Decision &decision) const
        {
    
          int64_t rdCost = pqData.deltaDist + lastOffset;
          if (pqData.absLevel < 4)
          {
            rdCost += m_coeffFracBits.bits[pqData.absLevel];
          }
          else
          {
            const unsigned value = (pqData.absLevel - 4) >> 1;
            rdCost += m_coeffFracBits.bits[pqData.absLevel - (value << 1)] + g_goRiceBits[m_goRicePar][value < RICEMAX ? value : RICEMAX-1];
          }
    
          if( rdCost < decision.rdCost )
          {
            decision.rdCost   = rdCost;
            decision.absLevel = pqData.absLevel;
            decision.prevId   = -1;
          }
        }
    
        inline void checkRdCostSkipSbb(Decision &decision) const
        {
          int64_t rdCost = m_rdCost + m_sbbFracBits.intBits[0];
          if( rdCost < decision.rdCost )
          {
            decision.rdCost   = rdCost;
            decision.absLevel = 0;
            decision.prevId   = 4+m_stateId;
          }
        }
    
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
    
        inline void checkRdCostSkipSbbZeroOut(Decision &decision) const
    
        {
          int64_t rdCost = m_rdCost + m_sbbFracBits.intBits[0];
          decision.rdCost = rdCost;
          decision.absLevel = 0;
          decision.prevId = 4 + m_stateId;
        }
    #endif
    
    
      private:
        int64_t                   m_rdCost;
        uint16_t                  m_absLevelsAndCtxInit[24];  // 16x8bit for abs levels + 16x16bit for ctx init id
    
        int8_t                    m_numSigSbb;
        int8_t                    m_remRegBins;
        int8_t                    m_refSbbCtxId;
    
        BinFracBits               m_sbbFracBits;
        BinFracBits               m_sigFracBits;
        CoeffFracBits             m_coeffFracBits;
    
        int8_t                    m_goRicePar;
        int8_t                    m_goRiceZero;
        const int8_t              m_stateId;
    
        const BinFracBits*const   m_sigFracBitsArray;
        const CoeffFracBits*const m_gtxFracBitsArray;
    
        const uint32_t*const      m_goRiceZeroArray;
    
        CommonCtx&                m_commonCtx;
      };
    
    
      State::State( const RateEstimator& rateEst, CommonCtx& commonCtx, const int stateId )
        : m_sbbFracBits     { { 0, 0 } }
        , m_stateId         ( stateId )
        , m_sigFracBitsArray( rateEst.sigFlagBits(stateId) )
        , m_gtxFracBitsArray( rateEst.gtxFracBits(stateId) )
    
        , m_goRiceZeroArray ( g_auiGoRicePosCoeff0[std::max(0,stateId-1)] )
    
        , m_commonCtx       ( commonCtx )
      {
      }
    
      template<uint8_t numIPos>
      inline void State::updateState(const ScanInfo &scanInfo, const State *prevStates, const Decision &decision)
      {
        m_rdCost = decision.rdCost;
        if( decision.prevId > -2 )
        {
          if( decision.prevId >= 0 )
          {
            const State*  prvState  = prevStates            +   decision.prevId;
            m_numSigSbb             = prvState->m_numSigSbb + !!decision.absLevel;
            m_refSbbCtxId           = prvState->m_refSbbCtxId;
            m_sbbFracBits           = prvState->m_sbbFracBits;
    
            m_remRegBins            = prvState->m_remRegBins - 1;
            m_goRicePar             = prvState->m_goRicePar;
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
            if( m_remRegBins >= 4 )
    #else
    
            {
              TCoeff rem = (decision.absLevel - 4) >> 1;
              if( m_goRicePar < 3 && rem > (3<<m_goRicePar)-1 )
              {
                m_goRicePar++;
              }
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
              m_remRegBins -= (decision.absLevel < 2 ? decision.absLevel : 3);
    #else
    
              m_remRegBins -= std::min<TCoeff>( decision.absLevel, 2 );
    
            ::memcpy( m_absLevelsAndCtxInit, prvState->m_absLevelsAndCtxInit, 48*sizeof(uint8_t) );
          }
          else
          {
            m_numSigSbb     =  1;
            m_refSbbCtxId   = -1;
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
              m_remRegBins = MAX_NUM_REG_BINS_2x2SUBBLOCK - (decision.absLevel < 2 ? decision.absLevel : 3);
    #else
    
              m_remRegBins  = MAX_NUM_REG_BINS_2x2SUBBLOCK - MAX_NUM_GT2_BINS_2x2SUBBLOCK - std::min<TCoeff>( decision.absLevel, 2 );
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
              m_remRegBins = MAX_NUM_REG_BINS_4x4SUBBLOCK - (decision.absLevel < 2 ? decision.absLevel : 3);
    #else
    
              m_remRegBins  = MAX_NUM_REG_BINS_4x4SUBBLOCK - MAX_NUM_GT2_BINS_4x4SUBBLOCK - std::min<TCoeff>( decision.absLevel, 2 );
    
            }
            m_goRicePar     = ( ((decision.absLevel - 4) >> 1) > (3<<0)-1 ? 1 : 0 );
    
            ::memset( m_absLevelsAndCtxInit, 0, 48*sizeof(uint8_t) );
          }
    
          uint8_t* levels               = reinterpret_cast<uint8_t*>(m_absLevelsAndCtxInit);
          levels[ scanInfo.insidePos ]  = (uint8_t)std::min<TCoeff>( 255, decision.absLevel );
    
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
          if (m_remRegBins >= 4)
    #else
    
          {
            TCoeff  tinit = m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos];
            TCoeff  sumAbs1 = (tinit >> 3) & 31;
            TCoeff  sumNum = tinit & 7;
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
    #define UPDATE(k) {TCoeff t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs1+=std::min<TCoeff>(4+(t&1),t); sumNum+=!!t; }
    #else
    
    #define UPDATE(k) {TCoeff t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs1+=std::min<TCoeff>(2+(t&1),t); sumNum+=!!t; }
    
            if (numIPos == 1)
            {
              UPDATE(0);
            }
            else if (numIPos == 2)
            {
              UPDATE(0);
              UPDATE(1);
            }
            else if (numIPos == 3)
            {
              UPDATE(0);
              UPDATE(1);
              UPDATE(2);
            }
            else if (numIPos == 4)
            {
              UPDATE(0);
              UPDATE(1);
              UPDATE(2);
              UPDATE(3);
            }
            else if (numIPos == 5)
            {
              UPDATE(0);
              UPDATE(1);
              UPDATE(2);
              UPDATE(3);
              UPDATE(4);
            }
    #undef UPDATE
            TCoeff sumGt1 = sumAbs1 - sumNum;
            m_sigFracBits = m_sigFracBitsArray[scanInfo.sigCtxOffsetNext + (sumAbs1 < 5 ? sumAbs1 : 5)];
            m_coeffFracBits = m_gtxFracBitsArray[scanInfo.gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)];
          }
          else
          {
            TCoeff  sumAbs = m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos] >> 8;
    #define UPDATE(k) {TCoeff t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs+=t; }
            if (numIPos == 1)
            {
              UPDATE(0);
            }
            else if (numIPos == 2)
            {
              UPDATE(0);
              UPDATE(1);
            }
            else if (numIPos == 3)
            {
              UPDATE(0);
              UPDATE(1);
              UPDATE(2);
            }
            else if (numIPos == 4)
            {
              UPDATE(0);
              UPDATE(1);
              UPDATE(2);
              UPDATE(3);
            }
            else if (numIPos == 5)
            {
              UPDATE(0);
              UPDATE(1);
              UPDATE(2);
              UPDATE(3);
              UPDATE(4);
            }
    #undef UPDATE
            sumAbs = std::min(31, sumAbs);
            m_goRicePar = g_auiGoRiceParsCoeff[sumAbs];
            m_goRiceZero = m_goRiceZeroArray[sumAbs];
          }
    
        }
      }
    
      inline void State::updateStateEOS(const ScanInfo &scanInfo, const State *prevStates, const State *skipStates,
                                        const Decision &decision)
      {
        m_rdCost = decision.rdCost;
        if( decision.prevId > -2 )
        {
          const State* prvState = 0;
    
          if( decision.prevId  >= 4 )
          {
            CHECK( decision.absLevel != 0, "cannot happen" );
            prvState    = skipStates + ( decision.prevId - 4 );
            m_numSigSbb = 0;
            ::memset( m_absLevelsAndCtxInit, 0, 16*sizeof(uint8_t) );
          }
          else if( decision.prevId  >= 0 )
    
            prvState    = prevStates            +   decision.prevId;
            m_numSigSbb = prvState->m_numSigSbb + !!decision.absLevel;
    
            ::memcpy( m_absLevelsAndCtxInit, prvState->m_absLevelsAndCtxInit, 16*sizeof(uint8_t) );
          }
          else
          {
            m_numSigSbb = 1;
            ::memset( m_absLevelsAndCtxInit, 0, 16*sizeof(uint8_t) );
          }
          reinterpret_cast<uint8_t*>(m_absLevelsAndCtxInit)[ scanInfo.insidePos ] = (uint8_t)std::min<TCoeff>( 255, decision.absLevel );
    
          m_commonCtx.update( scanInfo, prvState, *this );
    
          TCoeff  tinit   = m_absLevelsAndCtxInit[ 8 + scanInfo.nextInsidePos ];
          TCoeff  sumNum  =   tinit        & 7;
          TCoeff  sumAbs1 = ( tinit >> 3 ) & 31;
          TCoeff  sumGt1  = sumAbs1        - sumNum;
          m_sigFracBits   = m_sigFracBitsArray[ scanInfo.sigCtxOffsetNext + ( sumAbs1 < 5 ? sumAbs1 : 5 ) ];
          m_coeffFracBits = m_gtxFracBitsArray[ scanInfo.gtxCtxOffsetNext + ( sumGt1  < 4 ? sumGt1  : 4 ) ];
        }
      }
    
      inline void CommonCtx::update(const ScanInfo &scanInfo, const State *prevState, State &currState)
      {
        uint8_t*    sbbFlags  = m_currSbbCtx[ currState.m_stateId ].sbbFlags;
        uint8_t*    levels    = m_currSbbCtx[ currState.m_stateId ].levels;
        std::size_t setCpSize = m_nbInfo[ scanInfo.scanIdx - 1 ].maxDist * sizeof(uint8_t);
        if( prevState && prevState->m_refSbbCtxId >= 0 )
        {
          ::memcpy( sbbFlags,                  m_prevSbbCtx[prevState->m_refSbbCtxId].sbbFlags,                  scanInfo.numSbb*sizeof(uint8_t) );
          ::memcpy( levels + scanInfo.scanIdx, m_prevSbbCtx[prevState->m_refSbbCtxId].levels + scanInfo.scanIdx, setCpSize );
        }
        else
        {
          ::memset( sbbFlags,                  0, scanInfo.numSbb*sizeof(uint8_t) );
          ::memset( levels + scanInfo.scanIdx, 0, setCpSize );
        }
        sbbFlags[ scanInfo.sbbPos ] = !!currState.m_numSigSbb;
        ::memcpy( levels + scanInfo.scanIdx, currState.m_absLevelsAndCtxInit, scanInfo.sbbSize*sizeof(uint8_t) );
    
        const int       sigNSbb   = ( ( scanInfo.nextSbbRight ? sbbFlags[ scanInfo.nextSbbRight ] : false ) || ( scanInfo.nextSbbBelow ? sbbFlags[ scanInfo.nextSbbBelow ] : false ) ? 1 : 0 );
        currState.m_numSigSbb     = 0;
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
          currState.m_remRegBins  = MAX_NUM_REG_BINS_2x2SUBBLOCK;
    #else
    
          currState.m_remRegBins  = MAX_NUM_REG_BINS_2x2SUBBLOCK - MAX_NUM_GT2_BINS_2x2SUBBLOCK;
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
          currState.m_remRegBins  = MAX_NUM_REG_BINS_4x4SUBBLOCK;
    #else
    
          currState.m_remRegBins  = MAX_NUM_REG_BINS_4x4SUBBLOCK - MAX_NUM_GT2_BINS_4x4SUBBLOCK;
    
        currState.m_refSbbCtxId   = currState.m_stateId;
        currState.m_sbbFracBits   = m_sbbFlagBits[ sigNSbb ];
    
        uint16_t          templateCtxInit[16];
        const int         scanBeg   = scanInfo.scanIdx - scanInfo.sbbSize;
        const NbInfoOut*  nbOut     = m_nbInfo + scanBeg;
        const uint8_t*    absLevels = levels   + scanBeg;
        for( int id = 0; id < scanInfo.sbbSize; id++, nbOut++ )
        {
          if( nbOut->num )
          {
            TCoeff sumAbs = 0, sumAbs1 = 0, sumNum = 0;
    
    #if JVET_M0173_MOVE_GT2_TO_FIRST_PASS
    #define UPDATE(k) {TCoeff t=absLevels[nbOut->outPos[k]]; sumAbs+=t; sumAbs1+=std::min<TCoeff>(4+(t&1),t); sumNum+=!!t; }
    #else
    
    #define UPDATE(k) {TCoeff t=absLevels[nbOut->outPos[k]]; sumAbs+=t; sumAbs1+=std::min<TCoeff>(2+(t&1),t); sumNum+=!!t; }
    
            UPDATE(0);
            if( nbOut->num > 1 )
            {
              UPDATE(1);
              if( nbOut->num > 2 )
              {
                UPDATE(2);
                if( nbOut->num > 3 )
                {
                  UPDATE(3);
                  if( nbOut->num > 4 )
                  {
                    UPDATE(4);
                  }
                }
              }
            }
    #undef UPDATE
            templateCtxInit[id] = uint16_t(sumNum) + ( uint16_t(sumAbs1) << 3 ) + ( (uint16_t)std::min<TCoeff>( 127, sumAbs ) << 8 );
          }
          else
          {
            templateCtxInit[id] = 0;
          }
        }
        ::memset( currState.m_absLevelsAndCtxInit,     0,               16*sizeof(uint8_t) );
        ::memcpy( currState.m_absLevelsAndCtxInit + 8, templateCtxInit, 16*sizeof(uint16_t) );
      }
    
    
    
      /*================================================================================*/
      /*=====                                                                      =====*/
      /*=====   T C Q                                                              =====*/
      /*=====                                                                      =====*/
      /*================================================================================*/
      class DepQuant : private RateEstimator
      {
      public:
        DepQuant();
    
        void    quant   ( TransformUnit& tu, const CCoeffBuf& srcCoeff, const ComponentID compID, const QpParam& cQP, const double lambda, const Ctx& ctx, TCoeff& absSum );
        void    dequant ( const TransformUnit& tu,  CoeffBuf& recCoeff, const ComponentID compID, const QpParam& cQP )  const;
    
      private:
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
    
        void    xDecideAndUpdate  ( const TCoeff absCoeff, const ScanInfo& scanInfo, bool zeroOut );
        void    xDecide           ( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions, bool zeroOut );
    
        void    xDecideAndUpdate  ( const TCoeff absCoeff, const ScanInfo& scanInfo );
    
        void    xDecide           ( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions );
    
    #endif
    
    
      private:
        CommonCtx   m_commonCtx;
        State       m_allStates[ 12 ];
        State*      m_currStates;
        State*      m_prevStates;
        State*      m_skipStates;
        State       m_startState;
        Quantizer   m_quant;
        Decision    m_trellis[ MAX_TU_SIZE * MAX_TU_SIZE ][ 8 ];
      };
    
    
    #define TINIT(x) {*this,m_commonCtx,x}
      DepQuant::DepQuant()
        : RateEstimator ()
        , m_commonCtx   ()
        , m_allStates   {TINIT(0),TINIT(1),TINIT(2),TINIT(3),TINIT(0),TINIT(1),TINIT(2),TINIT(3),TINIT(0),TINIT(1),TINIT(2),TINIT(3)}
        , m_currStates  (  m_allStates      )
        , m_prevStates  (  m_currStates + 4 )
        , m_skipStates  (  m_prevStates + 4 )
        , m_startState  TINIT(0)
      {}
    #undef TINIT
    
    
      void DepQuant::dequant( const TransformUnit& tu,  CoeffBuf& recCoeff, const ComponentID compID, const QpParam& cQP ) const
      {
        m_quant.dequantBlock( tu, compID, cQP, recCoeff );
      }
    
    
    #define DINIT(l,p) {std::numeric_limits<int64_t>::max()>>2,l,p}
      static const Decision startDec[8] = {DINIT(-1,-2),DINIT(-1,-2),DINIT(-1,-2),DINIT(-1,-2),DINIT(0,4),DINIT(0,5),DINIT(0,6),DINIT(0,7)};
    #undef  DINIT
    
    
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
    
      void DepQuant::xDecide( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions, bool zeroOut)
    
      void DepQuant::xDecide( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions)
    
    #endif
    
      {
        ::memcpy( decisions, startDec, 8*sizeof(Decision) );
    
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
        if( zeroOut )
        {
          if( spt==SCAN_EOCSBB )
          {
            m_skipStates[0].checkRdCostSkipSbbZeroOut( decisions[0] );
            m_skipStates[1].checkRdCostSkipSbbZeroOut( decisions[1] );
            m_skipStates[2].checkRdCostSkipSbbZeroOut( decisions[2] );
            m_skipStates[3].checkRdCostSkipSbbZeroOut( decisions[3] );
          }
          return;
        }
    #endif
    
    
        PQData  pqData[4];
        m_quant.preQuantCoeff( absCoeff, pqData );
    
        m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2]);
        m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0]);
        m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3]);
        m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1]);
    
        if( spt==SCAN_EOCSBB )
        {
          m_skipStates[0].checkRdCostSkipSbb( decisions[0] );
          m_skipStates[1].checkRdCostSkipSbb( decisions[1] );
          m_skipStates[2].checkRdCostSkipSbb( decisions[2] );
          m_skipStates[3].checkRdCostSkipSbb( decisions[3] );
        }
        m_startState.checkRdCostStart( lastOffset, pqData[0], decisions[0] );
        m_startState.checkRdCostStart( lastOffset, pqData[2], decisions[2] );
      }
    
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
    
      void DepQuant::xDecideAndUpdate( const TCoeff absCoeff, const ScanInfo& scanInfo, bool zeroOut )
    
      void DepQuant::xDecideAndUpdate( const TCoeff absCoeff, const ScanInfo& scanInfo )
    
    #endif
    
      {
        Decision* decisions = m_trellis[ scanInfo.scanIdx ];
    
        std::swap( m_prevStates, m_currStates );
    
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
    
        xDecide( scanInfo.spt, absCoeff, lastOffset(scanInfo.scanIdx), decisions, zeroOut );
    
        xDecide( scanInfo.spt, absCoeff, lastOffset(scanInfo.scanIdx), decisions);
    
    #endif
    
    
        if( scanInfo.scanIdx )
        {
          if( scanInfo.eosbb )
          {
            m_commonCtx.swap();
            m_currStates[0].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[0] );
            m_currStates[1].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[1] );
            m_currStates[2].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[2] );
            m_currStates[3].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[3] );
            ::memcpy( decisions+4, decisions, 4*sizeof(Decision) );
          }
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
          else if( !zeroOut )
    #else
    
          {
            switch( scanInfo.nextNbInfoSbb.num )
            {
            case 0:
              m_currStates[0].updateState<0>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<0>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<0>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<0>( scanInfo, m_prevStates, decisions[3] );
              break;
            case 1:
              m_currStates[0].updateState<1>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<1>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<1>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<1>( scanInfo, m_prevStates, decisions[3] );
              break;
            case 2:
              m_currStates[0].updateState<2>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<2>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<2>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<2>( scanInfo, m_prevStates, decisions[3] );
              break;
            case 3:
              m_currStates[0].updateState<3>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<3>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<3>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<3>( scanInfo, m_prevStates, decisions[3] );
              break;
            case 4:
              m_currStates[0].updateState<4>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<4>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<4>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<4>( scanInfo, m_prevStates, decisions[3] );
              break;
            default:
              m_currStates[0].updateState<5>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<5>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<5>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<5>( scanInfo, m_prevStates, decisions[3] );
            }
          }
    
    
          {
            std::swap( m_prevStates, m_skipStates );
          }
        }
      }
    
    
      void DepQuant::quant( TransformUnit& tu, const CCoeffBuf& srcCoeff, const ComponentID compID, const QpParam& cQP, const double lambda, const Ctx& ctx, TCoeff& absSum )
      {
    
        CHECKD( tu.cs->sps->getSpsRangeExtension().getExtendedPrecisionProcessingFlag(), "ext precision is not supported" );
    
    
        const TUParameters& tuPars  = *g_Rom.getTUPars( tu.blocks[compID], compID );
    
        m_quant.initQuantBlock    ( tu, compID, cQP, lambda );
        TCoeff*       qCoeff      = tu.getCoeffs( compID ).buf;
        const TCoeff* tCoeff      = srcCoeff.buf;
        const int     numCoeff    = tu.blocks[compID].area();
        ::memset( tu.getCoeffs( compID ).buf, 0x00, numCoeff*sizeof(TCoeff) );
        absSum          = 0;
    
        //===== find first test position =====
        int   firstTestPos = numCoeff - 1;
        const TCoeff thres = m_quant.getLastThreshold();
        for( ; firstTestPos >= 0; firstTestPos-- )
        {
    
    Frank Bossen's avatar
    Frank Bossen committed
          if (abs(tCoeff[tuPars.m_scanId2BlkPos[firstTestPos].idx]) > thres)
    
        RateEstimator::initCtx( tuPars, tu, compID, ctx.getFracBitsAcess() );
        m_commonCtx.reset( tuPars, *this );
    
        for( int k = 0; k < 12; k++ )
        {
          m_allStates[k].init();
        }
        m_startState.init();
    
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
    
        int effWidth = tuPars.m_width, effHeight = tuPars.m_height;
        bool zeroOut = false;
    
    #if JVET_M0464_UNI_MTS
        if( tu.mtsIdx > 1 && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
    #else
        if( tu.cu->emtFlag && !tu.transformSkip[compID] && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
    #endif
        {
    
          effHeight = ( tuPars.m_height == 32 ) ? 16 : tuPars.m_height;
          effWidth = ( tuPars.m_width == 32 ) ? 16 : tuPars.m_width;
          zeroOut  = ( effHeight < tuPars.m_height || effWidth < tuPars.m_width );
    
        for( int scanIdx = firstTestPos; scanIdx >= 0; scanIdx-- )
        {
          const ScanInfo& scanInfo = tuPars.m_scanInfo[ scanIdx ];
    
    #if JVET_M0297_32PT_MTS_ZERO_OUT
    
          xDecideAndUpdate( abs( tCoeff[ scanInfo.rasterPos ] ), scanInfo, zeroOut && ( scanInfo.posX >= effWidth || scanInfo.posY >= effHeight ) );
    
          xDecideAndUpdate( abs( tCoeff[ scanInfo.rasterPos ] ), scanInfo );
    
    #endif
    
    
        //===== find best path =====
        Decision  decision    = { std::numeric_limits<int64_t>::max(), -1, -2 };
        int64_t   minPathCost =  0;
        for( int8_t stateId = 0; stateId < 4; stateId++ )
        {
          int64_t pathCost = m_trellis[0][stateId].rdCost;
          if( pathCost < minPathCost )
          {
            decision.prevId = stateId;
            minPathCost     = pathCost;
          }
        }
    
        //===== backward scanning =====
        int scanIdx = 0;
        for( ; decision.prevId >= 0; scanIdx++ )
        {
          decision          = m_trellis[ scanIdx ][ decision.prevId ];
    
    Frank Bossen's avatar
    Frank Bossen committed
          int32_t blkpos    = tuPars.m_scanId2BlkPos[scanIdx].idx;
    
          qCoeff[ blkpos ]  = ( tCoeff[ blkpos ] < 0 ? -decision.absLevel : decision.absLevel );
          absSum           += decision.absLevel;
        }
      }
    
    }; // namespace DQIntern
    
    
    
    
    //===== interface class =====
    DepQuant::DepQuant( const Quant* other, bool enc ) : QuantRDOQ( other )
    {
      const DepQuant* dq = dynamic_cast<const DepQuant*>( other );
      CHECK( other && !dq, "The DepQuant cast must be successfull!" );
      p = new DQIntern::DepQuant();
      if( enc )
      {
        DQIntern::g_Rom.init();
      }
    }
    
    DepQuant::~DepQuant()
    {
      delete static_cast<DQIntern::DepQuant*>(p);
    }
    
    void DepQuant::quant( TransformUnit &tu, const ComponentID &compID, const CCoeffBuf &pSrc, TCoeff &uiAbsSum, const QpParam &cQP, const Ctx& ctx )
    {
      if( tu.cs->slice->getDepQuantEnabledFlag() )
      {
        static_cast<DQIntern::DepQuant*>(p)->quant( tu, pSrc, compID, cQP, Quant::m_dLambda, ctx, uiAbsSum );
      }
      else
      {
        QuantRDOQ::quant( tu, compID, pSrc, uiAbsSum, cQP, ctx );
      }
    }
    
    void DepQuant::dequant( const TransformUnit &tu, CoeffBuf &dstCoeff, const ComponentID &compID, const QpParam &cQP )
    {
      if( tu.cs->slice->getDepQuantEnabledFlag() )
      {
        static_cast<DQIntern::DepQuant*>(p)->dequant( tu, dstCoeff, compID, cQP );
      }
      else
      {
        QuantRDOQ::dequant( tu, dstCoeff, compID, cQP );
      }
    }