From 18411bc05f50b5ecb281ad38ec522c600f827329 Mon Sep 17 00:00:00 2001
From: Frank Bossen <fbossen@gmail.com>
Date: Wed, 6 Jul 2022 12:34:21 -0400
Subject: [PATCH] Clean up BCW target block computation for bipred

---
 source/Lib/CommonLib/Buffer.cpp       |  91 +-----------
 source/Lib/CommonLib/Buffer.h         | 118 ++++++----------
 source/Lib/CommonLib/CommonDef.h      |   9 +-
 source/Lib/CommonLib/Rom.cpp          |   5 +-
 source/Lib/CommonLib/Rom.h            |   2 -
 source/Lib/CommonLib/x86/BufferX86.h  | 192 ++++++++++++--------------
 source/Lib/EncoderLib/InterSearch.cpp |   2 +-
 7 files changed, 148 insertions(+), 271 deletions(-)

diff --git a/source/Lib/CommonLib/Buffer.cpp b/source/Lib/CommonLib/Buffer.cpp
index f8bf39bc9..50fa1889f 100644
--- a/source/Lib/CommonLib/Buffer.cpp
+++ b/source/Lib/CommonLib/Buffer.cpp
@@ -226,80 +226,6 @@ void calcBlkGradientCore(int sx, int sy, int     *arraysGx2, int     *arraysGxGy
   }
 }
 
-#if ENABLE_SIMD_OPT_BCW
-void removeWeightHighFreq(int16_t *dst, ptrdiff_t dstStride, const int16_t *src, ptrdiff_t srcStride, int width,
-                          int height, int shift, int bcwWeight)
-{
-  const int normalizer = ((1 << 16) + (bcwWeight > 0 ? (bcwWeight >> 1) : -(bcwWeight >> 1))) / bcwWeight;
-
-  const int weight0 = normalizer << g_bcwLog2WeightBase;
-  const int weight1 = (g_bcwWeightBase - bcwWeight) * normalizer;
-
-#define REM_HF_INC  \
-  src += srcStride; \
-  dst += dstStride; \
-
-#define REM_HF_OP( ADDR )      dst[ADDR] =             (dst[ADDR]*weight0 - src[ADDR]*weight1 + (1<<15))>>16
-
-  SIZE_AWARE_PER_EL_OP(REM_HF_OP, REM_HF_INC);
-
-#undef REM_HF_INC
-#undef REM_HF_OP
-#undef REM_HF_OP_CLIP
-}
-
-void removeHighFreq(int16_t *dst, ptrdiff_t dstStride, const int16_t *src, ptrdiff_t srcStride, int width, int height)
-{
-#define REM_HF_INC  \
-  src += srcStride; \
-  dst += dstStride; \
-
-#define REM_HF_OP( ADDR )      dst[ADDR] =             2 * dst[ADDR] - src[ADDR]
-
-  SIZE_AWARE_PER_EL_OP(REM_HF_OP, REM_HF_INC);
-
-#undef REM_HF_INC
-#undef REM_HF_OP
-#undef REM_HF_OP_CLIP
-}
-#if RExt__HIGH_BIT_DEPTH_SUPPORT
-void removeWeightHighFreq_HBD(Pel *dst, ptrdiff_t dstStride, const Pel *src, ptrdiff_t srcStride, int width, int height,
-                              int shift, int bcwWeight)
-{
-  Intermediate_Int normalizer = ((1 << 16) + (bcwWeight > 0 ? (bcwWeight >> 1) : -(bcwWeight >> 1))) / bcwWeight;
-
-  Intermediate_Int weight0 = normalizer << g_bcwLog2WeightBase;
-  Intermediate_Int weight1 = (g_bcwWeightBase - bcwWeight) * normalizer;
-#define REM_HF_INC  \
-  src += srcStride; \
-  dst += dstStride; \
-
-#define REM_HF_OP( ADDR )      dst[ADDR] =             (Pel)((dst[ADDR]*weight0 - src[ADDR]*weight1 + (1<<15))>>16)
-
-  SIZE_AWARE_PER_EL_OP(REM_HF_OP, REM_HF_INC);
-
-#undef REM_HF_INC
-#undef REM_HF_OP
-#undef REM_HF_OP_CLIP
-}
-
-void removeHighFreq_HBD(Pel *dst, ptrdiff_t dstStride, const Pel *src, ptrdiff_t srcStride, int width, int height)
-{
-#define REM_HF_INC  \
-  src += srcStride; \
-  dst += dstStride; \
-
-#define REM_HF_OP( ADDR )      dst[ADDR] =             2 * dst[ADDR] - src[ADDR]
-
-  SIZE_AWARE_PER_EL_OP(REM_HF_OP, REM_HF_INC);
-
-#undef REM_HF_INC
-#undef REM_HF_OP
-#undef REM_HF_OP_CLIP
-}
-#endif
-#endif
-
 template<typename T>
 void reconstructCore(const T *src1, ptrdiff_t src1Stride, const T *src2, ptrdiff_t src2Stride, T *dest,
                      ptrdiff_t dstStride, int width, int height, const ClpRng &clpRng)
