Skip to content
Snippets Groups Projects
DepQuant.cpp 72.2 KiB
Newer Older
          rdCostA += ( 1 << SCALE_BITS ) + goRiceTab[ pqDataA.absLevel <= m_goRiceZero ? pqDataA.absLevel - 1 : ( pqDataA.absLevel < RICEMAX ? pqDataA.absLevel : RICEMAX - 1 ) ];
          rdCostB += ( 1 << SCALE_BITS ) + goRiceTab[ pqDataB.absLevel <= m_goRiceZero ? pqDataB.absLevel - 1 : ( pqDataB.absLevel < RICEMAX ? pqDataB.absLevel : RICEMAX - 1 ) ];
          rdCostZ += goRiceTab[ m_goRiceZero ];
        if( rdCostA < decisionA.rdCost )
          decisionA.rdCost = rdCostA;
          decisionA.absLevel = pqDataA.absLevel;
          decisionA.prevId = m_stateId;
        if( rdCostZ < decisionA.rdCost )
        {
          decisionA.rdCost = rdCostZ;
          decisionA.absLevel = 0;
          decisionA.prevId = m_stateId;
        }
        if( rdCostB < decisionB.rdCost )
        {
          decisionB.rdCost = rdCostB;
          decisionB.absLevel = pqDataB.absLevel;
          decisionB.prevId = m_stateId;
        }
#if !JVET_O0094_LFNST_ZERO_PRIM_COEFFS
Mischa Siekmann's avatar
Mischa Siekmann committed
#endif

    inline void checkRdCostStart(int32_t lastOffset, const PQData &pqData, Decision &decision) const
    {
      int64_t rdCost = pqData.deltaDist + lastOffset;
      if (pqData.absLevel < 4)
      {
        rdCost += m_coeffFracBits.bits[pqData.absLevel];
      }
      else
      {
        const unsigned value = (pqData.absLevel - 4) >> 1;
        rdCost += m_coeffFracBits.bits[pqData.absLevel - (value << 1)] + g_goRiceBits[m_goRicePar][value < RICEMAX ? value : RICEMAX-1];
      }
      if( rdCost < decision.rdCost )
      {
        decision.rdCost   = rdCost;
        decision.absLevel = pqData.absLevel;
        decision.prevId   = -1;
      }
    }

    inline void checkRdCostSkipSbb(Decision &decision) const
    {
      int64_t rdCost = m_rdCost + m_sbbFracBits.intBits[0];
      if( rdCost < decision.rdCost )
      {
        decision.rdCost   = rdCost;
        decision.absLevel = 0;
        decision.prevId   = 4+m_stateId;
      }
    }

    inline void checkRdCostSkipSbbZeroOut(Decision &decision) const
    {
      int64_t rdCost = m_rdCost + m_sbbFracBits.intBits[0];
      decision.rdCost = rdCost;
      decision.absLevel = 0;
      decision.prevId = 4 + m_stateId;
    }

  private:
    int64_t                   m_rdCost;
    uint16_t                  m_absLevelsAndCtxInit[24];  // 16x8bit for abs levels + 16x16bit for ctx init id
#if JVET_O0052_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT
    int                       m_remRegBins;
#else
    BinFracBits               m_sbbFracBits;
    BinFracBits               m_sigFracBits;
    CoeffFracBits             m_coeffFracBits;
    int8_t                    m_goRicePar;
    int8_t                    m_goRiceZero;
    const int8_t              m_stateId;
    const BinFracBits*const   m_sigFracBitsArray;
    const CoeffFracBits*const m_gtxFracBitsArray;
    const uint32_t*const      m_goRiceZeroArray;
#if JVET_O0052_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT
  public:
    unsigned                  effWidth;
    unsigned                  effHeight;
#endif
  };


  State::State( const RateEstimator& rateEst, CommonCtx& commonCtx, const int stateId )
    : m_sbbFracBits     { { 0, 0 } }
    , m_stateId         ( stateId )
    , m_sigFracBitsArray( rateEst.sigFlagBits(stateId) )
    , m_gtxFracBitsArray( rateEst.gtxFracBits(stateId) )
    , m_goRiceZeroArray ( g_auiGoRicePosCoeff0[std::max(0,stateId-1)] )
    , m_commonCtx       ( commonCtx )
  {
  }

  template<uint8_t numIPos>
  inline void State::updateState(const ScanInfo &scanInfo, const State *prevStates, const Decision &decision)
  {
    m_rdCost = decision.rdCost;
    if( decision.prevId > -2 )
    {
      if( decision.prevId >= 0 )
      {
        const State*  prvState  = prevStates            +   decision.prevId;
        m_numSigSbb             = prvState->m_numSigSbb + !!decision.absLevel;
        m_refSbbCtxId           = prvState->m_refSbbCtxId;
        m_sbbFracBits           = prvState->m_sbbFracBits;
        m_remRegBins            = prvState->m_remRegBins - 1;
        m_goRicePar             = prvState->m_goRicePar;
          m_remRegBins -= (decision.absLevel < 2 ? decision.absLevel : 3);
        ::memcpy( m_absLevelsAndCtxInit, prvState->m_absLevelsAndCtxInit, 48*sizeof(uint8_t) );
      }
      else
      {
        m_numSigSbb     =  1;
        m_refSbbCtxId   = -1;
#if JVET_O0052_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT
        int ctxBinSampleRatio = (scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA;
        m_remRegBins = (effWidth * effHeight *ctxBinSampleRatio) / 16 - (decision.absLevel < 2 ? decision.absLevel : 3);
#else
          m_remRegBins = MAX_NUM_REG_BINS_2x2SUBBLOCK - (decision.absLevel < 2 ? decision.absLevel : 3);
          m_remRegBins = MAX_NUM_REG_BINS_4x4SUBBLOCK - (decision.absLevel < 2 ? decision.absLevel : 3);
        ::memset( m_absLevelsAndCtxInit, 0, 48*sizeof(uint8_t) );
      }

      uint8_t* levels               = reinterpret_cast<uint8_t*>(m_absLevelsAndCtxInit);
      levels[ scanInfo.insidePos ]  = (uint8_t)std::min<TCoeff>( 255, decision.absLevel );

      {
        TCoeff  tinit = m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos];
        TCoeff  sumAbs1 = (tinit >> 3) & 31;
        TCoeff  sumNum = tinit & 7;
#define UPDATE(k) {TCoeff t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs1+=std::min<TCoeff>(4+(t&1),t); sumNum+=!!t; }
        if (numIPos == 1)
        {
          UPDATE(0);
        }
        else if (numIPos == 2)
        {
          UPDATE(0);
          UPDATE(1);
        }
        else if (numIPos == 3)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
        }
        else if (numIPos == 4)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
          UPDATE(3);
        }
        else if (numIPos == 5)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
          UPDATE(3);
          UPDATE(4);
        }
#undef UPDATE
        TCoeff sumGt1 = sumAbs1 - sumNum;
#if JVET_O0617_SIG_FLAG_CONTEXT_REDUCTION
        m_sigFracBits = m_sigFracBitsArray[scanInfo.sigCtxOffsetNext + std::min( (sumAbs1+1)>>1, 3 )];
#else
        m_sigFracBits = m_sigFracBitsArray[scanInfo.sigCtxOffsetNext + (sumAbs1 < 5 ? sumAbs1 : 5)];
        m_coeffFracBits = m_gtxFracBitsArray[scanInfo.gtxCtxOffsetNext + (sumGt1 < 4 ? sumGt1 : 4)];
        TCoeff  sumAbs = m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos] >> 8;
#define UPDATE(k) {TCoeff t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs+=t; }
        if (numIPos == 1)
        {
          UPDATE(0);
        }
        else if (numIPos == 2)
        {
          UPDATE(0);
          UPDATE(1);
        }
        else if (numIPos == 3)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
        }
        else if (numIPos == 4)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
          UPDATE(3);
        }
        else if (numIPos == 5)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
          UPDATE(3);
          UPDATE(4);
        }
#undef UPDATE
        int sumAll = std::max(std::min(31, (int)sumAbs - 4 * 5), 0);
        m_goRicePar = g_auiGoRiceParsCoeff[sumAll];
      }
      else
      {
        TCoeff  sumAbs = m_absLevelsAndCtxInit[8 + scanInfo.nextInsidePos] >> 8;
#define UPDATE(k) {TCoeff t=levels[scanInfo.nextNbInfoSbb.inPos[k]]; sumAbs+=t; }
        if (numIPos == 1)
        {
          UPDATE(0);
        }
        else if (numIPos == 2)
        {
          UPDATE(0);
          UPDATE(1);
        }
        else if (numIPos == 3)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
        }
        else if (numIPos == 4)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
          UPDATE(3);
        }
        else if (numIPos == 5)
        {
          UPDATE(0);
          UPDATE(1);
          UPDATE(2);
          UPDATE(3);
          UPDATE(4);
        }
