From e04ad0cf27a29bb2f84c8324e38f74d261ef6c80 Mon Sep 17 00:00:00 2001
From: jamesxxiu <xiaoyuxiu@kwai.com>
Date: Tue, 2 Apr 2019 02:08:48 -0700
Subject: [PATCH] JVET-N0335: MV Rounding Unification

---
 source/Lib/CommonLib/Mv.cpp           |  5 +++++
 source/Lib/CommonLib/Mv.h             | 28 +++++++++++++++++++++++++++
 source/Lib/CommonLib/TypeDef.h        |  2 ++
 source/Lib/EncoderLib/InterSearch.cpp | 12 ++++++++++++
 4 files changed, 47 insertions(+)

diff --git a/source/Lib/CommonLib/Mv.cpp b/source/Lib/CommonLib/Mv.cpp
index 732e756b5..73956d8f3 100644
--- a/source/Lib/CommonLib/Mv.cpp
+++ b/source/Lib/CommonLib/Mv.cpp
@@ -45,8 +45,13 @@ const MvPrecision Mv::m_amvrPrecision[3] = { MV_PRECISION_QUARTER, MV_PRECISION_
 void roundAffineMv( int& mvx, int& mvy, int nShift )
 {
   const int nOffset = 1 << (nShift - 1);
+#if JVET_N0335_N0085_MV_ROUNDING
+  mvx = (mvx + nOffset - (mvx >= 0)) >> nShift;
+  mvy = (mvy + nOffset - (mvy >= 0)) >> nShift;
+#else
   mvx = mvx >= 0 ? (mvx + nOffset) >> nShift : -((-mvx + nOffset) >> nShift);
   mvy = mvy >= 0 ? (mvy + nOffset) >> nShift : -((-mvy + nOffset) >> nShift);
+#endif
 }
 
 void clipMv( Mv& rcMv, const Position& pos,
diff --git a/source/Lib/CommonLib/Mv.h b/source/Lib/CommonLib/Mv.h
index 51d08d682..440f9d77a 100644
--- a/source/Lib/CommonLib/Mv.h
+++ b/source/Lib/CommonLib/Mv.h
@@ -121,6 +121,14 @@ public:
   //! shift right with rounding
   void divideByPowerOf2 (const int i)
   {
+#if JVET_N0335_N0085_MV_ROUNDING
+    if (i != 0)
+    {
+      const int offset = (1 << (i - 1));
+      hor = (hor + offset - (hor >= 0)) >> i;
+      ver = (ver + offset - (ver >= 0)) >> i;
+    }
+#else
 #if ME_ENABLE_ROUNDING_OF_MVS
     const int offset = (i == 0) ? 0 : 1 << (i - 1);
     hor += offset;
@@ -128,6 +136,7 @@ public:
 #endif
     hor >>= i;
     ver >>= i;
+#endif
   }
 
   const Mv& operator<<= (const int i)
@@ -139,8 +148,17 @@ public:
 
   const Mv& operator>>= ( const int i )
   {
+#if JVET_N0335_N0085_MV_ROUNDING
+    if (i != 0)
+    {
+      const int offset = (1 << (i - 1));
+      hor = (hor + offset - (hor >= 0)) >> i;
+      ver = (ver + offset - (ver >= 0)) >> i;
+  }
+#else
     hor >>= i;
     ver >>= i;
+#endif
     return  *this;
   }
 
@@ -166,8 +184,13 @@ public:
 
   const Mv scaleMv( int iScale ) const
   {
+#if JVET_N0335_N0085_MV_ROUNDING
+    const int mvx = Clip3(-131072, 131071, (iScale * getHor() + 128 - (iScale * getHor() >= 0)) >> 8);
+    const int mvy = Clip3(-131072, 131071, (iScale * getVer() + 128 - (iScale * getVer() >= 0)) >> 8);
+#else
     const int mvx = Clip3( -131072, 131071, (iScale * getHor() + 127 + (iScale * getHor() < 0)) >> 8 );
     const int mvy = Clip3( -131072, 131071, (iScale * getVer() + 127 + (iScale * getVer() < 0)) >> 8 );
+#endif
     return Mv( mvx, mvy );
   }
 
@@ -182,8 +205,13 @@ public:
     {
       const int rightShift = -shift;
       const int nOffset = 1 << (rightShift - 1);
+#if JVET_N0335_N0085_MV_ROUNDING
+      hor = hor >= 0 ? (hor + nOffset - 1) >> rightShift : (hor + nOffset) >> rightShift;
+      ver = ver >= 0 ? (ver + nOffset - 1) >> rightShift : (ver + nOffset) >> rightShift;
+#else
       hor = hor >= 0 ? (hor + nOffset) >> rightShift : -((-hor + nOffset) >> rightShift);
       ver = ver >= 0 ? (ver + nOffset) >> rightShift : -((-ver + nOffset) >> rightShift);
+#endif
     }
   }
 
diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h
index 80270fd95..778d5c32c 100644
--- a/source/Lib/CommonLib/TypeDef.h
+++ b/source/Lib/CommonLib/TypeDef.h
@@ -50,6 +50,8 @@
 #include <assert.h>
 #include <cassert>
 
+#define JVET_N0335_N0085_MV_ROUNDING                      1  // MV rounding unification
+
 #define JVET_N0477_LMCS_CLEANUP                           1
 #define JVET_N0220_LMCS_SIMPLIFICATION                    1
 
diff --git a/source/Lib/EncoderLib/InterSearch.cpp b/source/Lib/EncoderLib/InterSearch.cpp
index dd85f3c41..182e93a7a 100644
--- a/source/Lib/EncoderLib/InterSearch.cpp
+++ b/source/Lib/EncoderLib/InterSearch.cpp
@@ -1527,7 +1527,13 @@ void InterSearch::xxIBCHashSearch(PredictionUnit& pu, Mv* mvPred, int numMvPred,
             int imvShift = 2;
             int offset = 1 << (imvShift - 1);
 
+#if JVET_N0335_N0085_MV_ROUNDING
+            int x = (mvPred[n].hor + offset - (mvPred[n].hor >= 0)) >> 2;
+            int y = (mvPred[n].ver + offset - (mvPred[n].ver >= 0)) >> 2;
+            mvPredQuadPel.set(x, y);
+#else            
             mvPredQuadPel.set(((mvPred[n].hor + offset) >> 2), ((mvPred[n].ver + offset) >> 2));
+#endif
 
             m_pcRdCost->setPredictor(mvPredQuadPel);
 
@@ -4264,8 +4270,14 @@ void InterSearch::xPredAffineInterSearch( PredictionUnit&       pu,
         int shift = MAX_CU_DEPTH;
         int vx2 = (mvFour[0].getHor() << shift) - ((mvFour[1].getVer() - mvFour[0].getVer()) << (shift + g_aucLog2[pu.lheight()] - g_aucLog2[pu.lwidth()]));
         int vy2 = (mvFour[0].getVer() << shift) + ((mvFour[1].getHor() - mvFour[0].getHor()) << (shift + g_aucLog2[pu.lheight()] - g_aucLog2[pu.lwidth()]));
+#if JVET_N0335_N0085_MV_ROUNDING
+        int offset = (1 << (shift - 1));
+        vx2 = (vx2 + offset - (vx2 >= 0)) >> shift;
+        vy2 = (vy2 + offset - (vy2 >= 0)) >> shift;
+#else
         vx2 >>= shift;
         vy2 >>= shift;
+#endif
         mvFour[2].hor = vx2;
         mvFour[2].ver = vy2;
         mvFour[2].clipToStorageBitDepth();
-- 
GitLab