@@ -349,17 +275,10 @@ PelBufferOps::PelBufferOps()
   copyBuffer = copyBufferCore;
   padding = paddingCore;
 #if ENABLE_SIMD_OPT_BCW
-#if RExt__HIGH_BIT_DEPTH_SUPPORT
-  removeWeightHighFreq8 = removeWeightHighFreq_HBD;
-  removeWeightHighFreq4 = removeWeightHighFreq_HBD;
-  removeHighFreq8 = removeHighFreq_HBD;
-  removeHighFreq4 = removeHighFreq_HBD;
-#else
-  removeWeightHighFreq8 = removeWeightHighFreq;
-  removeWeightHighFreq4 = removeWeightHighFreq;
-  removeHighFreq8 = removeHighFreq;
-  removeHighFreq4 = removeHighFreq;
-#endif
+  removeWeightHighFreq8 = nullptr;
+  removeWeightHighFreq4 = nullptr;
+  removeHighFreq8       = nullptr;
+  removeHighFreq4       = nullptr;
 #endif
 
   profGradFilter = gradFilterCore <false>;
@@ -407,7 +326,7 @@ void AreaBuf<Pel>::addWeightedAvg(const AreaBuf<const Pel> &other1, const AreaBu
 {
   const int8_t w0 = getBcwWeight(bcwIdx, REF_PIC_LIST_0);
   const int8_t w1 = getBcwWeight(bcwIdx, REF_PIC_LIST_1);
-  const int8_t log2WeightBase = g_bcwLog2WeightBase;
+  const int8_t log2WeightBase = BCW_LOG2_WEIGHT_BASE;
 
   const Pel* src0 = other1.buf;
   const Pel* src2 = other2.buf;
diff --git a/source/Lib/CommonLib/Buffer.h b/source/Lib/CommonLib/Buffer.h
index 646b130a0..dd3547073 100644
--- a/source/Lib/CommonLib/Buffer.h
+++ b/source/Lib/CommonLib/Buffer.h
@@ -93,9 +93,9 @@ struct PelBufferOps
   void (*padding)(Pel *dst, ptrdiff_t stride, int width, int height, int padSize);
 #if ENABLE_SIMD_OPT_BCW
   void (*removeWeightHighFreq8)(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, ptrdiff_t src1Stride, int width,
-                                int height, int shift, int bcwWeight);
+                                int height, int bcwWeight, const Pel minVal, const Pel maxVal);
   void (*removeWeightHighFreq4)(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, ptrdiff_t src1Stride, int width,
-                                int height, int shift, int bcwWeight);
+                                int height, int bcwWeight, const Pel minVal, const Pel maxVal);
   void (*removeHighFreq8)(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, ptrdiff_t src1Stride, int width,
                           int height);
   void (*removeHighFreq4)(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, ptrdiff_t src1Stride, int width,
@@ -409,115 +409,85 @@ template<>
 void AreaBuf<Pel>::toLast( const ClpRng& clpRng );
 
 template<typename T>
-void AreaBuf<T>::removeWeightHighFreq(const AreaBuf<T>& other, const bool bClip, const ClpRng& clpRng, const int8_t bcwWeight)
+void AreaBuf<T>::removeWeightHighFreq(const AreaBuf<T> &other, const bool clampToNominalRange, const ClpRng &clpRng,
+                                      const int8_t bcwWeight)
 {
-  const int8_t bcwWeightOther = g_bcwWeightBase - bcwWeight;
-  const int8_t log2WeightBase = g_bcwLog2WeightBase;
+  const Pel *src        = other.buf;
+  const ptrdiff_t  srcStride = other.stride;
 
-  const Pel* src = other.buf;
-  const ptrdiff_t srcStride = other.stride;
-
-  Pel* dst = buf;
+  Pel *     dst        = buf;
   const ptrdiff_t dstStride = stride;
 
+  const Pel minVal = clampToNominalRange ? clpRng.min : 5 * clpRng.min - 4 * clpRng.max;
+  const Pel maxVal = clampToNominalRange ? clpRng.max : 5 * clpRng.max - 4 * clpRng.min;
+
 #if ENABLE_SIMD_OPT_BCW
-  if(!bClip)
+  if ((width & 7) == 0 && g_pelBufOP.removeWeightHighFreq8)
   {
-    if(!(width & 7))
-    {
-      g_pelBufOP.removeWeightHighFreq8(dst, dstStride, src, srcStride, width, height, 16, bcwWeight);
-    }
-    else if(!(width & 3))
-    {
-      g_pelBufOP.removeWeightHighFreq4(dst, dstStride, src, srcStride, width, height, 16, bcwWeight);
-    }
-    else
-    {
-      THROW("Not supported");
-    }
+    g_pelBufOP.removeWeightHighFreq8(dst, dstStride, src, srcStride, width, height, bcwWeight, minVal, maxVal);
   }
-  else
+  else if ((width & 3) == 0 && g_pelBufOP.removeWeightHighFreq4)
   {
+    g_pelBufOP.removeWeightHighFreq4(dst, dstStride, src, srcStride, width, height, bcwWeight, minVal, maxVal);
+  }
+  else
 #endif
-    Intermediate_Int normalizer = ((1 << 16) + (bcwWeight > 0 ? (bcwWeight >> 1) : -(bcwWeight >> 1))) / bcwWeight;
-    Intermediate_Int weight0 = normalizer << log2WeightBase;
-    Intermediate_Int weight1 = bcwWeightOther * normalizer;
-#define REM_HF_INC  \
-  src += srcStride; \
-  dst += dstStride; \
-
-#define REM_HF_OP_CLIP( ADDR ) dst[ADDR] = ClipPel<T>( T((dst[ADDR]*weight0 - src[ADDR]*weight1 + (1<<15))>>16), clpRng )
-#define REM_HF_OP( ADDR )      dst[ADDR] =             T((dst[ADDR]*weight0 - src[ADDR]*weight1 + (1<<15))>>16)
+  {
+    const int32_t w =
+      ((BCW_WEIGHT_BASE << BCW_INV_BITS) + (bcwWeight > 0 ? (bcwWeight >> 1) : -(bcwWeight >> 1))) / bcwWeight;
 
-    if(bClip)
-    {
-      SIZE_AWARE_PER_EL_OP(REM_HF_OP_CLIP, REM_HF_INC);
-    }
-    else
-    {
-      SIZE_AWARE_PER_EL_OP(REM_HF_OP, REM_HF_INC);
-    }
+#define REM_HF_INC                                                                                                     \
+  src += srcStride;                                                                                                    \
+  dst += dstStride;
 
-#undef REM_HF_INC
-#undef REM_HF_OP
+#define REM_HF_OP_CLIP(ADDR)                                                                                           \
+  dst[ADDR] =                                                                                                          \
+    Clip3<T>(minVal, maxVal, (((dst[ADDR] - src[ADDR]) * w + (1 << BCW_INV_BITS >> 1)) >> BCW_INV_BITS) + src[ADDR])
+    SIZE_AWARE_PER_EL_OP(REM_HF_OP_CLIP, REM_HF_INC);
 #undef REM_HF_OP_CLIP
-#if ENABLE_SIMD_OPT_BCW
+#undef REM_HF_INC
   }
-#endif
 }
 
 template<typename T>
-void AreaBuf<T>::removeHighFreq( const AreaBuf<T>& other, const bool bClip, const ClpRng& clpRng )
+void AreaBuf<T>::removeHighFreq(const AreaBuf<T> &other, const bool clampToNominalRange, const ClpRng &clpRng)
 {
-  const T*  src       = other.buf;
+  const T * src        = other.buf;
   const ptrdiff_t srcStride = other.stride;
 
-  T              *dst       = buf;
+  T *       dst        = buf;
   const ptrdiff_t dstStride = stride;
 
-#if ENABLE_SIMD_OPT_BCW
-  if (!bClip)
+#define REM_HF_INC                                                                                                     \
+  src += srcStride;                                                                                                    \
+  dst += dstStride;
+
+  if (!clampToNominalRange)
   {
-    if(!(width & 7))
+#if ENABLE_SIMD_OPT_BCW
+    if (!(width & 7) && g_pelBufOP.removeHighFreq8)
     {
       g_pelBufOP.removeHighFreq8(dst, dstStride, src, srcStride, width, height);
+      return;
     }
-    else if (!(width & 3))
+    else if (!(width & 3) && g_pelBufOP.removeHighFreq4)
     {
       g_pelBufOP.removeHighFreq4(dst, dstStride, src, srcStride, width, height);
+      return;
     }
-    else
-    {
-      THROW("Not supported");
-    }
-  }
-  else
-  {
 #endif
-
-#define REM_HF_INC  \
-  src += srcStride; \
-  dst += dstStride; \
-
-#define REM_HF_OP_CLIP( ADDR ) dst[ADDR] = ClipPel<T>( 2 * dst[ADDR] - src[ADDR], clpRng )
-#define REM_HF_OP( ADDR )      dst[ADDR] =             2 * dst[ADDR] - src[ADDR]
-
-  if( bClip )
-  {
-    SIZE_AWARE_PER_EL_OP( REM_HF_OP_CLIP, REM_HF_INC );
+#define REM_HF_OP(ADDR) dst[ADDR] = 2 * dst[ADDR] - src[ADDR]
+    SIZE_AWARE_PER_EL_OP(REM_HF_OP, REM_HF_INC);
   }
   else
   {
-    SIZE_AWARE_PER_EL_OP( REM_HF_OP,      REM_HF_INC );
+#define REM_HF_OP_CLIP(ADDR) dst[ADDR] = ClipPel<T>(2 * dst[ADDR] - src[ADDR], clpRng)
+    SIZE_AWARE_PER_EL_OP(REM_HF_OP_CLIP, REM_HF_INC);
   }
 
 #undef REM_HF_INC
 #undef REM_HF_OP
 #undef REM_HF_OP_CLIP
-
-#if ENABLE_SIMD_OPT_BCW
-  }
-#endif
 }
 
 
diff --git a/source/Lib/CommonLib/CommonDef.h b/source/Lib/CommonLib/CommonDef.h
index b33afef68..31bd60473 100644
--- a/source/Lib/CommonLib/CommonDef.h
+++ b/source/Lib/CommonLib/CommonDef.h
@@ -384,9 +384,12 @@ static constexpr int BIO_TEMP_BUFFER_SIZE         =                     (MAX_CU_
 static constexpr int PROF_BORDER_EXT_W            =                     1;
 static constexpr int PROF_BORDER_EXT_H            =                     1;
 
-static constexpr int    BCW_NUM             = 5;             // the number of weight options
-static constexpr int    BCW_DEFAULT         = BCW_NUM / 2;   // Default weighting index representing for w=0.5
-static constexpr int    BCW_SIZE_CONSTRAINT = 256;           // disabling BCW if cu size is smaller than 256
+static constexpr int BCW_LOG2_WEIGHT_BASE = 3;
+static constexpr int BCW_WEIGHT_BASE      = 1 << BCW_LOG2_WEIGHT_BASE;
+static constexpr int BCW_NUM              = 5;             // the number of weight options
+static constexpr int BCW_DEFAULT          = BCW_NUM / 2;   // Default weighting index representing for w=0.5
+static constexpr int BCW_SIZE_CONSTRAINT  = 256;           // disabling Bcw if cu size is smaller than 256
+static constexpr int BCW_INV_BITS         = 16;
 static constexpr double BCW_COST_TH         = 1.05;
 
 static constexpr double AMVR_FAST_4PEL_TH = 1.06;
diff --git a/source/Lib/CommonLib/Rom.cpp b/source/Lib/CommonLib/Rom.cpp
index 01ee8947a..31b6d1a59 100644
--- a/source/Lib/CommonLib/Rom.cpp
+++ b/source/Lib/CommonLib/Rom.cpp
@@ -201,8 +201,7 @@ public:
     return rtn;
   }
 };
-const int8_t g_bcwLog2WeightBase       = 3;
-const int8_t g_bcwWeightBase           = (1 << g_bcwLog2WeightBase);
+
 const int8_t g_BcwWeights[BCW_NUM] = { -2, 3, 4, 5, 10 };
 const int8_t g_BcwSearchOrder[BCW_NUM] = { BCW_DEFAULT, BCW_DEFAULT - 2, BCW_DEFAULT + 2, BCW_DEFAULT - 1, BCW_DEFAULT + 1 };
 int8_t g_BcwCodingOrder[BCW_NUM];
@@ -212,7 +211,7 @@ int8_t getBcwWeight(uint8_t bcwIdx, uint8_t refFrameList)
 {
   // Weights for the model: p0 + w * (p1 - p0) = (1-w) * p0 + w * p1
   // Retuning  1-w for p0 or w for p1
-  return (refFrameList == REF_PIC_LIST_0 ? g_bcwWeightBase - g_BcwWeights[bcwIdx] : g_BcwWeights[bcwIdx]);
+  return (refFrameList == REF_PIC_LIST_0 ? BCW_WEIGHT_BASE - g_BcwWeights[bcwIdx] : g_BcwWeights[bcwIdx]);
 }
 
 void resetBcwCodingOrder(bool runDecoding, const CodingStructure &cs)
