Commit 0a77c669 authored by Frank Bossen's avatar Frank Bossen
Browse files

Merge branch 'JVET_N0327_CIIP_BITCALC_BF' into 'master'

JVET_N0327: bug fix on bit estimation of merge mode RD

See merge request jvet/VVCSoftware_VTM!496
parents a4ded762 cf3989fd
......@@ -62,6 +62,7 @@
#define JVET_N0340_TRI_MERGE_CAND 1
#define JVET_N0302_SIMPLFIED_CIIP 1
#define JVET_N0327_MERGE_BIT_CALC_FIX 1
#define JVET_N0324_REGULAR_MRG_FLAG 1
......
......@@ -1843,10 +1843,14 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
{
RdModeList.clear();
mrgTempBufSet = true;
#if JVET_N0327_MERGE_BIT_CALC_FIX
const TempCtx ctxStart(m_CtxCache, m_CABACEstimator->getCtx());
#else
const double sqrtLambdaForFirstPass = m_pcRdCost->getMotionLambda( encTestMode.lossless );
#endif
CodingUnit &cu = tempCS->addCU( tempCS->area, partitioner.chType );
#if !JVET_N0302_SIMPLFIED_CIIP
#if !JVET_N0302_SIMPLFIED_CIIP || JVET_N0327_MERGE_BIT_CALC_FIX
const double sqrtLambdaForFirstPassIntra = m_pcRdCost->getMotionLambda(cu.transQuantBypass) / double(1 << SCALE_BITS);
#endif
partitioner.setCUData( cu );
......@@ -1908,6 +1912,11 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
}
Distortion uiSad = distParam.distFunc(distParam);
#if JVET_N0327_MERGE_BIT_CALC_FIX
m_CABACEstimator->getCtx() = ctxStart;
uint64_t fracBits = m_pcInterSearch->xCalcPuMeBits(pu);
double cost = (double)uiSad + (double)fracBits * sqrtLambdaForFirstPassIntra;
#else
uint32_t uiBitsCand = uiMergeCand + 1;
if( uiMergeCand == tempCS->slice->getMaxNumMergeCand() - 1 )
{
......@@ -1922,6 +1931,7 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
#endif
#endif
double cost = (double)uiSad + (double)uiBitsCand * sqrtLambdaForFirstPass;
#endif
insertPos = -1;
updateDoubleCandList(uiMergeCand, cost, RdModeList, candCostList, RdModeList2, (uint32_t)NUM_LUMA_MODE, uiNumMrgSATDCand, &insertPos);
if (insertPos != -1)
......@@ -1948,7 +1958,9 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
int numTestIntraMode = 4;
#endif
// prepare for Intra bits calculation
#if !JVET_N0327_MERGE_BIT_CALC_FIX
const TempCtx ctxStart(m_CtxCache, m_CABACEstimator->getCtx());
#endif
#if !JVET_N0302_SIMPLFIED_CIIP
const TempCtx ctxStartIntraMode(m_CtxCache, SubCtx(Ctx::MHIntraPredMode, m_CABACEstimator->getCtx()));
......@@ -1970,11 +1982,15 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
acMergeBuffer[mergeCand] = m_acRealMergeBuffer[mergeCand].getBuf(localUnitArea);
// estimate merge bits
#if JVET_N0327_MERGE_BIT_CALC_FIX
mergeCtx.setMergeInfo(pu, mergeCand);
#else
uint32_t bitsCand = mergeCand + 1;
if (mergeCand == pu.cs->slice->getMaxNumMergeCand() - 1)
{
bitsCand--;
}
#endif
// first round
#if JVET_N0302_SIMPLFIED_CIIP
......@@ -2005,10 +2021,16 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
{
pu.cs->getPredBuf(pu).Y().rspSignal(m_pcReshape->getFwdLUT());
}
#if JVET_N0327_MERGE_BIT_CALC_FIX
m_CABACEstimator->getCtx() = ctxStart;
uint64_t fracBits = m_pcInterSearch->xCalcPuMeBits(pu);
double cost = (double)sadValue + (double)fracBits * sqrtLambdaForFirstPassIntra;
#else
#if JVET_N0324_REGULAR_MRG_FLAG
double cost = (double)sadValue + (double)(bitsCand + 9) * sqrtLambdaForFirstPass;
#else
double cost = (double)sadValue + (double)(bitsCand + 1) * sqrtLambdaForFirstPass;
#endif
#endif
insertPos = -1;
updateDoubleCandList(mergeCand + MRG_MAX_NUM_CANDS + MMVD_ADD_NUM, cost, RdModeList, candCostList, RdModeList2, pu.intraDir[0], uiNumMrgSATDCand, &insertPos);
......@@ -2064,12 +2086,18 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
{
pu.cs->getPredBuf(pu).Y().rspSignal(m_pcReshape->getFwdLUT());
}
#if JVET_N0327_MERGE_BIT_CALC_FIX
m_CABACEstimator->getCtx() = ctxStart;
uint64_t fracBits = m_pcInterSearch->xCalcPuMeBits(pu);
double cost = (double)sadValue + (double)fracBits * sqrtLambdaForFirstPassIntra;
#else
m_CABACEstimator->getCtx() = SubCtx(Ctx::MHIntraPredMode, ctxStartIntraMode);
uint64_t fracModeBits = m_pcIntraSearch->xFracModeBitsIntra(pu, pu.intraDir[0], CHANNEL_TYPE_LUMA);
#if JVET_N0324_REGULAR_MRG_FLAG
double cost = (double)sadValue + (double)(bitsCand + 9) * sqrtLambdaForFirstPass + (double)fracModeBits * sqrtLambdaForFirstPassIntra;
#else
double cost = (double)sadValue + (double)(bitsCand + 1) * sqrtLambdaForFirstPass + (double)fracModeBits * sqrtLambdaForFirstPassIntra;
#endif
#endif
insertPos = -1;
updateDoubleCandList(mergeCand + MRG_MAX_NUM_CANDS + MMVD_ADD_NUM, cost, RdModeList, candCostList, RdModeList2, pu.intraDir[0], uiNumMrgSATDCand, &insertPos);
......@@ -2091,7 +2119,9 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
#endif
}
pu.mhIntraFlag = false;
#if !JVET_N0327_MERGE_BIT_CALC_FIX
m_CABACEstimator->getCtx() = ctxStart;
#endif
}
#if !JVET_MMVD_OFF_MACRO
#if JVET_N0127_MMVD_SPS_FLAG
......@@ -2107,6 +2137,10 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
for (uint32_t mergeCand = mergeCtx.numValidMergeCand; mergeCand < mergeCtx.numValidMergeCand + tempNum; mergeCand++)
{
const int mmvdMergeCand = mergeCand - mergeCtx.numValidMergeCand;
#if JVET_N0327_MERGE_BIT_CALC_FIX
int baseIdx = mmvdMergeCand / MMVD_MAX_REFINE_NUM;
int refineStep = (mmvdMergeCand - (baseIdx * MMVD_MAX_REFINE_NUM)) / 4;
#else
int bitsBaseIdx = 0;
int bitsRefineStep = 0;
int bitsDirection = 2;
......@@ -2115,10 +2149,12 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
int refineStep;
baseIdx = mmvdMergeCand / MMVD_MAX_REFINE_NUM;
refineStep = (mmvdMergeCand - (baseIdx * MMVD_MAX_REFINE_NUM)) / 4;
#endif
#if JVET_N0449_MMVD_SIMP
if (refineStep >= m_pcEncCfg->getMmvdDisNum())
continue;
#endif
#if !JVET_N0327_MERGE_BIT_CALC_FIX
bitsBaseIdx = baseIdx + 1;
if (baseIdx == MMVD_BASE_MV_NUM - 1)
{
......@@ -2136,6 +2172,7 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
bitsCand += 7;
#else
bitsCand++; // for mmvd_flag
#endif
#endif
mergeCtx.setMmvdMergeCandiInfo(pu, mmvdMergeCand);
......@@ -2150,8 +2187,13 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
pu.mvRefine = false;
Distortion uiSad = distParam.distFunc(distParam);
#if JVET_N0327_MERGE_BIT_CALC_FIX
m_CABACEstimator->getCtx() = ctxStart;
uint64_t fracBits = m_pcInterSearch->xCalcPuMeBits(pu);
double cost = (double)uiSad + (double)fracBits * sqrtLambdaForFirstPassIntra;
#else
double cost = (double)uiSad + (double)bitsCand * sqrtLambdaForFirstPass;
#endif
insertPos = -1;
updateDoubleCandList(mergeCand, cost, RdModeList, candCostList, RdModeList2, (uint32_t)NUM_LUMA_MODE, uiNumMrgSATDCand, &insertPos);
if (insertPos != -1)
......@@ -2206,6 +2248,9 @@ void EncCu::xCheckRDCostMerge2Nx2N( CodingStructure *&tempCS, CodingStructure *&
}
tempCS->initStructData( encTestMode.qp, encTestMode.lossless );
#if JVET_N0327_MERGE_BIT_CALC_FIX
m_CABACEstimator->getCtx() = ctxStart;
#endif
}
else
{
......
......@@ -7842,3 +7842,70 @@ void InterSearch::symmvdCheckBestMvp(
}
}
}
#if JVET_N0327_MERGE_BIT_CALC_FIX
uint64_t InterSearch::xCalcPuMeBits(PredictionUnit& pu)
{
assert(pu.mergeFlag);
assert(!CU::isIBC(*pu.cu));
m_CABACEstimator->resetBits();
m_CABACEstimator->merge_flag(pu);
if (pu.mergeFlag)
{
if (CU::isIBC(*pu.cu))
{
m_CABACEstimator->merge_idx(pu);
return m_CABACEstimator->getEstFracBits();
}
#if JVET_N0324_REGULAR_MRG_FLAG
if (pu.regularMergeFlag)
{
m_CABACEstimator->merge_idx(pu);
}
else
{
#endif
m_CABACEstimator->subblock_merge_flag(*pu.cu);
m_CABACEstimator->MHIntra_flag(pu);
#if !JVET_N0302_SIMPLFIED_CIIP
if (pu.mhIntraFlag)
{
MHIntra_luma_pred_modes(*pu.cu);
}
#if JVET_N0324_REGULAR_MRG_FLAG
else
{
if (!pu.cu->affine && !pu.mmvdMergeFlag && !pu.cu->mmvdSkip)
{
CHECK(!pu.cu->triangle, "triangle_flag must be true");
}
}
#else
triangle_mode(*pu.cu);
#endif
#else
#if JVET_N0324_REGULAR_MRG_FLAG
if (!pu.mhIntraFlag)
{
if (!pu.cu->affine && !pu.mmvdMergeFlag && !pu.cu->mmvdSkip)
{
CHECK(!pu.cu->triangle, "triangle_flag must be true");
}
}
#else
triangle_mode(*pu.cu);
#endif
#endif
if (pu.mmvdMergeFlag)
{
m_CABACEstimator->mmvd_merge_idx(pu);
}
else
m_CABACEstimator->merge_idx(pu);
#if JVET_N0324_REGULAR_MRG_FLAG
}
#endif
}
return m_CABACEstimator->getEstFracBits();
}
#endif
\ No newline at end of file
......@@ -529,6 +529,9 @@ public:
, const bool luma = true, const bool chroma = true
);
uint64_t xGetSymbolFracBitsInter (CodingStructure &cs, Partitioner &partitioner);
#if JVET_N0327_MERGE_BIT_CALC_FIX
uint64_t xCalcPuMeBits (PredictionUnit& pu);
#endif
};// END CLASS DEFINITION EncSearch
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment