From 35500e8e8589037a51313d4d958215dd1ab3e3d0 Mon Sep 17 00:00:00 2001
From: Vadim Seregin <vseregin@qti.qualcomm.com>
Date: Wed, 22 Jun 2022 00:18:17 +0000
Subject: [PATCH] Fix: Add storing and loading CABAC windows for temporal CABAC

---
 source/Lib/CommonLib/Contexts.cpp | 21 ++++++++++++++++-----
 source/Lib/CommonLib/Contexts.h   | 26 +++++++++++++++++++++++++-
 2 files changed, 41 insertions(+), 6 deletions(-)

diff --git a/source/Lib/CommonLib/Contexts.cpp b/source/Lib/CommonLib/Contexts.cpp
index 6753b760d..d571266a5 100644
--- a/source/Lib/CommonLib/Contexts.cpp
+++ b/source/Lib/CommonLib/Contexts.cpp
@@ -5452,18 +5452,29 @@ void CtxStore<BinProbModel>::init( int qp, int initId )
   }
 }
 
+#if JVET_Z0135_TEMP_CABAC_WIN_WEIGHT
 template <class BinProbModel>
-void CtxStore<BinProbModel>::setWinSizes( const std::vector<uint8_t>& log2WindowSizes )
+void CtxStore<BinProbModel>::saveWinSizes( std::vector<uint8_t>& windows ) const
 {
-  CHECK( m_CtxBuffer.size() != log2WindowSizes.size(),
-        "Size of window size table (" << log2WindowSizes.size() << ") does not match size of context buffer (" << m_CtxBuffer.size() << ")." );
+  windows.resize( m_CtxBuffer.size(), uint8_t( 0 ) );
+
   for( std::size_t k = 0; k < m_CtxBuffer.size(); k++ )
   {
-    m_CtxBuffer[k].setLog2WindowSize( log2WindowSizes[k] );
+    windows[k] = m_CtxBuffer[k].getWinSizes();
+  }
+}
+
+template <class BinProbModel>
+void CtxStore<BinProbModel>::loadWinSizes( const std::vector<uint8_t>& windows )
+{
+  CHECK( m_CtxBuffer.size() != windows.size(),
+         "Size of prob states table (" << windows.size() << ") does not match size of context buffer (" << m_CtxBuffer.size() << ")." );
+  for( std::size_t k = 0; k < m_CtxBuffer.size(); k++ )
+  {
+    m_CtxBuffer[k].setWinSizes( windows[k] );
   }
 }
 
-#if JVET_Z0135_TEMP_CABAC_WIN_WEIGHT
 template <class BinProbModel>
 void CtxStore<BinProbModel>::loadWeights( const std::vector<uint8_t>& weights )
 {
diff --git a/source/Lib/CommonLib/Contexts.h b/source/Lib/CommonLib/Contexts.h
index 51e70a89e..76157b737 100644
--- a/source/Lib/CommonLib/Contexts.h
+++ b/source/Lib/CommonLib/Contexts.h
@@ -167,6 +167,8 @@ public:
 #if JVET_Z0135_TEMP_CABAC_WIN_WEIGHT
   void     setAdaptRateWeight( uint8_t weight ) { m_weight = weight; }
   uint8_t  getAdaptRateWeight() const           { return m_weight; }
+  void     setWinSizes( uint8_t rate )          { m_rate = rate; }
+  uint8_t  getWinSizes() const                  { return m_rate; }
   void     setAdaptRateOffset(uint8_t rateOffset, bool bin ) { m_rateOffset[bin] = rateOffset;}
 #endif
 
@@ -524,8 +526,9 @@ public:
   void copyFrom   ( const CtxStore<BinProbModel>& src )                        { checkInit(); ::memcpy( m_Ctx,               src.m_Ctx,               sizeof( BinProbModel ) * ContextSetCfg::NumberOfContexts ); }
   void copyFrom   ( const CtxStore<BinProbModel>& src, const CtxSet& ctxSet )  { checkInit(); ::memcpy( m_Ctx+ctxSet.Offset, src.m_Ctx+ctxSet.Offset, sizeof( BinProbModel ) * ctxSet.Size ); }
   void init       ( int qp, int initId );
-  void setWinSizes( const std::vector<uint8_t>&   log2WindowSizes );
 #if JVET_Z0135_TEMP_CABAC_WIN_WEIGHT
+  void loadWinSizes( const std::vector<uint8_t>&   windows );
+  void saveWinSizes( std::vector<uint8_t>&         windows ) const;
   void loadWeights( const std::vector <uint8_t>&  weights );
   void saveWeights( std::vector<uint8_t>&         weights )  const;
   void loadPStates( const std::vector<std::pair<uint16_t, uint16_t>>&  probStates );
@@ -629,6 +632,24 @@ public:
     }
   }
 
+  void  loadWinSizes( const std::vector<uint8_t>& windows )
+  {
+    switch( m_BPMType )
+    {
+    case BPM_Std:   m_CtxStore_Std.loadWinSizes( windows );  break;
+    default:        break;
+    }
+  }
+
+  void  saveWinSizes( std::vector<uint8_t>& windows ) const
+  {
+    switch( m_BPMType )
+    {
+    case BPM_Std:   m_CtxStore_Std.saveWinSizes( windows );  break;
+    default:        break;
+    }
+  }
+
   void  loadPStates( const std::vector<std::pair<uint16_t, uint16_t>>& probStates )
   {
     switch( m_BPMType )
@@ -747,6 +768,7 @@ public:
     {
       ctx.loadPStates( m_states );
       ctx.loadWeights( m_weights );
+      ctx.loadWinSizes( m_rate );
       return true;
     }
     return false;
@@ -755,6 +777,7 @@ public:
   {
     ctx.savePStates( m_states );
     ctx.saveWeights( m_weights );
+    ctx.saveWinSizes( m_rate );
     m_valid = true;
   }
 
@@ -762,6 +785,7 @@ private:
   std::vector<std::pair<uint16_t, uint16_t>> m_states;
   bool                 m_valid;
   std::vector<uint8_t> m_weights;
+  std::vector<uint8_t> m_rate;
 };
 
 class CtxStateArray
-- 
GitLab