diff --git a/source/Lib/CommonLib/Rom.h b/source/Lib/CommonLib/Rom.h
index a00e94ac8..eb8968913 100644
--- a/source/Lib/CommonLib/Rom.h
+++ b/source/Lib/CommonLib/Rom.h
@@ -185,8 +185,6 @@ extern const uint32_t g_scalingListId[SCALING_LIST_SIZE_NUM][SCALING_LIST_NUM];
 
 extern MsgLevel g_verbosity;
 
-extern const int8_t g_bcwLog2WeightBase;
-extern const int8_t g_bcwWeightBase;
 extern const int8_t g_BcwWeights[BCW_NUM];
 extern const int8_t g_BcwSearchOrder[BCW_NUM];
 extern       int8_t g_BcwCodingOrder[BCW_NUM];
diff --git a/source/Lib/CommonLib/x86/BufferX86.h b/source/Lib/CommonLib/x86/BufferX86.h
index 92556c3a1..5c7f11057 100644
--- a/source/Lib/CommonLib/x86/BufferX86.h
+++ b/source/Lib/CommonLib/x86/BufferX86.h
@@ -1204,46 +1204,41 @@ void reco_SSE(const int16_t *src0, ptrdiff_t src0Stride, const int16_t *src1, pt
 
 #if ENABLE_SIMD_OPT_BCW
 template<X86_VEXT vext, int W>
-void removeWeightHighFreq_SSE(int16_t *src0, ptrdiff_t src0Stride, const int16_t *src1, ptrdiff_t src1Stride, int width,
-                              int height, int shift, int bcwWeight)
+void removeWeightHighFreq_SSE(int16_t *src0, ptrdiff_t src0Stride, const int16_t *src1, ptrdiff_t src1Stride, int width, int height,
+                              int bcwWeight, const Pel minVal, const Pel maxVal)
 {
-  int normalizer = ((1 << 16) + (bcwWeight>0 ? (bcwWeight >> 1) : -(bcwWeight >> 1))) / bcwWeight;
-  int weight0    = normalizer * (1 << g_bcwLog2WeightBase);
-  int weight1    = (g_bcwWeightBase - bcwWeight) * normalizer;
-  int offset = 1 << (shift - 1);
+  static_assert(W == 4 || W == 8, "W must be 4 or 8");
+
+  const int32_t w =
+    ((BCW_WEIGHT_BASE << BCW_INV_BITS) + (bcwWeight > 0 ? (bcwWeight >> 1) : -(bcwWeight >> 1))) / bcwWeight;
+
   if (W == 8)
   {
-    __m128i vzero = _mm_setzero_si128();
-    __m128i voffset = _mm_set1_epi32(offset);
-    __m128i vw0 = _mm_set1_epi32(weight0);
-    __m128i vw1 = _mm_set1_epi32(weight1);
-
     for (int row = 0; row < height; row++)
     {
       for (int col = 0; col < width; col += 8)
       {
-        __m128i vsrc0 = _mm_loadu_si128( (const __m128i *)&src0[col] );
-        __m128i vsrc1 = _mm_loadu_si128( (const __m128i *)&src1[col] );
-
-        __m128i vtmp, vdst, vsrc;
-        vdst = _mm_cvtepi16_epi32(vsrc0);
-        vsrc = _mm_cvtepi16_epi32(vsrc1);
-        vdst = _mm_mullo_epi32(vdst, vw0);
-        vsrc = _mm_mullo_epi32(vsrc, vw1);
-        vtmp = _mm_add_epi32(_mm_sub_epi32(vdst, vsrc), voffset);
-        vtmp = _mm_srai_epi32(vtmp, shift);
-
-        vsrc0 = _mm_unpackhi_epi64(vsrc0, vzero);
-        vsrc1 = _mm_unpackhi_epi64(vsrc1, vzero);
-        vdst = _mm_cvtepi16_epi32(vsrc0);
-        vsrc = _mm_cvtepi16_epi32(vsrc1);
-        vdst = _mm_mullo_epi32(vdst, vw0);
-        vsrc = _mm_mullo_epi32(vsrc, vw1);
-        vdst = _mm_add_epi32(_mm_sub_epi32(vdst, vsrc), voffset);
-        vdst = _mm_srai_epi32(vdst, shift);
-        vdst = _mm_packs_epi32(vtmp, vdst);
-
-        _mm_store_si128((__m128i *)&src0[col], vdst);
+        const __m128i vsrc0 = _mm_loadu_si128((const __m128i *) &src0[col]);
+        const __m128i vsrc1 = _mm_loadu_si128((const __m128i *) &src1[col]);
+
+        const __m128i diff = _mm_sub_epi16(vsrc0, vsrc1);
+
+        __m128i lo = _mm_cvtepi16_epi32(diff);
+        lo         = _mm_mullo_epi32(lo, _mm_set1_epi32(w));
+        lo         = _mm_add_epi32(lo, _mm_set1_epi32(1 << BCW_INV_BITS >> 1));
+        lo         = _mm_srai_epi32(lo, BCW_INV_BITS);
+
+        __m128i hi = _mm_cvtepi16_epi32(_mm_unpackhi_epi64(diff, diff));
+        hi         = _mm_mullo_epi32(hi, _mm_set1_epi32(w));
+        hi         = _mm_add_epi32(hi, _mm_set1_epi32(1 << BCW_INV_BITS >> 1));
+        hi         = _mm_srai_epi32(hi, BCW_INV_BITS);
+
+        __m128i res = _mm_packs_epi32(lo, hi);
+        res         = _mm_add_epi16(res, vsrc1);
+        res         = _mm_max_epi16(res, _mm_set1_epi16(minVal));
+        res         = _mm_min_epi16(res, _mm_set1_epi16(maxVal));
+
+        _mm_store_si128((__m128i *) &src0[col], res);
       }
       src0 += src0Stride;
       src1 += src1Stride;
@@ -1251,34 +1246,29 @@ void removeWeightHighFreq_SSE(int16_t *src0, ptrdiff_t src0Stride, const int16_t
   }
   else if (W == 4)
   {
-    __m128i vzero = _mm_setzero_si128();
-    __m128i voffset = _mm_set1_epi32(offset);
-    __m128i vw0 = _mm_set1_epi32(weight0);
-    __m128i vw1 = _mm_set1_epi32(weight1);
-
     for (int row = 0; row < height; row++)
     {
-      __m128i vsum = _mm_loadl_epi64((const __m128i *)src0);
-      __m128i vdst = _mm_loadl_epi64((const __m128i *)src1);
-
-      vsum = _mm_cvtepi16_epi32(vsum);
-      vdst = _mm_cvtepi16_epi32(vdst);
-      vsum = _mm_mullo_epi32(vsum, vw0);
-      vdst = _mm_mullo_epi32(vdst, vw1);
-      vsum = _mm_add_epi32(_mm_sub_epi32(vsum, vdst), voffset);
-      vsum = _mm_srai_epi32(vsum, shift);
-      vsum = _mm_packs_epi32(vsum, vzero);
+      const __m128i vsrc0 = _mm_loadl_epi64((const __m128i *) src0);
+      const __m128i vsrc1 = _mm_loadl_epi64((const __m128i *) src1);
 
-      _mm_storel_epi64((__m128i *)src0, vsum);
+      const __m128i diff = _mm_sub_epi16(vsrc0, vsrc1);
+
+      __m128i lo = _mm_cvtepi16_epi32(diff);
+      lo         = _mm_mullo_epi32(lo, _mm_set1_epi32(w));
+      lo         = _mm_add_epi32(lo, _mm_set1_epi32(1 << BCW_INV_BITS >> 1));
+      lo         = _mm_srai_epi32(lo, BCW_INV_BITS);
+
+      __m128i res = _mm_packs_epi32(lo, lo);
+      res         = _mm_add_epi16(res, vsrc1);
+      res         = _mm_max_epi16(res, _mm_set1_epi16(minVal));
+      res         = _mm_min_epi16(res, _mm_set1_epi16(maxVal));
+
+      _mm_storel_epi64((__m128i *) src0, res);
 
       src0 += src0Stride;
       src1 += src1Stride;
     }
   }
-  else
-  {
-    THROW("Unsupported size");
-  }
 }
 
 template<X86_VEXT vext, int W>
@@ -1307,6 +1297,8 @@ void removeHighFreq_SSE(int16_t *src0, ptrdiff_t src0Stride, const int16_t *src1
   }
   else if (W == 4)
   {
+    CHECK(width != 4, "width must be 4");
+
     for (int row = 0; row < height; row += 2)
     {
       __m128i vsrc0 = _mm_loadl_epi64((const __m128i *)src0);
@@ -1321,8 +1313,8 @@ void removeHighFreq_SSE(int16_t *src0, ptrdiff_t src0Stride, const int16_t *src1
       _mm_storel_epi64((__m128i *)src0, vsrc0);
       _mm_storel_epi64((__m128i *)(src0 + src0Stride), _mm_unpackhi_epi64(vsrc0, vsrc0));
 
-      src0 += (src0Stride << 1);
-      src1 += (src1Stride << 1);
+      src0 += 2 * src0Stride;
+      src1 += 2 * src1Stride;
     }
   }
   else
@@ -1546,16 +1538,18 @@ template<X86_VEXT vext, int W>
 void removeHighFreq_HBD_SIMD(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, ptrdiff_t src1Stride, int width,
                              int height)
 {
-  CHECK((width & 3), "the function only supports width multiple of 4");
+  CHECK((width & 3), "width must be a multiple of 4");
+
   for (int row = 0; row < height; row++)
   {
     int col = 0;
 #ifdef USE_AVX2
     if (vext >= AVX2)
     {
-      __m256i mm256_vsrc0, mm256_vsrc1;
-      for (; col < ((width >> 3) << 3); col += 8)
+      for (; col < (width & ~7); col += 8)
       {
+        __m256i mm256_vsrc0, mm256_vsrc1;
+
         mm256_vsrc0 = _mm256_lddqu_si256((const __m256i *)&src0[col]);
         mm256_vsrc1 = _mm256_lddqu_si256((const __m256i *)&src1[col]);
 
@@ -1564,9 +1558,10 @@ void removeHighFreq_HBD_SIMD(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, p
       }
     }
 #endif
-    __m128i vsrc0, vsrc1;
     for (; col < width; col += 4)
     {
+      __m128i vsrc0, vsrc1;
+
       vsrc0 = _mm_lddqu_si128((const __m128i *)&src0[col]);
       vsrc1 = _mm_lddqu_si128((const __m128i *)&src1[col]);
 
@@ -1579,41 +1574,40 @@ void removeHighFreq_HBD_SIMD(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, p
 }
 
 template<X86_VEXT vext, int W>
-void removeWeightHighFreq_HBD_SIMD(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, ptrdiff_t src1Stride, int width,
-                                   int height, int shift, int bcwWeight)
+void removeWeightHighFreq_HBD_SIMD(Pel *src0, ptrdiff_t src0Stride, const Pel *src1, ptrdiff_t src1Stride, int width, int height,
+                                   int bcwWeight, const Pel minVal, const Pel maxVal)
 {
   CHECK((width & 3), "the function only supports width multiple of 4");
 
-  int normalizer = ((1 << 16) + (bcwWeight > 0 ? (bcwWeight >> 1) : -(bcwWeight >> 1))) / bcwWeight;
-  int              weight0    = normalizer << g_bcwLog2WeightBase;
-  int              weight1    = (g_bcwWeightBase - bcwWeight) * normalizer;
-  Intermediate_Int offset = Intermediate_Int(1) << (shift - 1);
+  constexpr int s1 = (32 - BCW_INV_BITS) / 2;
+  constexpr int s2 = 32 - BCW_INV_BITS - s1;
+
+  const int32_t w =
+    ((BCW_WEIGHT_BASE << BCW_INV_BITS) + (bcwWeight > 0 ? (bcwWeight >> 1) : -(bcwWeight >> 1))) / bcwWeight << s1;
 
 #ifdef USE_AVX2
   if (vext >= AVX2)
   {
-    __m256i voffset = _mm256_set1_epi64x(offset);
-    __m256i vw0 = _mm256_set1_epi32(weight0);
-    __m256i vw1 = _mm256_set1_epi32(weight1);
-
-    __m256i vdst, vsrc;
     for (int row = 0; row < height; row++)
     {
       for (int col = 0; col < width; col += 4)
       {
-        __m256i vsrc0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_lddqu_si128((__m128i *)&src0[col])), _mm_lddqu_si128((__m128i *)&src0[col + 2]), 1);
-        __m256i vsrc1 = _mm256_inserti128_si256(_mm256_castsi128_si256(_mm_lddqu_si128((__m128i *)&src1[col])), _mm_lddqu_si128((__m128i *)&src1[col + 2]), 1);
-        vsrc0 = _mm256_shuffle_epi32(vsrc0, 0x50);
-        vsrc1 = _mm256_shuffle_epi32(vsrc1, 0x50);
-
-        vdst = _mm256_mul_epi32(vsrc0, vw0);
-        vsrc = _mm256_mul_epi32(vsrc1, vw1);
-        vdst = _mm256_add_epi64(_mm256_sub_epi64(vdst, vsrc), voffset);
-
-        *(src0 + col) = (Pel)(_mm256_extract_epi64(vdst, 0) >> shift);
-        *(src0 + col + 1) = (Pel)(_mm256_extract_epi64(vdst, 1) >> shift);
-        *(src0 + col + 2) = (Pel)(_mm256_extract_epi64(vdst, 2) >> shift);
-        *(src0 + col + 3) = (Pel)(_mm256_extract_epi64(vdst, 3) >> shift);
+        const __m128i vsrc0 = _mm_loadu_si128((const __m128i *) &src0[col]);
+        const __m128i vsrc1 = _mm_loadu_si128((const __m128i *) &src1[col]);
+
+        const __m128i diff = _mm_slli_epi32(_mm_sub_epi32(vsrc0, vsrc1), s2);
+
+        __m256i tmp = _mm256_cvtepi32_epi64(diff);
+        tmp         = _mm256_mul_epi32(tmp, _mm256_set1_epi32(w));
+        tmp         = _mm256_add_epi64(tmp, _mm256_set1_epi64x(1u << 31));
+        tmp         = _mm256_permutevar8x32_epi32(tmp, _mm256_setr_epi32(1, 3, 5, 7, 0, 2, 4, 6));
+
+        __m128i res = _mm256_castsi256_si128(tmp);
+        res         = _mm_add_epi32(res, vsrc1);
+        res         = _mm_min_epi32(res, _mm_set1_epi32(maxVal));
+        res         = _mm_max_epi32(res, _mm_set1_epi32(minVal));
+
+        _mm_storeu_si128((__m128i *) &src0[col], res);
       }
       src0 += src0Stride;
       src1 += src1Stride;
@@ -1622,34 +1616,28 @@ void removeWeightHighFreq_HBD_SIMD(Pel *src0, ptrdiff_t src0Stride, const Pel *s
   else
 #endif
   {
-    __m128i voffset = _mm_set_epi64x(offset, offset);
-    __m128i vw0 = _mm_set1_epi32(weight0);
-    __m128i vw1 = _mm_set1_epi32(weight1);
-
-    __m128i vdst, vsrc;
     for (int row = 0; row < height; row++)
     {
       for (int col = 0; col < width; col += 4)
       {
-        __m128i vsrc0 = _mm_lddqu_si128((__m128i *)&src0[col]);
-        __m128i vsrc1 = _mm_lddqu_si128((__m128i *)&src1[col]);
+        const __m128i vsrc0 = _mm_loadu_si128((const __m128i *) &src0[col]);
+        const __m128i vsrc1 = _mm_loadu_si128((const __m128i *) &src1[col]);
 
-        vdst = _mm_mul_epi32(vsrc0, vw0);
-        vsrc = _mm_mul_epi32(vsrc1, vw1);
-        vdst = _mm_add_epi64(_mm_sub_epi64(vdst, vsrc), voffset);
+        const __m128i diff = _mm_slli_epi32(_mm_sub_epi32(vsrc0, vsrc1), s2);
 
-        *(src0 + col) = (Pel)(_mm_extract_epi64(vdst, 0) >> shift);
-        *(src0 + col + 2) = (Pel)(_mm_extract_epi64(vdst, 1) >> shift);
+        __m128i lo = _mm_mul_epi32(diff, _mm_set1_epi32(w));
+        lo         = _mm_add_epi64(lo, _mm_set1_epi64x(1u << 31));
+        lo         = _mm_srli_si128(lo, 4);
 
-        vsrc0 = _mm_srli_si128(vsrc0, 4);
-        vsrc1 = _mm_srli_si128(vsrc1, 4);
+        __m128i hi = _mm_mul_epi32(_mm_srli_si128(diff, 4), _mm_set1_epi32(w));
+        hi         = _mm_add_epi64(hi, _mm_set1_epi64x(1u << 31));
 
-        vdst = _mm_mul_epi32(vsrc0, vw0);
-        vsrc = _mm_mul_epi32(vsrc1, vw1);
-        vdst = _mm_add_epi64(_mm_sub_epi64(vdst, vsrc), voffset);
+        __m128i res = _mm_blend_epi16(lo, hi, 0xcc);
+        res         = _mm_add_epi32(res, vsrc1);
+        res         = _mm_min_epi32(res, _mm_set1_epi32(maxVal));
+        res         = _mm_max_epi32(res, _mm_set1_epi32(minVal));
 
-        *(src0 + col + 1) = (Pel)(_mm_extract_epi64(vdst, 0) >> shift);
-        *(src0 + col + 3) = (Pel)(_mm_extract_epi64(vdst, 1) >> shift);
+        _mm_storeu_si128((__m128i *) &src0[col], res);
       }
       src0 += src0Stride;
       src1 += src1Stride;
diff --git a/source/Lib/EncoderLib/InterSearch.cpp b/source/Lib/EncoderLib/InterSearch.cpp
index 22ed2e20c..266287068 100644
--- a/source/Lib/EncoderLib/InterSearch.cpp
+++ b/source/Lib/EncoderLib/InterSearch.cpp
@@ -11046,7 +11046,7 @@ double InterSearch::xGetMEDistortionWeight(uint8_t bcwIdx, RefPicList eRefPicLis
 {
   if( bcwIdx != BCW_DEFAULT )
   {
-    return fabs((double) getBcwWeight(bcwIdx, eRefPicList) / (double) g_bcwWeightBase);
+    return (double) abs(getBcwWeight(bcwIdx, eRefPicList)) / BCW_WEIGHT_BASE;
   }
   else
   {
-- 
GitLab