Skip to content
Snippets Groups Projects
DepQuant.cpp 88.5 KiB
Newer Older
#if JVET_L0274
    m_prevStates[0].checkRdCosts( spt, pqData[0], pqData[2], decisions[0], decisions[2]);
    m_prevStates[1].checkRdCosts( spt, pqData[0], pqData[2], decisions[2], decisions[0]);
    m_prevStates[2].checkRdCosts( spt, pqData[3], pqData[1], decisions[1], decisions[3]);
    m_prevStates[3].checkRdCosts( spt, pqData[3], pqData[1], decisions[3], decisions[1]);
#else
    m_prevStates[0].checkRdCostNonZero<spt> ( pqData[0],  decisions[0] );
    m_prevStates[0].checkRdCostNonZero<spt> ( pqData[2],  decisions[2] );
    m_prevStates[0].checkRdCostZero<spt>                ( decisions[0] );
    m_prevStates[1].checkRdCostNonZero<spt> ( pqData[2],  decisions[0] );
    m_prevStates[1].checkRdCostNonZero<spt> ( pqData[0],  decisions[2] );
    m_prevStates[1].checkRdCostZero<spt>                ( decisions[2] );
    m_prevStates[2].checkRdCostNonZero<spt> ( pqData[3],  decisions[1] );
    m_prevStates[2].checkRdCostNonZero<spt> ( pqData[1],  decisions[3] );
    m_prevStates[2].checkRdCostZero<spt>                ( decisions[1] );
    m_prevStates[3].checkRdCostNonZero<spt> ( pqData[1],  decisions[1] );
    m_prevStates[3].checkRdCostNonZero<spt> ( pqData[3],  decisions[3] );
    m_prevStates[3].checkRdCostZero<spt>                ( decisions[3] );
    if( spt==SCAN_EOCSBB )
    {
      m_skipStates[0].checkRdCostSkipSbb( decisions[0] );
      m_skipStates[1].checkRdCostSkipSbb( decisions[1] );
      m_skipStates[2].checkRdCostSkipSbb( decisions[2] );
      m_skipStates[3].checkRdCostSkipSbb( decisions[3] );
    }
    m_startState.checkRdCostStart( lastOffset, pqData[0], decisions[0] );
    m_startState.checkRdCostStart( lastOffset, pqData[2], decisions[2] );
  }

  void DepQuant::xDecideAndUpdate( const TCoeff absCoeff, const ScanInfo& scanInfo )
  {
    Decision* decisions = m_trellis[ scanInfo.scanIdx ];

    std::swap( m_prevStates, m_currStates );

#if JVET_L0274
    xDecide( scanInfo.spt, absCoeff, lastOffset(scanInfo.scanIdx), decisions);
#else
    if     ( scanInfo.socsbb )  { xDecide<SCAN_SOCSBB>( absCoeff, scanInfo.lastOffset, decisions ); }
    else if( scanInfo.eocsbb )  { xDecide<SCAN_EOCSBB>( absCoeff, scanInfo.lastOffset, decisions ); }
    else                        { xDecide<SCAN_ISCSBB>( absCoeff, scanInfo.lastOffset, decisions ); }

    if( scanInfo.scanIdx )
    {
      if( scanInfo.eosbb )
      {
        m_commonCtx.swap();
        m_currStates[0].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[0] );
        m_currStates[1].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[1] );
        m_currStates[2].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[2] );
        m_currStates[3].updateStateEOS( scanInfo, m_prevStates, m_skipStates, decisions[3] );
        ::memcpy( decisions+4, decisions, 4*sizeof(Decision) );
      }
      else
      {
        switch( scanInfo.nextNbInfoSbb.num )
        {
        case 0:
          m_currStates[0].updateState<0>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<0>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<0>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<0>( scanInfo, m_prevStates, decisions[3] );
          break;
        case 1:
          m_currStates[0].updateState<1>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<1>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<1>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<1>( scanInfo, m_prevStates, decisions[3] );
          break;
        case 2:
          m_currStates[0].updateState<2>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<2>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<2>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<2>( scanInfo, m_prevStates, decisions[3] );
          break;
        case 3:
          m_currStates[0].updateState<3>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<3>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<3>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<3>( scanInfo, m_prevStates, decisions[3] );
          break;
        case 4:
          m_currStates[0].updateState<4>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<4>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<4>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<4>( scanInfo, m_prevStates, decisions[3] );
          break;
        default:
          m_currStates[0].updateState<5>( scanInfo, m_prevStates, decisions[0] );
          m_currStates[1].updateState<5>( scanInfo, m_prevStates, decisions[1] );
          m_currStates[2].updateState<5>( scanInfo, m_prevStates, decisions[2] );
          m_currStates[3].updateState<5>( scanInfo, m_prevStates, decisions[3] );
        }
      }

#if JVET_L0274
      if( scanInfo.spt == SCAN_SOCSBB )
