From 402e4b5ac7f5b77028807f288ea441acffdb35e4 Mon Sep 17 00:00:00 2001
From: Xiaoyu Xiu <xiaoyuxiu@kwai.com>
Date: Mon, 21 Oct 2019 23:17:47 +0200
Subject: [PATCH] JVET-P0512: SIMD support for MC at high internal bit-depth

---
 source/Lib/CommonLib/TypeDef.h                |   2 +
 .../CommonLib/x86/InterpolationFilterX86.h    | 119 ++++++++++++++++++
 2 files changed, 121 insertions(+)

diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index 9ebce3ec6..0488ef90f 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -50,6 +50,8 @@
 #include <assert.h>
 #include <cassert>
 
+#define JVET_P0512_SIMD_HIGH_BITDEPTH                     1 // JVET-P0512: MC SIMD support for high internal bit-depthf
+
 #define JVET_P0491_BDOFPROF_MVD_RANGE                     1 // JVET-P0491: clip the MVD in BDOF/PROF to [-31 31]
 
 #define JVET_P0460_PLT_TS_MIN_QP                          1 // JVET-P0460: Use TS min QP for Palette Escape mode
diff --git a/source/Lib/CommonLib/x86/InterpolationFilterX86.h b/source/Lib/CommonLib/x86/InterpolationFilterX86.h
index 6d94bd153..02a84cf6a 100644
--- a/source/Lib/CommonLib/x86/InterpolationFilterX86.h
+++ b/source/Lib/CommonLib/x86/InterpolationFilterX86.h
@@ -1008,6 +1008,112 @@ static inline __m128i simdInterpolateLuma10Bit2P4(int16_t const *src, int srcStr
   return sumLo;
 }
 