#undef UPDATE
        sumAbs = std::min<TCoeff>(31, sumAbs);
        m_goRicePar = g_auiGoRiceParsCoeff[sumAbs];
        m_goRiceZero = m_goRiceZeroArray[sumAbs];
      }
    }
  }

  inline void State::updateStateEOS(const ScanInfo &scanInfo, const State *prevStates, const State *skipStates,
                                    const Decision &decision)
  {
    m_rdCost = decision.rdCost;
    if( decision.prevId > -2 )
    {
      const State* prvState = 0;
      if( decision.prevId  >= 4 )
      {
        CHECK( decision.absLevel != 0, "cannot happen" );
        prvState    = skipStates + ( decision.prevId - 4 );
        m_numSigSbb = 0;
        ::memset( m_absLevelsAndCtxInit, 0, 16*sizeof(uint8_t) );
      }
      else if( decision.prevId  >= 0 )
        prvState    = prevStates            +   decision.prevId;
        m_numSigSbb = prvState->m_numSigSbb + !!decision.absLevel;
        ::memcpy( m_absLevelsAndCtxInit, prvState->m_absLevelsAndCtxInit, 16*sizeof(uint8_t) );
      }
      else
      {
        m_numSigSbb = 1;
        ::memset( m_absLevelsAndCtxInit, 0, 16*sizeof(uint8_t) );
      }
      reinterpret_cast<uint8_t*>(m_absLevelsAndCtxInit)[ scanInfo.insidePos ] = (uint8_t)std::min<TCoeff>( 255, decision.absLevel );

      m_commonCtx.update( scanInfo, prvState, *this );

      TCoeff  tinit   = m_absLevelsAndCtxInit[ 8 + scanInfo.nextInsidePos ];
      TCoeff  sumNum  =   tinit        & 7;
      TCoeff  sumAbs1 = ( tinit >> 3 ) & 31;
      TCoeff  sumGt1  = sumAbs1        - sumNum;
#if JVET_O0617_SIG_FLAG_CONTEXT_REDUCTION
      m_sigFracBits   = m_sigFracBitsArray[ scanInfo.sigCtxOffsetNext + std::min( (sumAbs1+1)>>1, 3 ) ];
#else
      m_sigFracBits   = m_sigFracBitsArray[ scanInfo.sigCtxOffsetNext + ( sumAbs1 < 5 ? sumAbs1 : 5 ) ];
      m_coeffFracBits = m_gtxFracBitsArray[ scanInfo.gtxCtxOffsetNext + ( sumGt1  < 4 ? sumGt1  : 4 ) ];
    }
  }

  inline void CommonCtx::update(const ScanInfo &scanInfo, const State *prevState, State &currState)
  {
    uint8_t*    sbbFlags  = m_currSbbCtx[ currState.m_stateId ].sbbFlags;
    uint8_t*    levels    = m_currSbbCtx[ currState.m_stateId ].levels;
    std::size_t setCpSize = m_nbInfo[ scanInfo.scanIdx - 1 ].maxDist * sizeof(uint8_t);
    if( prevState && prevState->m_refSbbCtxId >= 0 )
    {
      ::memcpy( sbbFlags,                  m_prevSbbCtx[prevState->m_refSbbCtxId].sbbFlags,                  scanInfo.numSbb*sizeof(uint8_t) );
      ::memcpy( levels + scanInfo.scanIdx, m_prevSbbCtx[prevState->m_refSbbCtxId].levels + scanInfo.scanIdx, setCpSize );
    }
    else
    {
      ::memset( sbbFlags,                  0, scanInfo.numSbb*sizeof(uint8_t) );
      ::memset( levels + scanInfo.scanIdx, 0, setCpSize );
    }
    sbbFlags[ scanInfo.sbbPos ] = !!currState.m_numSigSbb;
    ::memcpy( levels + scanInfo.scanIdx, currState.m_absLevelsAndCtxInit, scanInfo.sbbSize*sizeof(uint8_t) );

    const int       sigNSbb   = ( ( scanInfo.nextSbbRight ? sbbFlags[ scanInfo.nextSbbRight ] : false ) || ( scanInfo.nextSbbBelow ? sbbFlags[ scanInfo.nextSbbBelow ] : false ) ? 1 : 0 );
    currState.m_numSigSbb     = 0;
#if JVET_O0052_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT
    if (prevState)
    {
      currState.m_remRegBins = prevState->m_remRegBins;
    }
    else
    {
      int ctxBinSampleRatio = (scanInfo.chType == CHANNEL_TYPE_LUMA) ? MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_LUMA : MAX_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT_CHROMA;
      currState.m_remRegBins = (currState.effWidth * currState.effHeight *ctxBinSampleRatio) / 16;
    }
#else
      currState.m_remRegBins  = MAX_NUM_REG_BINS_2x2SUBBLOCK;
      currState.m_remRegBins  = MAX_NUM_REG_BINS_4x4SUBBLOCK;
    currState.m_refSbbCtxId   = currState.m_stateId;
    currState.m_sbbFracBits   = m_sbbFlagBits[ sigNSbb ];

    uint16_t          templateCtxInit[16];
    const int         scanBeg   = scanInfo.scanIdx - scanInfo.sbbSize;
    const NbInfoOut*  nbOut     = m_nbInfo + scanBeg;
    const uint8_t*    absLevels = levels   + scanBeg;
    for( int id = 0; id < scanInfo.sbbSize; id++, nbOut++ )
    {
      if( nbOut->num )
      {
        TCoeff sumAbs = 0, sumAbs1 = 0, sumNum = 0;
#define UPDATE(k) {TCoeff t=absLevels[nbOut->outPos[k]]; sumAbs+=t; sumAbs1+=std::min<TCoeff>(4+(t&1),t); sumNum+=!!t; }
        UPDATE(0);
        if( nbOut->num > 1 )
        {
          UPDATE(1);
          if( nbOut->num > 2 )
          {
            UPDATE(2);
            if( nbOut->num > 3 )
            {
              UPDATE(3);
              if( nbOut->num > 4 )
              {
                UPDATE(4);
              }
            }
          }
        }
#undef UPDATE
        templateCtxInit[id] = uint16_t(sumNum) + ( uint16_t(sumAbs1) << 3 ) + ( (uint16_t)std::min<TCoeff>( 127, sumAbs ) << 8 );
      }
      else
      {
        templateCtxInit[id] = 0;
      }
    }
    ::memset( currState.m_absLevelsAndCtxInit,     0,               16*sizeof(uint8_t) );
    ::memcpy( currState.m_absLevelsAndCtxInit + 8, templateCtxInit, 16*sizeof(uint16_t) );
  }



  /*================================================================================*/
  /*=====                                                                      =====*/
  /*=====   T C Q                                                              =====*/
  /*=====                                                                      =====*/
  /*================================================================================*/
  class DepQuant : private RateEstimator
  {
  public:
    DepQuant();

    void    quant   ( TransformUnit& tu, const CCoeffBuf& srcCoeff, const ComponentID compID, const QpParam& cQP, const double lambda, const Ctx& ctx, TCoeff& absSum, bool enableScalingLists, int* quantCoeff );
    void    dequant ( const TransformUnit& tu, CoeffBuf& recCoeff, const ComponentID compID, const QpParam& cQP, bool enableScalingLists, int* quantCoeff );
    void    xDecideAndUpdate  ( const TCoeff absCoeff, const ScanInfo& scanInfo, bool zeroOut, int quantCoeff);
    void    xDecide           ( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions, bool zeroOut, int quantCoeff );

  private:
    CommonCtx   m_commonCtx;
    State       m_allStates[ 12 ];
    State*      m_currStates;
    State*      m_prevStates;
    State*      m_skipStates;
    State       m_startState;
    Quantizer   m_quant;
    Decision    m_trellis[ MAX_TB_SIZEY * MAX_TB_SIZEY ][ 8 ];
  };


#define TINIT(x) {*this,m_commonCtx,x}
  DepQuant::DepQuant()
    : RateEstimator ()
    , m_commonCtx   ()
    , m_allStates   {TINIT(0),TINIT(1),TINIT(2),TINIT(3),TINIT(0),TINIT(1),TINIT(2),TINIT(3),TINIT(0),TINIT(1),TINIT(2),TINIT(3)}
    , m_currStates  (  m_allStates      )
    , m_prevStates  (  m_currStates + 4 )
    , m_skipStates  (  m_prevStates + 4 )
    , m_startState  TINIT(0)
  {}
#undef TINIT


  void DepQuant::dequant( const TransformUnit& tu,  CoeffBuf& recCoeff, const ComponentID compID, const QpParam& cQP, bool enableScalingLists, int* piDequantCoef )
    m_quant.dequantBlock( tu, compID, cQP, recCoeff, enableScalingLists, piDequantCoef );
  }


#define DINIT(l,p) {std::numeric_limits<int64_t>::max()>>2,l,p}
  static const Decision startDec[8] = {DINIT(-1,-2),DINIT(-1,-2),DINIT(-1,-2),DINIT(-1,-2),DINIT(0,4),DINIT(0,5),DINIT(0,6),DINIT(0,7)};
#undef  DINIT


  void DepQuant::xDecide( const ScanPosType spt, const TCoeff absCoeff, const int lastOffset, Decision* decisions, bool zeroOut, int quanCoeff)
  {
    ::memcpy( decisions, startDec, 8*sizeof(Decision) );

#if JVET_O0094_LFNST_ZERO_PRIM_COEFFS
Mischa Siekmann's avatar
Mischa Siekmann committed
    if( zeroOut )
    {
      if( spt==SCAN_EOCSBB )
      {
        m_skipStates[0].checkRdCostSkipSbbZeroOut( decisions[0] );
        m_skipStates[1].checkRdCostSkipSbbZeroOut( decisions[1] );
        m_skipStates[2].checkRdCostSkipSbbZeroOut( decisions[2] );
        m_skipStates[3].checkRdCostSkipSbbZeroOut( decisions[3] );
      }
      return;
    }
#endif
    m_quant.preQuantCoeff( absCoeff, pqData, quanCoeff );
#if JVET_O0094_LFNST_ZERO_PRIM_COEFFS
Mischa Siekmann's avatar
Mischa Siekmann committed
    m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2]);
    m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0]);
    m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3]);
    m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1]);
#else
    m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2], zeroOut );
    m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0], zeroOut );
    m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3], zeroOut );
    m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1], zeroOut );
Mischa Siekmann's avatar
Mischa Siekmann committed
#endif
#if !JVET_O0094_LFNST_ZERO_PRIM_COEFFS
      if( zeroOut )
      {
        m_skipStates[0].checkRdCostSkipSbbZeroOut( decisions[0] );
        m_skipStates[1].checkRdCostSkipSbbZeroOut( decisions[1] );
        m_skipStates[2].checkRdCostSkipSbbZeroOut( decisions[2] );
        m_skipStates[3].checkRdCostSkipSbbZeroOut( decisions[3] );
      }
      else
      {
Mischa Siekmann's avatar
Mischa Siekmann committed
#endif
        m_skipStates[0].checkRdCostSkipSbb( decisions[0] );
        m_skipStates[1].checkRdCostSkipSbb( decisions[1] );
        m_skipStates[2].checkRdCostSkipSbb( decisions[2] );
        m_skipStates[3].checkRdCostSkipSbb( decisions[3] );
#if !JVET_O0094_LFNST_ZERO_PRIM_COEFFS
Mischa Siekmann's avatar
Mischa Siekmann committed
#endif
#if !JVET_O0094_LFNST_ZERO_PRIM_COEFFS
Mischa Siekmann's avatar
Mischa Siekmann committed
#endif
    m_startState.checkRdCostStart( lastOffset, pqData[0], decisions[0] );
    m_startState.checkRdCostStart( lastOffset, pqData[2], decisions[2] );
#if !JVET_O0094_LFNST_ZERO_PRIM_COEFFS
Mischa Siekmann's avatar
Mischa Siekmann committed
#endif
  void DepQuant::xDecideAndUpdate( const TCoeff absCoeff, const ScanInfo& scanInfo, bool zeroOut, int quantCoeff )
  {
    Decision* decisions = m_trellis[ scanInfo.scanIdx ];

    std::swap( m_prevStates, m_currStates );

    xDecide( scanInfo.spt, absCoeff, lastOffset(scanInfo.scanIdx), decisions, zeroOut, quantCoeff );

    if( scanInfo.scanIdx )
    {
      if( scanInfo.eosbb )
      {
        m_commonCtx.swap();
        m_currStates[0].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[0] );
        m_currStates[1].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[1] );
        m_currStates[2].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[2] );
        m_currStates[3].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[3] );
        ::memcpy( decisions+4, decisions, 4*sizeof(Decision) );
      }
#if JVET_O0094_LFNST_ZERO_PRIM_COEFFS
Mischa Siekmann's avatar
Mischa Siekmann committed
      else if( !zeroOut )
#else
Mischa Siekmann's avatar
Mischa Siekmann committed
#endif
      {
        switch( scanInfo.nextNbInfoSbb.num )
        {
        case 0:
          m_currStates[0].updateState<0>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<0>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<0>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<0>( scanInfo, m_prevStates, decisions[3] );
          break;
        case 1:
          m_currStates[0].updateState<1>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<1>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<1>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<1>( scanInfo, m_prevStates, decisions[3] );
          break;
        case 2:
          m_currStates[0].updateState<2>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<2>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<2>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<2>( scanInfo, m_prevStates, decisions[3] );
          break;
        case 3:
          m_currStates[0].updateState<3>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<3>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<3>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<3>( scanInfo, m_prevStates, decisions[3] );
          break;
        case 4:
          m_currStates[0].updateState<4>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<4>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<4>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<4>( scanInfo, m_prevStates, decisions[3] );
          break;
        default:
          m_currStates[0].updateState<5>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<5>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<5>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<5>( scanInfo, m_prevStates, decisions[3] );
        }
      }

  void DepQuant::quant( TransformUnit& tu, const CCoeffBuf& srcCoeff, const ComponentID compID, const QpParam& cQP, const double lambda, const Ctx& ctx, TCoeff& absSum, bool enableScalingLists, int* quantCoeff )
    CHECKD( tu.cs->sps->getSpsRangeExtension().getExtendedPrecisionProcessingFlag(), "ext precision is not supported" );

    const TUParameters& tuPars  = *g_Rom.getTUPars( tu.blocks[compID], compID );
    m_quant.initQuantBlock    ( tu, compID, cQP, lambda );
    TCoeff*       qCoeff      = tu.getCoeffs( compID ).buf;
    const TCoeff* tCoeff      = srcCoeff.buf;
    const int     numCoeff    = tu.blocks[compID].area();
    ::memset( tu.getCoeffs( compID ).buf, 0x00, numCoeff*sizeof(TCoeff) );
    absSum          = 0;

    const CompArea& area     = tu.blocks[ compID ];
    const uint32_t  width    = area.width;
    const uint32_t  height   = area.height;
    const uint32_t  lfnstIdx = tu.cu->lfnstIdx;
    //===== scaling matrix ====
    //const int         qpDQ = cQP.Qp + 1;
    //const int         qpPer = qpDQ / 6;
    //const int         qpRem = qpDQ - 6 * qpPer;
    //TCoeff thresTmp = thres;
    bool zeroOut = false;
    bool zeroOutforThres = false;
    int effWidth = tuPars.m_width, effHeight = tuPars.m_height;
#if JVET_O0538_SPS_CONTROL_ISP_SBT
    if( ( tu.mtsIdx > MTS_SKIP || ( tu.cs->sps->getUseMTS() && tu.cu->sbtInfo != 0 && tuPars.m_height <= 32 && tuPars.m_width <= 32 ) ) && !tu.cu->transQuantBypass && compID == COMPONENT_Y )
#else
    if ((tu.mtsIdx > MTS_SKIP || (tu.cu->sbtInfo != 0 && tuPars.m_height <= 32 && tuPars.m_width <= 32)) && !tu.cu->transQuantBypass && compID == COMPONENT_Y)
    {
      effHeight = (tuPars.m_height == 32) ? 16 : tuPars.m_height;
      effWidth = (tuPars.m_width == 32) ? 16 : tuPars.m_width;
      zeroOut = (effHeight < tuPars.m_height || effWidth < tuPars.m_width);
    }
    zeroOutforThres = zeroOut || (32 < tuPars.m_height || 32 < tuPars.m_width);
#if JVET_O0094_LFNST_ZERO_PRIM_COEFFS
Mischa Siekmann's avatar
Mischa Siekmann committed
    if( lfnstIdx > 0 && tu.mtsIdx != MTS_SKIP && width >= 4 && height >= 4 )
    {
      firstTestPos = ( ( width == 4 && height == 4 ) || ( width == 8 && height == 8 ) )  ? 7 : 15 ;
    }
#else
    if( lfnstIdx > 0 && tu.mtsIdx != MTS_SKIP && ( ( width == 4 && height == 4 ) || ( width == 8 && height == 8 ) ) )
Mischa Siekmann's avatar
Mischa Siekmann committed
#endif
    const TCoeff defaultQuantisationCoefficient = (TCoeff)m_quant.getQScale();
    const TCoeff thres = m_quant.getLastThreshold();
    for( ; firstTestPos >= 0; firstTestPos-- )
    {
      if (zeroOutforThres && (tuPars.m_scanId2BlkPos[firstTestPos].x >= ((tuPars.m_width == 32 && zeroOut) ? 16 : 32)
                           || tuPars.m_scanId2BlkPos[firstTestPos].y >= ((tuPars.m_height == 32 && zeroOut) ? 16 : 32)))
        continue;
      TCoeff thresTmp = (enableScalingLists) ? TCoeff(thres / (4 * quantCoeff[tuPars.m_scanId2BlkPos[firstTestPos].idx]))
                                             : TCoeff(thres / (4 * defaultQuantisationCoefficient));

      if (abs(tCoeff[tuPars.m_scanId2BlkPos[firstTestPos].idx]) > thresTmp)
    RateEstimator::initCtx( tuPars, tu, compID, ctx.getFracBitsAcess() );
    m_commonCtx.reset( tuPars, *this );
    for( int k = 0; k < 12; k++ )
    {
      m_allStates[k].init();
    }
    m_startState.init();


#if JVET_O0052_TU_LEVEL_CTX_CODED_BIN_CONSTRAINT
    int effectWidth = std::min(32, effWidth);
    int effectHeight = std::min(32, effHeight);
    for (int k = 0; k < 12; k++)
    {
      m_allStates[k].effWidth = effectWidth;
      m_allStates[k].effHeight = effectHeight;
    }
    m_startState.effWidth = effectWidth;
    m_startState.effHeight = effectHeight;
#endif

    for( int scanIdx = firstTestPos; scanIdx >= 0; scanIdx-- )
    {
      const ScanInfo& scanInfo = tuPars.m_scanInfo[ scanIdx ];
#if !JVET_O0094_LFNST_ZERO_PRIM_COEFFS
      bool lfnstZeroOut = lfnstIdx > 0 && tu.mtsIdx != MTS_SKIP && width >= 4 && height >= 4 &&
        ( ( ( ( width >= 8 && height >= 8 ) && scanIdx >= 16 ) || ( ( ( width == 4 && height == 4 ) || ( width == 8 && height == 8 ) ) && scanIdx >= 8 ) ) && scanIdx < 48 );
      if (enableScalingLists)
      {
        m_quant.initQuantBlock(tu, compID, cQP, lambda, quantCoeff[scanInfo.rasterPos]);
        xDecideAndUpdate( abs( tCoeff[scanInfo.rasterPos]), scanInfo, (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)) || lfnstZeroOut, quantCoeff[scanInfo.rasterPos] );
      else
        xDecideAndUpdate( abs( tCoeff[scanInfo.rasterPos]), scanInfo, (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)) || lfnstZeroOut, defaultQuantisationCoefficient );
Mischa Siekmann's avatar
Mischa Siekmann committed
#else
      if (enableScalingLists)
      {
        m_quant.initQuantBlock(tu, compID, cQP, lambda, quantCoeff[scanInfo.rasterPos]);
        xDecideAndUpdate( abs( tCoeff[scanInfo.rasterPos]), scanInfo, (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)), quantCoeff[scanInfo.rasterPos] );
      }
      else
        xDecideAndUpdate( abs( tCoeff[scanInfo.rasterPos]), scanInfo, (zeroOut && (scanInfo.posX >= effWidth || scanInfo.posY >= effHeight)), defaultQuantisationCoefficient );
#endif

    //===== find best path =====
    Decision  decision    = { std::numeric_limits<int64_t>::max(), -1, -2 };
    int64_t   minPathCost =  0;
    for( int8_t stateId = 0; stateId < 4; stateId++ )
    {
      int64_t pathCost = m_trellis[0][stateId].rdCost;
      if( pathCost < minPathCost )
      {
        decision.prevId = stateId;
        minPathCost     = pathCost;
      }
    }

    //===== backward scanning =====
    int scanIdx = 0;
    for( ; decision.prevId >= 0; scanIdx++ )
    {
      decision          = m_trellis[ scanIdx ][ decision.prevId ];
Frank Bossen's avatar
Frank Bossen committed
      int32_t blkpos    = tuPars.m_scanId2BlkPos[scanIdx].idx;
      qCoeff[ blkpos ]  = ( tCoeff[ blkpos ] < 0 ? -decision.absLevel : decision.absLevel );
      absSum           += decision.absLevel;
    }
  }

}; // namespace DQIntern




//===== interface class =====
DepQuant::DepQuant( const Quant* other, bool enc ) : QuantRDOQ( other )
{
  const DepQuant* dq = dynamic_cast<const DepQuant*>( other );
  CHECK( other && !dq, "The DepQuant cast must be successfull!" );
  p = new DQIntern::DepQuant();
  if( enc )
  {
    DQIntern::g_Rom.init();
  }
}

DepQuant::~DepQuant()
{
  delete static_cast<DQIntern::DepQuant*>(p);
}

void DepQuant::quant( TransformUnit &tu, const ComponentID &compID, const CCoeffBuf &pSrc, TCoeff &uiAbsSum, const QpParam &cQP, const Ctx& ctx )
{
Brian Heng's avatar
Brian Heng committed
  if( tu.cs->slice->getDepQuantEnabledFlag() && (tu.mtsIdx != MTS_SKIP || !isLuma(compID)) )
    //===== scaling matrix ====
#if JVET_O0919_TS_MIN_QP
    const int         qpDQ            = cQP.Qp(tu.mtsIdx==MTS_SKIP && isLuma(compID)) + 1;
#else
    const int         qpDQ            = cQP.Qp + 1;
    const int         qpPer           = qpDQ / 6;
    const int         qpRem           = qpDQ - 6 * qpPer;
    const CompArea    &rect           = tu.blocks[compID];
    const int         width           = rect.width;
    const int         height          = rect.height;
    uint32_t          scalingListType = getScalingListType(tu.cu->predMode, compID);
    CHECK(scalingListType >= SCALING_LIST_NUM, "Invalid scaling list");
    const uint32_t    log2TrWidth     = g_aucLog2[width];
    const uint32_t    log2TrHeight    = g_aucLog2[height];
Chen-Yen Lai's avatar
Chen-Yen Lai committed
    const bool        enableScalingLists = getUseScalingList(width, height, (tu.mtsIdx == MTS_SKIP && isLuma(compID)));
    static_cast<DQIntern::DepQuant*>(p)->quant( tu, pSrc, compID, cQP, Quant::m_dLambda, ctx, uiAbsSum, enableScalingLists, Quant::getQuantCoeff(scalingListType, qpRem, log2TrWidth, log2TrHeight) );
  }
  else
  {
    QuantRDOQ::quant( tu, compID, pSrc, uiAbsSum, cQP, ctx );
  }
}

void DepQuant::dequant( const TransformUnit &tu, CoeffBuf &dstCoeff, const ComponentID &compID, const QpParam &cQP )
{
Brian Heng's avatar
Brian Heng committed
  if( tu.cs->slice->getDepQuantEnabledFlag() && (tu.mtsIdx != MTS_SKIP || !isLuma(compID)) )
#if JVET_O0919_TS_MIN_QP
    const int         qpDQ            = cQP.Qp(tu.mtsIdx==MTS_SKIP && isLuma(compID)) + 1;
#else
    const int         qpDQ            = cQP.Qp + 1;
    const int         qpPer           = qpDQ / 6;
    const int         qpRem           = qpDQ - 6 * qpPer;
    const CompArea    &rect           = tu.blocks[compID];
    const int         width           = rect.width;
    const int         height          = rect.height;
    uint32_t          scalingListType = getScalingListType(tu.cu->predMode, compID);
    CHECK(scalingListType >= SCALING_LIST_NUM, "Invalid scaling list");
    const uint32_t    log2TrWidth  = g_aucLog2[width];
    const uint32_t    log2TrHeight = g_aucLog2[height];
Chen-Yen Lai's avatar
Chen-Yen Lai committed
    const bool enableScalingLists = getUseScalingList(width, height, (tu.mtsIdx == MTS_SKIP && isLuma(compID)));
    static_cast<DQIntern::DepQuant*>(p)->dequant( tu, dstCoeff, compID, cQP, enableScalingLists, Quant::getDequantCoeff(scalingListType, qpRem, log2TrWidth, log2TrHeight) );