#else
      {
        std::swap( m_prevStates, m_skipStates );
      }
    }
  }


  void DepQuant::quant( TransformUnit& tu, const CCoeffBuf& srcCoeff, const ComponentID compID, const QpParam& cQP, const double lambda, const Ctx& ctx, TCoeff& absSum )
  {
#if JVET_L0274
    CHECKD( tu.cs->sps->getSpsRangeExtension().getExtendedPrecisionProcessingFlag(), "ext precision is not supported" );
#endif

#if JVET_L0274_ENCODER_SPEED_UP
    const TUParameters& tuPars  = *g_Rom.getTUPars( tu.blocks[compID], compID );
#else
    m_quant.initQuantBlock    ( tu, compID, cQP, lambda );
    TCoeff*       qCoeff      = tu.getCoeffs( compID ).buf;
    const TCoeff* tCoeff      = srcCoeff.buf;
    const int     numCoeff    = tu.blocks[compID].area();
    ::memset( tu.getCoeffs( compID ).buf, 0x00, numCoeff*sizeof(TCoeff) );
    absSum          = 0;

    //===== find first test position =====
    int   firstTestPos = numCoeff - 1;
    const TCoeff thres = m_quant.getLastThreshold();
    for( ; firstTestPos >= 0; firstTestPos-- )
    {
#if JVET_L0274_ENCODER_SPEED_UP
      if( abs( tCoeff[ tuPars.m_scanId2BlkPos[firstTestPos] ] ) > thres )
#else
      if( abs( tCoeff[ rasterPos(firstTestPos) ] ) > thres )
#if JVET_L0274_ENCODER_SPEED_UP
    RateEstimator::initCtx( tuPars, tu, compID, ctx.getFracBitsAcess() );
    m_commonCtx.reset( tuPars, *this );
#else
    RateEstimator::initCtx( tu, ctx.getFracBitsAcess() );
    m_commonCtx.reset( *this );
    for( int k = 0; k < 12; k++ )
    {
      m_allStates[k].init();
    }
    m_startState.init();


    //===== populate trellis =====
#if JVET_L0274_ENCODER_SPEED_UP
    for( int scanIdx = firstTestPos; scanIdx >= 0; scanIdx-- )
    {
      const ScanInfo& scanInfo = tuPars.m_scanInfo[ scanIdx ];
      xDecideAndUpdate( abs( tCoeff[ scanInfo.rasterPos ] ), scanInfo );
    }
#else
    for( ScanData scanData(*this,firstTestPos); scanData.valid(); scanData.next() )
    {
      xDecideAndUpdate( abs( tCoeff[ scanData.rasterPos ] ), scanData );
    }

    //===== find best path =====
    Decision  decision    = { std::numeric_limits<int64_t>::max(), -1, -2 };
    int64_t   minPathCost =  0;
    for( int8_t stateId = 0; stateId < 4; stateId++ )
    {
      int64_t pathCost = m_trellis[0][stateId].rdCost;
      if( pathCost < minPathCost )
      {
        decision.prevId = stateId;
        minPathCost     = pathCost;
      }
    }

    //===== backward scanning =====
    int scanIdx = 0;
    for( ; decision.prevId >= 0; scanIdx++ )
    {
      decision          = m_trellis[ scanIdx ][ decision.prevId ];
#if JVET_L0274_ENCODER_SPEED_UP
      int32_t blkpos    = tuPars.m_scanId2BlkPos[ scanIdx ];
#else
      qCoeff[ blkpos ]  = ( tCoeff[ blkpos ] < 0 ? -decision.absLevel : decision.absLevel );
      absSum           += decision.absLevel;
    }
  }

}; // namespace DQIntern




//===== interface class =====
DepQuant::DepQuant( const Quant* other, bool enc ) : QuantRDOQ( other )
{
  const DepQuant* dq = dynamic_cast<const DepQuant*>( other );
  CHECK( other && !dq, "The DepQuant cast must be successfull!" );
  p = new DQIntern::DepQuant();
  if( enc )
  {
    DQIntern::g_Rom.init();
  }
}

DepQuant::~DepQuant()
{
  delete static_cast<DQIntern::DepQuant*>(p);
}

void DepQuant::quant( TransformUnit &tu, const ComponentID &compID, const CCoeffBuf &pSrc, TCoeff &uiAbsSum, const QpParam &cQP, const Ctx& ctx )
{
  if( tu.cs->slice->getDepQuantEnabledFlag() )
  {
    static_cast<DQIntern::DepQuant*>(p)->quant( tu, pSrc, compID, cQP, Quant::m_dLambda, ctx, uiAbsSum );
  }
  else
  {
    QuantRDOQ::quant( tu, compID, pSrc, uiAbsSum, cQP, ctx );
  }
}

void DepQuant::dequant( const TransformUnit &tu, CoeffBuf &dstCoeff, const ComponentID &compID, const QpParam &cQP )
{
  if( tu.cs->slice->getDepQuantEnabledFlag() )
  {
    static_cast<DQIntern::DepQuant*>(p)->dequant( tu, dstCoeff, compID, cQP );
  }
  else
  {
    QuantRDOQ::dequant( tu, dstCoeff, compID, cQP );
  }
}