+#if JVET_P0512_SIMD_HIGH_BITDEPTH
+#ifdef USE_AVX2
+static inline __m256i simdInterpolateLumaHighBit2P16(int16_t const *src1, int srcStride, __m256i *mmCoeff, const __m256i & mmOffset, __m128i &mmShift)
+{
+  __m256i mm_mul_lo = _mm256_setzero_si256();
+  __m256i mm_mul_hi = _mm256_setzero_si256();
+
+  for (int coefIdx = 0; coefIdx < 2; coefIdx++)
+  {
+    __m256i mmPix = _mm256_lddqu_si256((__m256i*)(src1 + coefIdx * srcStride));
+    __m256i mm_hi = _mm256_mulhi_epi16(mmPix, mmCoeff[coefIdx]);
+    __m256i mm_lo = _mm256_mullo_epi16(mmPix, mmCoeff[coefIdx]);
+    mm_mul_lo = _mm256_add_epi32(mm_mul_lo, _mm256_unpacklo_epi16(mm_lo, mm_hi));
+    mm_mul_hi = _mm256_add_epi32(mm_mul_hi, _mm256_unpackhi_epi16(mm_lo, mm_hi));
+  }
+  mm_mul_lo = _mm256_sra_epi32(_mm256_add_epi32(mm_mul_lo, mmOffset), mmShift);
+  mm_mul_hi = _mm256_sra_epi32(_mm256_add_epi32(mm_mul_hi, mmOffset), mmShift);
+  __m256i mm_sum = _mm256_packs_epi32(mm_mul_lo, mm_mul_hi);
+  return (mm_sum);
+}
+#endif
+
+static inline __m128i simdInterpolateLumaHighBit2P8(int16_t const *src1, int srcStride, __m128i *mmCoeff, const __m128i & mmOffset, __m128i &mmShift)
+{
+  __m128i mm_mul_lo = _mm_setzero_si128();
+  __m128i mm_mul_hi = _mm_setzero_si128();
+
+  for (int coefIdx = 0; coefIdx < 2; coefIdx++)
+  {
+    __m128i mmPix = _mm_loadu_si128((__m128i*)(src1 + coefIdx * srcStride));
+    __m128i mm_hi = _mm_mulhi_epi16(mmPix, mmCoeff[coefIdx]);
+    __m128i mm_lo = _mm_mullo_epi16(mmPix, mmCoeff[coefIdx]);
+    mm_mul_lo = _mm_add_epi32(mm_mul_lo, _mm_unpacklo_epi16(mm_lo, mm_hi));
+    mm_mul_hi = _mm_add_epi32(mm_mul_hi, _mm_unpackhi_epi16(mm_lo, mm_hi));
+  }
+  mm_mul_lo = _mm_sra_epi32(_mm_add_epi32(mm_mul_lo, mmOffset), mmShift);
+  mm_mul_hi = _mm_sra_epi32(_mm_add_epi32(mm_mul_hi, mmOffset), mmShift);
+  __m128i mm_sum = _mm_packs_epi32(mm_mul_lo, mm_mul_hi);
+  return(mm_sum);
+}
+
+static inline __m128i simdInterpolateLumaHighBit2P4(int16_t const *src1, int srcStride, __m128i *mmCoeff, const __m128i & mmOffset, __m128i &mmShift)
+{
+  __m128i mm_sum = _mm_setzero_si128();
+  __m128i mm_zero = _mm_setzero_si128();
+  for (int coefIdx = 0; coefIdx < 2; coefIdx++)
+  {
+    __m128i mmPix = _mm_loadl_epi64((__m128i*)(src1 + coefIdx * srcStride));
+    __m128i mm_hi = _mm_mulhi_epi16(mmPix, mmCoeff[coefIdx]);
+    __m128i mm_lo = _mm_mullo_epi16(mmPix, mmCoeff[coefIdx]);
+    __m128i mm_mul = _mm_unpacklo_epi16(mm_lo, mm_hi);
+    mm_sum = _mm_add_epi32(mm_sum, mm_mul);
+  }
+  mm_sum = _mm_sra_epi32(_mm_add_epi32(mm_sum, mmOffset), mmShift);
+  mm_sum = _mm_packs_epi32(mm_sum, mm_zero);
+  return(mm_sum);
+}
+
+template<X86_VEXT vext, bool isLast>
+static void simdInterpolateN2_HIGHBIT_M4(const int16_t* src, int srcStride, int16_t *dst, int dstStride, int cStride, int width, int height, int shift, int offset, const ClpRng& clpRng, int16_t const *c)
+{
+#if USE_AVX2
+  __m256i mm256Offset = _mm256_set1_epi32(offset);
+  __m256i mm256Coeff[2];
+  for (int n = 0; n < 2; n++)
+  {
+    mm256Coeff[n] = _mm256_set1_epi16(c[n]);
+  }
+#endif
+  __m128i mmOffset = _mm_set1_epi32(offset);
+  __m128i mmCoeff[2];
+  for (int n = 0; n < 2; n++)
+    mmCoeff[n] = _mm_set1_epi16(c[n]);
+
+  __m128i mmShift = _mm_cvtsi64_si128(shift);
+
+  CHECK(isLast, "Not Supported");
+  CHECK(width % 4 != 0, "Not Supported");
+
+  for (int row = 0; row < height; row++)
+  {
+    int col = 0;
+#if USE_AVX2
+    for (; col < ((width >> 4) << 4); col += 16)
+    {
+      __m256i mmFiltered = simdInterpolateLumaHighBit2P16(src + col, cStride, mm256Coeff, mm256Offset, mmShift);
+      _mm256_storeu_si256((__m256i *)(dst + col), mmFiltered);
+    }
+#endif
+    for (; col < ((width >> 3) << 3); col += 8)
+    {
+      __m128i mmFiltered = simdInterpolateLumaHighBit2P8(src + col, cStride, mmCoeff, mmOffset, mmShift);
+      _mm_storeu_si128((__m128i *)(dst + col), mmFiltered);
+    }
+
+    for (; col < ((width >> 2) << 2); col += 4)
+    {
+      __m128i mmFiltered = simdInterpolateLumaHighBit2P4(src + col, cStride, mmCoeff, mmOffset, mmShift);
+      _mm_storel_epi64((__m128i *)(dst + col), mmFiltered);
+    }
+    src += srcStride;
+    dst += dstStride;
+  }
+}
+#endif
+
 template<X86_VEXT vext, bool isLast>
 static void simdInterpolateN2_10BIT_M4(const int16_t* src, int srcStride, int16_t *dst, int dstStride, int cStride, int width, int height, int shift, int offset, const ClpRng& clpRng, int16_t const *c)
 {
@@ -1112,7 +1218,9 @@ static void simdFilter( const ClpRng& clpRng, Pel const *src, int srcStride, Pel
       offset = 1 << (shift - 1);
     }
   }
+#if !JVET_P0512_SIMD_HIGH_BITDEPTH
   if( clpRng.bd <= 10 )
+#endif
   {
     if( N == 8 && !( width & 0x07 ) )
     {
@@ -1164,7 +1272,18 @@ static void simdFilter( const ClpRng& clpRng, Pel const *src, int srcStride, Pel
     {
       if (N == 2 && !(width & 0x03))
       {
+#if JVET_P0512_SIMD_HIGH_BITDEPTH
+        if (clpRng.bd <= 10)
+        {
+#endif
         simdInterpolateN2_10BIT_M4<vext, isLast>(src, srcStride, dst, dstStride, cStride, width, height, shift, offset, clpRng, c);
+#if JVET_P0512_SIMD_HIGH_BITDEPTH
+        }
+        else
+        {
+          simdInterpolateN2_HIGHBIT_M4<vext, isLast>(src, srcStride, dst, dstStride, cStride, width, height, shift, offset, clpRng, c);
+        }
+#endif
         return;
       }
     }
-- 
GitLab