Commit 87a15f25 authored by Xiaoyu Xiu's avatar Xiaoyu Xiu
Browse files

HBD_SIMD_FIX

parent e23eeafa
......@@ -62,7 +62,7 @@
//########### place macros to be be kept below this line ###############
#define JVET_S0257_DUMP_360SEI_MESSAGE 1 // Software support of 360 SEI messages
#define JVET_R0351_HIGH_BIT_DEPTH_ENABLED 0 // JVET-R0351: high bit depth coding enabled (increases accuracies of some calculations, e.g. transforms)
#define JVET_R0351_HIGH_BIT_DEPTH_ENABLED 1 // JVET-R0351: high bit depth coding enabled (increases accuracies of some calculations, e.g. transforms)
#define JVET_R0164_MEAN_SCALED_SATD 1 // JVET-R0164: Use a mean scaled version of SATD in encoder decisions
......
......@@ -555,185 +555,6 @@ static void simdFilter5x5Blk_HBD_AVX2(AlfClassifier **classifier, const PelUnitB
dst += dstStride * step_y;
}
}
static void simdDeriveClassificationBlk_HBD_AVX2(AlfClassifier **classifier, int **laplacian[NUM_DIRECTIONS],
const CPelBuf &srcLuma, const Area &blkDst, const Area &blk, const int shift,
const int vbCTUHeight, int vbPos)
{
CHECK((blk.height & 7) != 0, "Block height must be a multiple of 8");
CHECK((blk.width & 7) != 0, "Block width must be a multiple of 8");
CHECK((vbCTUHeight & (vbCTUHeight - 1)) != 0, "vbCTUHeight must be a power of 2");
const size_t imgStride = srcLuma.stride;
const Pel * srcExt = srcLuma.buf;
const int imgHExtended = blk.height + 4;
const int imgWExtended = blk.width + 4;
const int posX = blk.pos().x;
const int posY = blk.pos().y;
// 18x40 array
uint32_t colSums[(AdaptiveLoopFilter::m_CLASSIFICATION_BLK_SIZE + 4) >> 1]
[AdaptiveLoopFilter::m_CLASSIFICATION_BLK_SIZE + 8];
for (int i = 0; i < imgHExtended; i += 2)
{
const size_t offset = (i + posY - 3) * imgStride + posX - 3;
const Pel *imgY0 = &srcExt[offset];
const Pel *imgY1 = &srcExt[offset + imgStride];
const Pel *imgY2 = &srcExt[offset + imgStride * 2];
const Pel *imgY3 = &srcExt[offset + imgStride * 3];
// pixel padding for gradient calculation
int pos = blkDst.pos().y - 2 + i;
int posInCTU = pos & (vbCTUHeight - 1);
if (pos > 0 && posInCTU == vbPos - 2)
{
imgY3 = imgY2;
}
else if (pos > 0 && posInCTU == vbPos)
{
imgY0 = imgY1;
}
__m256i prev = _mm256_setzero_si256();
for (int j = 0; j < imgWExtended; j += 8)
{
const __m256i x0 = _mm256_lddqu_si256((const __m256i *) (imgY0 + j));
const __m256i x1 = _mm256_lddqu_si256((const __m256i *) (imgY1 + j));
const __m256i x2 = _mm256_lddqu_si256((const __m256i *) (imgY2 + j));
const __m256i x3 = _mm256_lddqu_si256((const __m256i *) (imgY3 + j));
const __m256i x4 = _mm256_lddqu_si256((const __m256i *) (imgY0 + j + 2));
const __m256i x5 = _mm256_lddqu_si256((const __m256i *) (imgY1 + j + 2));
const __m256i x6 = _mm256_lddqu_si256((const __m256i *) (imgY2 + j + 2));
const __m256i x7 = _mm256_lddqu_si256((const __m256i *) (imgY3 + j + 2));
const __m256i nw = _mm256_blend_epi32(x0, x1, 0xaa);
const __m256i n = _mm256_blend_epi32(x0, x5, 0x55);
const __m256i ne = _mm256_blend_epi32(x4, x5, 0xaa);
const __m256i w = _mm256_blend_epi32(x1, x2, 0xaa);
const __m256i e = _mm256_blend_epi32(x5, x6, 0xaa);
const __m256i sw = _mm256_blend_epi32(x2, x3, 0xaa);
const __m256i s = _mm256_blend_epi32(x2, x7, 0x55);
const __m256i se = _mm256_blend_epi32(x6, x7, 0xaa);
__m256i c = _mm256_slli_epi32(_mm256_blend_epi32(x1, x6, 0x55), 1);
__m256i d = _mm256_shuffle_epi32(c, 0xb1);
const __m256i ver = _mm256_abs_epi32(_mm256_sub_epi32(c, _mm256_add_epi32(n, s)));
const __m256i hor = _mm256_abs_epi32(_mm256_sub_epi32(d, _mm256_add_epi32(w, e)));
const __m256i di0 = _mm256_abs_epi32(_mm256_sub_epi32(d, _mm256_add_epi32(nw, se)));
const __m256i di1 = _mm256_abs_epi32(_mm256_sub_epi32(d, _mm256_add_epi32(ne, sw)));
const __m256i hv = _mm256_hadd_epi32(ver, hor);
const __m256i di = _mm256_hadd_epi32(di0, di1);
const __m256i all = _mm256_shuffle_epi32(_mm256_permute4x64_epi64(_mm256_hadd_epi32(hv, di), 0xd8), 0xd8);
const __m256i t = _mm256_blend_epi32(all, prev, 0xaa);
_mm256_store_si256((__m256i *) &colSums[i >> 1][j], _mm256_permute4x64_epi64(_mm256_hadd_epi32(t, all), 0xd8));
prev = all;
}
}
__m256i zeros = _mm256_setzero_si256();
for (int i = 0; i < (blk.height >> 1); i += 4) // 2 4x4 vertical
{
for (int j = 0; j < blk.width; j += 8) // 2 4x4 horizontal
{
__m256i x0, x1, x2, x3, x4, x5, x6, x7;
const uint32_t z = (2 * i + blkDst.pos().y) & (vbCTUHeight - 1);
const uint32_t z2 = (2 * i + 4 + blkDst.pos().y) & (vbCTUHeight - 1);
x0 = (z == vbPos) ? zeros : _mm256_lddqu_si256((__m256i *) &colSums[i + 0][j + 4]);
x1 = _mm256_lddqu_si256((__m256i *) &colSums[i + 1][j + 4]);
x2 = _mm256_lddqu_si256((__m256i *) &colSums[i + 2][j + 4]);
x3 = (z == vbPos - 4) ? zeros : _mm256_lddqu_si256((__m256i *) &colSums[i + 3][j + 4]);
x4 = (z2 == vbPos) ? zeros : _mm256_lddqu_si256((__m256i *) &colSums[i + 2][j + 4]);
x5 = _mm256_lddqu_si256((__m256i *) &colSums[i + 3][j + 4]);
x6 = _mm256_lddqu_si256((__m256i *) &colSums[i + 4][j + 4]);
x7 = (z2 == vbPos - 4) ? zeros : _mm256_lddqu_si256((__m256i *) &colSums[i + 5][j + 4]);
x0 = _mm256_add_epi32(x0, x1);
x2 = _mm256_add_epi32(x2, x3);
x4 = _mm256_add_epi32(x4, x5);
x6 = _mm256_add_epi32(x6, x7);
x0 = _mm256_add_epi32(x0, x2);
x4 = _mm256_add_epi32(x4, x6);
x2 = _mm256_unpacklo_epi32(x0, x4);
x3 = _mm256_permute4x64_epi64(x2, 0x4e);
x6 = _mm256_unpackhi_epi32(x0, x4);
x7 = _mm256_permute4x64_epi64(x6, 0x4e);
__m128i sumV = _mm256_castsi256_si128(_mm256_unpacklo_epi32(x2, x3));
__m128i sumH = _mm256_castsi256_si128(_mm256_unpackhi_epi32(x2, x3));
__m128i sumD0 = _mm256_castsi256_si128(_mm256_unpacklo_epi32(x6, x7));
__m128i sumD1 = _mm256_castsi256_si128(_mm256_unpackhi_epi32(x6, x7));
__m128i tempAct = _mm_add_epi32(sumV, sumH);
const uint32_t scale = (z == vbPos - 4 || z == vbPos) ? 96 : 64;
const uint32_t scale2 = (z2 == vbPos - 4 || z2 == vbPos) ? 96 : 64;
__m128i activity = _mm_mullo_epi32(tempAct, _mm_unpacklo_epi64(_mm_set1_epi32(scale), _mm_set1_epi32(scale2)));
activity = _mm_srl_epi32(activity, _mm_cvtsi32_si128(shift));
activity = _mm_min_epi32(activity, _mm_set1_epi32(15));
__m128i classIdx = _mm_shuffle_epi8(_mm_setr_epi8(0, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4), activity);
__m128i dirTempHVMinus1 = _mm_cmpgt_epi32(sumV, sumH);
__m128i hv1 = _mm_max_epi32(sumV, sumH);
__m128i hv0 = _mm_min_epi32(sumV, sumH);
__m128i dirTempDMinus1 = _mm_cmpgt_epi32(sumD0, sumD1);
__m128i d1 = _mm_max_epi32(sumD0, sumD1);
__m128i d0 = _mm_min_epi32(sumD0, sumD1);
__m128i a = _mm_xor_si128(_mm_mullo_epi32(d1, hv0), _mm_set1_epi32(0x80000000));
__m128i b = _mm_xor_si128(_mm_mullo_epi32(hv1, d0), _mm_set1_epi32(0x80000000));
__m128i dirIdx = _mm_cmpgt_epi32(a, b);
__m128i hvd1 = _mm_blendv_epi8(hv1, d1, dirIdx);
__m128i hvd0 = _mm_blendv_epi8(hv0, d0, dirIdx);
__m128i strength1 = _mm_cmpgt_epi32(hvd1, _mm_add_epi32(hvd0, hvd0));
__m128i strength2 = _mm_cmpgt_epi32(_mm_add_epi32(hvd1, hvd1), _mm_add_epi32(hvd0, _mm_slli_epi32(hvd0, 3)));
__m128i offset = _mm_and_si128(strength1, _mm_set1_epi32(5));
classIdx = _mm_add_epi32(classIdx, offset);
classIdx = _mm_add_epi32(classIdx, _mm_and_si128(strength2, _mm_set1_epi32(5)));
offset = _mm_andnot_si128(dirIdx, offset);
offset = _mm_add_epi32(offset, offset);
classIdx = _mm_add_epi32(classIdx, offset);
__m128i transposeIdx = _mm_set1_epi32(3);
transposeIdx = _mm_add_epi32(transposeIdx, dirTempHVMinus1);
transposeIdx = _mm_add_epi32(transposeIdx, dirTempDMinus1);
transposeIdx = _mm_add_epi32(transposeIdx, dirTempDMinus1);
int yOffset = 2 * i + blkDst.pos().y;
int xOffset = j + blkDst.pos().x;
static_assert(sizeof(AlfClassifier) == 2, "ALFClassifier type must be 16 bits wide");
__m128i v;
v = _mm_unpacklo_epi8(classIdx, transposeIdx);
v = _mm_shuffle_epi8(v, _mm_setr_epi8(0, 1, 0, 1, 0, 1, 0, 1, 8, 9, 8, 9, 8, 9, 8, 9));
_mm_storeu_si128((__m128i *) (classifier[yOffset] + xOffset), v);
_mm_storeu_si128((__m128i *) (classifier[yOffset + 1] + xOffset), v);
_mm_storeu_si128((__m128i *) (classifier[yOffset + 2] + xOffset), v);
_mm_storeu_si128((__m128i *) (classifier[yOffset + 3] + xOffset), v);
v = _mm_unpackhi_epi8(classIdx, transposeIdx);
v = _mm_shuffle_epi8(v, _mm_setr_epi8(0, 1, 0, 1, 0, 1, 0, 1, 8, 9, 8, 9, 8, 9, 8, 9));
_mm_storeu_si128((__m128i *) (classifier[yOffset + 4] + xOffset), v);
_mm_storeu_si128((__m128i *) (classifier[yOffset + 5] + xOffset), v);
_mm_storeu_si128((__m128i *) (classifier[yOffset + 6] + xOffset), v);
_mm_storeu_si128((__m128i *) (classifier[yOffset + 7] + xOffset), v);
}
}
}
#endif
#else
template<X86_VEXT vext>
......@@ -1853,17 +1674,16 @@ template <X86_VEXT vext>
void AdaptiveLoopFilter::_initAdaptiveLoopFilterX86()
{
#if RExt__HIGH_BIT_DEPTH_SUPPORT
m_deriveClassificationBlk = simdDeriveClassificationBlk_HBD;
#ifdef USE_AVX2
if (vext >= AVX2)
{
m_deriveClassificationBlk = simdDeriveClassificationBlk_HBD_AVX2;
m_filter5x5Blk = simdFilter5x5Blk_HBD_AVX2;
m_filter7x7Blk = simdFilter7x7Blk_HBD_AVX2;
}
else
#endif
{
m_deriveClassificationBlk = simdDeriveClassificationBlk_HBD;
m_filter5x5Blk = simdFilter5x5Blk_HBD;
m_filter7x7Blk = simdFilter7x7Blk_HBD;
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment