Skip to content
Snippets Groups Projects
DepQuant.cpp 88.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • #if JVET_L0274
        m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2]);
        m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0]);
        m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3]);
        m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1]);
    #else
    
        m_prevStates[0].checkRdCostNonZero<spt> ( pqData[0],  decisions[0] );
        m_prevStates[0].checkRdCostNonZero<spt> ( pqData[2],  decisions[2] );
        m_prevStates[0].checkRdCostZero<spt>                ( decisions[0] );
        m_prevStates[1].checkRdCostNonZero<spt> ( pqData[2],  decisions[0] );
        m_prevStates[1].checkRdCostNonZero<spt> ( pqData[0],  decisions[2] );
        m_prevStates[1].checkRdCostZero<spt>                ( decisions[2] );
        m_prevStates[2].checkRdCostNonZero<spt> ( pqData[3],  decisions[1] );
        m_prevStates[2].checkRdCostNonZero<spt> ( pqData[1],  decisions[3] );
        m_prevStates[2].checkRdCostZero<spt>                ( decisions[1] );
        m_prevStates[3].checkRdCostNonZero<spt> ( pqData[1],  decisions[1] );
        m_prevStates[3].checkRdCostNonZero<spt> ( pqData[3],  decisions[3] );
        m_prevStates[3].checkRdCostZero<spt>                ( decisions[3] );
    
        if( spt==SCAN_EOCSBB )
        {
          m_skipStates[0].checkRdCostSkipSbb( decisions[0] );
          m_skipStates[1].checkRdCostSkipSbb( decisions[1] );
          m_skipStates[2].checkRdCostSkipSbb( decisions[2] );
          m_skipStates[3].checkRdCostSkipSbb( decisions[3] );
        }
        m_startState.checkRdCostStart( lastOffset, pqData[0], decisions[0] );
        m_startState.checkRdCostStart( lastOffset, pqData[2], decisions[2] );
      }
    
      void DepQuant::xDecideAndUpdate( const TCoeff absCoeff, const ScanInfo& scanInfo )
      {
        Decision* decisions = m_trellis[ scanInfo.scanIdx ];
    
        std::swap( m_prevStates, m_currStates );
    
    
    #if JVET_L0274
        xDecide( scanInfo.spt, absCoeff, lastOffset(scanInfo.scanIdx), decisions);
    #else
    
        if     ( scanInfo.socsbb )  { xDecide<SCAN_SOCSBB>( absCoeff, scanInfo.lastOffset, decisions ); }
        else if( scanInfo.eocsbb )  { xDecide<SCAN_EOCSBB>( absCoeff, scanInfo.lastOffset, decisions ); }
        else                        { xDecide<SCAN_ISCSBB>( absCoeff, scanInfo.lastOffset, decisions ); }
    
    
        if( scanInfo.scanIdx )
        {
          if( scanInfo.eosbb )
          {
            m_commonCtx.swap();
            m_currStates[0].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[0] );
            m_currStates[1].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[1] );
            m_currStates[2].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[2] );
            m_currStates[3].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[3] );
            ::memcpy( decisions+4, decisions, 4*sizeof(Decision) );
          }
          else
          {
            switch( scanInfo.nextNbInfoSbb.num )
            {
            case 0:
              m_currStates[0].updateState<0>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<0>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<0>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<0>( scanInfo, m_prevStates, decisions[3] );
              break;
            case 1:
              m_currStates[0].updateState<1>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<1>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<1>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<1>( scanInfo, m_prevStates, decisions[3] );
              break;
            case 2:
              m_currStates[0].updateState<2>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<2>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<2>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<2>( scanInfo, m_prevStates, decisions[3] );
              break;
            case 3:
              m_currStates[0].updateState<3>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<3>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<3>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<3>( scanInfo, m_prevStates, decisions[3] );
              break;
            case 4:
              m_currStates[0].updateState<4>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<4>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<4>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<4>( scanInfo, m_prevStates, decisions[3] );
              break;
            default:
              m_currStates[0].updateState<5>( scanInfo, m_prevStates, decisions[0] );
              m_currStates[1].updateState<5>( scanInfo, m_prevStates, decisions[1] );
              m_currStates[2].updateState<5>( scanInfo, m_prevStates, decisions[2] );
              m_currStates[3].updateState<5>( scanInfo, m_prevStates, decisions[3] );
            }
          }
    
    
    #if JVET_L0274
          if( scanInfo.spt == SCAN_SOCSBB )
    #else
    
          {
            std::swap( m_prevStates, m_skipStates );
          }
        }
      }
    
    
      void DepQuant::quant( TransformUnit& tu, const CCoeffBuf& srcCoeff, const ComponentID compID, const QpParam& cQP, const double lambda, const Ctx& ctx, TCoeff& absSum )
      {
    
    #if JVET_L0274
        CHECKD( tu.cs->sps->getSpsRangeExtension().getExtendedPrecisionProcessingFlag(), "ext precision is not supported" );
    #endif
    
    
    #if JVET_L0274_ENCODER_SPEED_UP
        const TUParameters& tuPars  = *g_Rom.getTUPars( tu.blocks[compID], compID );
    #else
    
        m_quant.initQuantBlock    ( tu, compID, cQP, lambda );
        TCoeff*       qCoeff      = tu.getCoeffs( compID ).buf;
        const TCoeff* tCoeff      = srcCoeff.buf;
        const int     numCoeff    = tu.blocks[compID].area();
        ::memset( tu.getCoeffs( compID ).buf, 0x00, numCoeff*sizeof(TCoeff) );
        absSum          = 0;
    
        //===== find first test position =====
        int   firstTestPos = numCoeff - 1;
        const TCoeff thres = m_quant.getLastThreshold();
        for( ; firstTestPos >= 0; firstTestPos-- )
        {
    
    #if JVET_L0274_ENCODER_SPEED_UP
          if( abs( tCoeff[ tuPars.m_scanId2BlkPos[firstTestPos] ] ) > thres )
    #else
    
          if( abs( tCoeff[ rasterPos(firstTestPos) ] ) > thres )
    
    #if JVET_L0274_ENCODER_SPEED_UP
        RateEstimator::initCtx( tuPars, tu, compID, ctx.getFracBitsAcess() );
        m_commonCtx.reset( tuPars, *this );
    #else
    
        RateEstimator::initCtx( tu, ctx.getFracBitsAcess() );
        m_commonCtx.reset( *this );
    
        for( int k = 0; k < 12; k++ )
        {
          m_allStates[k].init();
        }
        m_startState.init();
    
    
        //===== populate trellis =====
    
    #if JVET_L0274_ENCODER_SPEED_UP
        for( int scanIdx = firstTestPos; scanIdx >= 0; scanIdx-- )
        {
          const ScanInfo& scanInfo = tuPars.m_scanInfo[ scanIdx ];
          xDecideAndUpdate( abs( tCoeff[ scanInfo.rasterPos ] ), scanInfo );
        }
    #else
    
        for( ScanData scanData(*this,firstTestPos); scanData.valid(); scanData.next() )
        {
          xDecideAndUpdate( abs( tCoeff[ scanData.rasterPos ] ), scanData );
        }
    
    
        //===== find best path =====
        Decision  decision    = { std::numeric_limits<int64_t>::max(), -1, -2 };
        int64_t   minPathCost =  0;
        for( int8_t stateId = 0; stateId < 4; stateId++ )
        {
          int64_t pathCost = m_trellis[0][stateId].rdCost;
          if( pathCost < minPathCost )
          {
            decision.prevId = stateId;
            minPathCost     = pathCost;
          }
        }
    
        //===== backward scanning =====
        int scanIdx = 0;
        for( ; decision.prevId >= 0; scanIdx++ )
        {
          decision          = m_trellis[ scanIdx ][ decision.prevId ];
    
    #if JVET_L0274_ENCODER_SPEED_UP
          int32_t blkpos    = tuPars.m_scanId2BlkPos[ scanIdx ];
    #else
    
          qCoeff[ blkpos ]  = ( tCoeff[ blkpos ] < 0 ? -decision.absLevel : decision.absLevel );
          absSum           += decision.absLevel;
        }
      }
    
    }; // namespace DQIntern
    
    
    
    
    //===== interface class =====
    DepQuant::DepQuant( const Quant* other, bool enc ) : QuantRDOQ( other )
    {
      const DepQuant* dq = dynamic_cast<const DepQuant*>( other );
      CHECK( other && !dq, "The DepQuant cast must be successfull!" );
      p = new DQIntern::DepQuant();
      if( enc )
      {
        DQIntern::g_Rom.init();
      }
    }
    
    DepQuant::~DepQuant()
    {
      delete static_cast<DQIntern::DepQuant*>(p);
    }
    
    void DepQuant::quant( TransformUnit &tu, const ComponentID &compID, const CCoeffBuf &pSrc, TCoeff &uiAbsSum, const QpParam &cQP, const Ctx& ctx )
    {
      if( tu.cs->slice->getDepQuantEnabledFlag() )
      {
        static_cast<DQIntern::DepQuant*>(p)->quant( tu, pSrc, compID, cQP, Quant::m_dLambda, ctx, uiAbsSum );
      }
      else
      {
        QuantRDOQ::quant( tu, compID, pSrc, uiAbsSum, cQP, ctx );
      }
    }
    
    void DepQuant::dequant( const TransformUnit &tu, CoeffBuf &dstCoeff, const ComponentID &compID, const QpParam &cQP )
    {
      if( tu.cs->slice->getDepQuantEnabledFlag() )
      {
        static_cast<DQIntern::DepQuant*>(p)->dequant( tu, dstCoeff, compID, cQP );
      }
      else
      {
        QuantRDOQ::dequant( tu, dstCoeff, compID, cQP );
      }
    }