Commit bb5469ef authored by Xiang Li's avatar Xiang Li

Merge branch 'JVET-R0164-mean_scaled_SATD' into 'master'

JVET-R0164: Use a mean scaled version of SATD in encoder decisions

See merge request jvet/VVCSoftware_VTM!1544
parents d15f4116 155be26b
......@@ -2151,7 +2151,11 @@ Distortion RdCost::xCalcHADs2x2( const Pel *piOrg, const Pel *piCur, int iStride
m[2] = diff[0] - diff[2];
m[3] = diff[1] - diff[3];
#if JVET_R0164_MEAN_SCALED_SATD
satd += abs(m[0] + m[1]) >> 2;
#else
satd += abs(m[0] + m[1]);
#endif
satd += abs(m[0] - m[1]);
satd += abs(m[2] + m[3]);
satd += abs(m[2] - m[3]);
......@@ -2250,7 +2254,12 @@ Distortion RdCost::xCalcHADs4x4( const Pel *piOrg, const Pel *piCur, int iStride
{
satd += abs(d[k]);
}
satd = ((satd+1)>>1);
#if JVET_R0164_MEAN_SCALED_SATD
satd -= abs(d[0]);
satd += abs(d[0]) >> 2;
#endif
satd = ((satd+1)>>1);
return satd;
}
......@@ -2347,7 +2356,11 @@ Distortion RdCost::xCalcHADs8x8( const Pel *piOrg, const Pel *piCur, int iStride
}
}
sad=((sad+2)>>2);
#if JVET_R0164_MEAN_SCALED_SATD
sad -= abs(m2[0][0]);
sad += abs(m2[0][0]) >> 2;
#endif
sad = ((sad+2)>>2);
return sad;
}
......@@ -2493,7 +2506,11 @@ Distortion RdCost::xCalcHADs16x8( const Pel *piOrg, const Pel *piCur, int iStrid
}
}
sad = ( int ) ( sad / sqrt( 16.0 * 8 ) * 2 );
#if JVET_R0164_MEAN_SCALED_SATD
sad -= abs(m2[0][0]);
sad += abs(m2[0][0]) >> 2;
#endif
sad = ( int ) ( sad / sqrt( 16.0 * 8 ) * 2 );
return sad;
}
......@@ -2630,7 +2647,11 @@ Distortion RdCost::xCalcHADs8x16( const Pel *piOrg, const Pel *piCur, int iStrid
}
}
sad = ( int ) ( sad / sqrt( 16.0 * 8 ) * 2 );
#if JVET_R0164_MEAN_SCALED_SATD
sad -= abs(m2[0][0]);
sad += abs(m2[0][0]) >> 2;
#endif
sad = ( int ) ( sad / sqrt( 16.0 * 8 ) * 2 );
return sad;
}
......@@ -2703,7 +2724,11 @@ Distortion RdCost::xCalcHADs4x8( const Pel *piOrg, const Pel *piCur, int iStride
}
}
sad = ( int ) ( sad / sqrt( 4.0 * 8 ) * 2 );
#if JVET_R0164_MEAN_SCALED_SATD
sad -= abs(m2[0][0]);
sad += abs(m2[0][0]) >> 2;
#endif
sad = ( int ) ( sad / sqrt( 4.0 * 8 ) * 2 );
return sad;
}
......@@ -2782,7 +2807,11 @@ Distortion RdCost::xCalcHADs8x4( const Pel *piOrg, const Pel *piCur, int iStride
}
}
sad = ( int ) ( sad / sqrt( 4.0 * 8 ) * 2 );
#if JVET_R0164_MEAN_SCALED_SATD
sad -= abs(m2[0][0]);
sad += abs(m2[0][0]) >> 2;
#endif
sad = ( int ) ( sad / sqrt( 4.0 * 8 ) * 2 );
return sad;
}
......
......@@ -79,6 +79,8 @@
//########### place macros to be be kept below this line ###############
#define JVET_R0164_MEAN_SCALED_SATD 1 // JVET-R0164: Use a mean scaled version of SATD in encoder decisions
#define JVET_M0497_MATRIX_MULT 0 // 0: Fast method; 1: Matrix multiplication
#define APPLY_SBT_SL_ON_MTS 1 // apply save & load fast algorithm on inter MTS when SBT is on
......
......@@ -523,6 +523,9 @@ static uint32_t xCalcHAD4x4_SSE( const Torg *piOrg, const Tcur *piCur, const int
// abs
__m128i Sum = _mm_abs_epi16( r0 );
#if JVET_R0164_MEAN_SCALED_SATD
uint32_t absDc = _mm_cvtsi128_si32( Sum ) & 0x0000ffff;
#endif
Sum = _mm_add_epi16( Sum, _mm_abs_epi16( r2 ) );
Sum = _mm_add_epi16( Sum, _mm_abs_epi16( r3 ) );
Sum = _mm_add_epi16( Sum, _mm_abs_epi16( r5 ) );
......@@ -534,7 +537,11 @@ static uint32_t xCalcHAD4x4_SSE( const Torg *piOrg, const Tcur *piCur, const int
uint32_t sad = _mm_cvtsi128_si32( Sum );
sad = ( ( sad + 1 ) >> 1 );
#if JVET_R0164_MEAN_SCALED_SATD
sad -= absDc;
sad += absDc >> 2;
#endif
sad = ( ( sad + 1 ) >> 1 );
return sad;
}
......@@ -663,7 +670,12 @@ static uint32_t xCalcHAD8x8_SSE( const Torg *piOrg, const Tcur *piCur, const int
iSum = _mm_hadd_epi32( iSum, iSum );
uint32_t sad = _mm_cvtsi128_si32( iSum );
sad = ( ( sad + 2 ) >> 2 );
#if JVET_R0164_MEAN_SCALED_SATD
uint32_t absDc = _mm_cvtsi128_si32( n1[0][0] );
sad -= absDc;
sad += absDc >> 2;
#endif
sad = ( ( sad + 2 ) >> 2 );
return sad;
}
......@@ -725,6 +737,9 @@ static uint32_t xCalcHAD16x8_SSE( const Torg *piOrg, const Tcur *piCur, const in
// 4 x 8x4 blocks
// 0 1
// 2 3
#if JVET_R0164_MEAN_SCALED_SATD
uint32_t absDc = 0;
#endif
// transpose and do horizontal in two steps
for( int l = 0; l < 2; l++ )
......@@ -841,6 +856,11 @@ static uint32_t xCalcHAD16x8_SSE( const Torg *piOrg, const Tcur *piCur, const in
n1[14] = _mm_abs_epi32( _mm_add_epi32( n2[14], n2[15] ) );
n1[15] = _mm_abs_epi32( _mm_sub_epi32( n2[14], n2[15] ) );
#if JVET_R0164_MEAN_SCALED_SATD
if (l == 0)
absDc = _mm_cvtsi128_si32( n1[0] );
#endif
// sum up
n1[0] = _mm_add_epi32( n1[0], n1[1] );
n1[2] = _mm_add_epi32( n1[2], n1[3] );
......@@ -868,7 +888,11 @@ static uint32_t xCalcHAD16x8_SSE( const Torg *piOrg, const Tcur *piCur, const in
uint32_t sad = _mm_cvtsi128_si32( iSum );
sad = (uint32_t)(sad / sqrt(16.0 * 8) * 2);
#if JVET_R0164_MEAN_SCALED_SATD
sad -= absDc;
sad += absDc >> 2;
#endif
sad = (uint32_t)(sad / sqrt(16.0 * 8) * 2);
return sad;
}
......@@ -984,6 +1008,10 @@ static uint32_t xCalcHAD8x16_SSE( const Torg *piOrg, const Tcur *piCur, const in
}
}
#if JVET_R0164_MEAN_SCALED_SATD
uint32_t absDc = 0;
#endif
for( int l = 0; l < 2; l++ )
{
int off = l * 8;
......@@ -1028,6 +1056,11 @@ static uint32_t xCalcHAD8x16_SSE( const Torg *piOrg, const Tcur *piCur, const in
n1[i][5] = _mm_abs_epi32( _mm_sub_epi32( n2[i][4], n2[i][5] ) );
n1[i][6] = _mm_abs_epi32( _mm_add_epi32( n2[i][6], n2[i][7] ) );
n1[i][7] = _mm_abs_epi32( _mm_sub_epi32( n2[i][6], n2[i][7] ) );
#if JVET_R0164_MEAN_SCALED_SATD
if ( l + i == 0 )
absDc = _mm_cvtsi128_si32( n1[i][0] );
#endif
}
for( int i = 0; i < 8; i++ )
......@@ -1050,7 +1083,11 @@ static uint32_t xCalcHAD8x16_SSE( const Torg *piOrg, const Tcur *piCur, const in
uint32_t sad = _mm_cvtsi128_si32( iSum );
sad = (uint32_t)(sad / sqrt(16.0 * 8) * 2);
#if JVET_R0164_MEAN_SCALED_SATD
sad -= absDc;
sad += absDc >> 2;
#endif
sad = (uint32_t)(sad / sqrt(16.0 * 8) * 2);
return sad;
}
......@@ -1177,6 +1214,9 @@ static uint32_t xCalcHAD8x4_SSE( const Torg *piOrg, const Tcur *piCur, const int
}
}
#if JVET_R0164_MEAN_SCALED_SATD
uint32_t absDc = _mm_cvtsi128_si32( m1[0] );
#endif
m1[0] = _mm_add_epi32( m1[0], m1[1] );
m1[1] = _mm_add_epi32( m1[2], m1[3] );
......@@ -1193,7 +1233,11 @@ static uint32_t xCalcHAD8x4_SSE( const Torg *piOrg, const Tcur *piCur, const int
uint32_t sad = _mm_cvtsi128_si32( iSum );
//sad = ((sad + 2) >> 2);
sad = (uint32_t)(sad / sqrt(4.0 * 8) * 2);
#if JVET_R0164_MEAN_SCALED_SATD
sad -= absDc;
sad += absDc >> 2;
#endif
sad = (uint32_t)(sad / sqrt(4.0 * 8) * 2);
return sad;
}
......@@ -1261,6 +1305,10 @@ static uint32_t xCalcHAD4x8_SSE( const Torg *piOrg, const Tcur *piCur, const int
m2[3] = _mm_unpackhi_epi64( m1[1], m1[3] );
}
#if JVET_R0164_MEAN_SCALED_SATD
uint32_t absDc = 0;
#endif
if( iBitDepth >= 10 /*sizeof( Torg ) > 1 || sizeof( Tcur ) > 1*/ )
{
__m128i n1[4][2];
......@@ -1288,6 +1336,10 @@ static uint32_t xCalcHAD4x8_SSE( const Torg *piOrg, const Tcur *piCur, const int
{
m1[i] = _mm_add_epi32( n1[i][0], n1[i][1] );
}
#if JVET_R0164_MEAN_SCALED_SATD
absDc = _mm_cvtsi128_si32( n1[0][0] );
#endif
}
else
{
......@@ -1310,6 +1362,10 @@ static uint32_t xCalcHAD4x8_SSE( const Torg *piOrg, const Tcur *piCur, const int
ma2 = _mm_unpackhi_epi16( m2[i], vzero );
m1[i] = _mm_add_epi32( ma1, ma2 );
}
#if JVET_R0164_MEAN_SCALED_SATD
absDc = _mm_cvtsi128_si32( m2[0] ) & 0x0000ffff;
#endif
}
m1[0] = _mm_add_epi32( m1[0], m1[1] );
......@@ -1323,7 +1379,11 @@ static uint32_t xCalcHAD4x8_SSE( const Torg *piOrg, const Tcur *piCur, const int
uint32_t sad = _mm_cvtsi128_si32( iSum );
//sad = ((sad + 2) >> 2);
sad = (uint32_t)(sad / sqrt(4.0 * 8) * 2);
#if JVET_R0164_MEAN_SCALED_SATD
sad -= absDc;
sad += absDc >> 2;
#endif
sad = (uint32_t)(sad / sqrt(4.0 * 8) * 2);
return sad;
}
......@@ -1462,6 +1522,11 @@ static uint32_t xCalcHAD16x16_AVX2( const Torg *piOrg, const Tcur *piCur, const
m2[i][7] = _mm256_abs_epi32( _mm256_sub_epi32( m1[i][6], m1[i][7] ) );
}
#if JVET_R0164_MEAN_SCALED_SATD
uint32_t absDc0 = _mm_cvtsi128_si32( _mm256_castsi256_si128( m2[0][0] ) );
uint32_t absDc1 = _mm_cvtsi128_si32( _mm256_castsi256_si128( _mm256_permute2x128_si256( m2[0][0], m2[0][0], 0x11 ) ) );
#endif
for( int i = 0; i < 8; i++ )
{
m1[0][i] = _mm256_add_epi32( m2[0][i], m2[1][i] );
......@@ -1481,12 +1546,20 @@ static uint32_t xCalcHAD16x16_AVX2( const Torg *piOrg, const Tcur *piCur, const
iSum = _mm256_hadd_epi32( iSum, iSum );
uint32_t tmp;
tmp = _mm_cvtsi128_si32( _mm256_castsi256_si128( iSum ) );
tmp = ( ( tmp + 2 ) >> 2 );
tmp = _mm_cvtsi128_si32( _mm256_castsi256_si128( iSum ) );
#if JVET_R0164_MEAN_SCALED_SATD
tmp -= absDc0;
tmp += absDc0 >> 2;
#endif
tmp = ( ( tmp + 2 ) >> 2 );
sad += tmp;
tmp = _mm_cvtsi128_si32( _mm256_castsi256_si128( _mm256_permute2x128_si256( iSum, iSum, 0x11 ) ) );
tmp = ( ( tmp + 2 ) >> 2 );
tmp = _mm_cvtsi128_si32( _mm256_castsi256_si128( _mm256_permute2x128_si256( iSum, iSum, 0x11 ) ) );
#if JVET_R0164_MEAN_SCALED_SATD
tmp -= absDc1;
tmp += absDc1 >> 2;
#endif
tmp = ( ( tmp + 2 ) >> 2 );
sad += tmp;
}
......@@ -1700,6 +1773,10 @@ static uint32_t xCalcHAD16x8_AVX2( const Torg *piOrg, const Tcur *piCur, const i
m1[15] = _mm256_abs_epi32( _mm256_sub_epi32( m2[14], m2[15] ) );
}
#if JVET_R0164_MEAN_SCALED_SATD
uint32_t absDc = _mm_cvtsi128_si32( _mm256_castsi256_si128( m1[0] ) );
#endif
// sum up
m1[ 0] = _mm256_add_epi32( m1[ 0], m1[ 1] );
m1[ 2] = _mm256_add_epi32( m1[ 2], m1[ 3] );
......@@ -1723,9 +1800,12 @@ static uint32_t xCalcHAD16x8_AVX2( const Torg *piOrg, const Tcur *piCur, const i
iSum = _mm256_hadd_epi32( iSum, iSum );
iSum = _mm256_add_epi32( iSum, _mm256_permute2x128_si256( iSum, iSum, 0x11 ) );
sad = _mm_cvtsi128_si32( _mm256_castsi256_si128( iSum ) );
sad = (uint32_t)(sad / sqrt(16.0 * 8) * 2);
sad = _mm_cvtsi128_si32( _mm256_castsi256_si128( iSum ) );
#if JVET_R0164_MEAN_SCALED_SATD
sad -= absDc;
sad += absDc >> 2;
#endif
sad = (uint32_t)(sad / sqrt(16.0 * 8) * 2);
}
#endif //USE_AVX2
......@@ -1911,6 +1991,10 @@ static uint32_t xCalcHAD8x16_AVX2( const Pel* piOrg, const Pel* piCur, const int
m1[6] = _mm256_abs_epi32( _mm256_add_epi32( m2[6], m2[7] ) );
m1[7] = _mm256_abs_epi32( _mm256_sub_epi32( m2[6], m2[7] ) );
#if JVET_R0164_MEAN_SCALED_SATD
int absDc = _mm_cvtsi128_si32( _mm256_castsi256_si128( m1[0] ) );
#endif
m1[0 + 8] = _mm256_add_epi32( m2[0 + 8], m2[4 + 8] );
m1[1 + 8] = _mm256_add_epi32( m2[1 + 8], m2[5 + 8] );
m1[2 + 8] = _mm256_add_epi32( m2[2 + 8], m2[6 + 8] );
......@@ -1965,7 +2049,11 @@ static uint32_t xCalcHAD8x16_AVX2( const Pel* piOrg, const Pel* piCur, const int
int sad2 = _mm_cvtsi128_si32( _mm256_castsi256_si128( iSum ) );
sad = (uint32_t)(sad2 / sqrt(16.0 * 8) * 2);
#if JVET_R0164_MEAN_SCALED_SATD
sad2 -= absDc;
sad2 += absDc >> 2;
#endif
sad = (uint32_t)(sad2 / sqrt(16.0 * 8) * 2);
}
#endif //USE_AVX2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment