From 35500e8e8589037a51313d4d958215dd1ab3e3d0 Mon Sep 17 00:00:00 2001 From: Vadim Seregin <vseregin@qti.qualcomm.com> Date: Wed, 22 Jun 2022 00:18:17 +0000 Subject: [PATCH] Fix: Add storing and loading CABAC windows for temporal CABAC --- source/Lib/CommonLib/Contexts.cpp | 21 ++++++++++++++++----- source/Lib/CommonLib/Contexts.h | 26 +++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/source/Lib/CommonLib/Contexts.cpp b/source/Lib/CommonLib/Contexts.cpp index 6753b760d..d571266a5 100644 --- a/source/Lib/CommonLib/Contexts.cpp +++ b/source/Lib/CommonLib/Contexts.cpp @@ -5452,18 +5452,29 @@ void CtxStore<BinProbModel>::init( int qp, int initId ) } } +#if JVET_Z0135_TEMP_CABAC_WIN_WEIGHT template <class BinProbModel> -void CtxStore<BinProbModel>::setWinSizes( const std::vector<uint8_t>& log2WindowSizes ) +void CtxStore<BinProbModel>::saveWinSizes( std::vector<uint8_t>& windows ) const { - CHECK( m_CtxBuffer.size() != log2WindowSizes.size(), - "Size of window size table (" << log2WindowSizes.size() << ") does not match size of context buffer (" << m_CtxBuffer.size() << ")." ); + windows.resize( m_CtxBuffer.size(), uint8_t( 0 ) ); + for( std::size_t k = 0; k < m_CtxBuffer.size(); k++ ) { - m_CtxBuffer[k].setLog2WindowSize( log2WindowSizes[k] ); + windows[k] = m_CtxBuffer[k].getWinSizes(); + } +} + +template <class BinProbModel> +void CtxStore<BinProbModel>::loadWinSizes( const std::vector<uint8_t>& windows ) +{ + CHECK( m_CtxBuffer.size() != windows.size(), + "Size of prob states table (" << windows.size() << ") does not match size of context buffer (" << m_CtxBuffer.size() << ")." ); + for( std::size_t k = 0; k < m_CtxBuffer.size(); k++ ) + { + m_CtxBuffer[k].setWinSizes( windows[k] ); } } -#if JVET_Z0135_TEMP_CABAC_WIN_WEIGHT template <class BinProbModel> void CtxStore<BinProbModel>::loadWeights( const std::vector<uint8_t>& weights ) { diff --git a/source/Lib/CommonLib/Contexts.h b/source/Lib/CommonLib/Contexts.h index 51e70a89e..76157b737 100644 --- a/source/Lib/CommonLib/Contexts.h +++ b/source/Lib/CommonLib/Contexts.h @@ -167,6 +167,8 @@ public: #if JVET_Z0135_TEMP_CABAC_WIN_WEIGHT void setAdaptRateWeight( uint8_t weight ) { m_weight = weight; } uint8_t getAdaptRateWeight() const { return m_weight; } + void setWinSizes( uint8_t rate ) { m_rate = rate; } + uint8_t getWinSizes() const { return m_rate; } void setAdaptRateOffset(uint8_t rateOffset, bool bin ) { m_rateOffset[bin] = rateOffset;} #endif @@ -524,8 +526,9 @@ public: void copyFrom ( const CtxStore<BinProbModel>& src ) { checkInit(); ::memcpy( m_Ctx, src.m_Ctx, sizeof( BinProbModel ) * ContextSetCfg::NumberOfContexts ); } void copyFrom ( const CtxStore<BinProbModel>& src, const CtxSet& ctxSet ) { checkInit(); ::memcpy( m_Ctx+ctxSet.Offset, src.m_Ctx+ctxSet.Offset, sizeof( BinProbModel ) * ctxSet.Size ); } void init ( int qp, int initId ); - void setWinSizes( const std::vector<uint8_t>& log2WindowSizes ); #if JVET_Z0135_TEMP_CABAC_WIN_WEIGHT + void loadWinSizes( const std::vector<uint8_t>& windows ); + void saveWinSizes( std::vector<uint8_t>& windows ) const; void loadWeights( const std::vector <uint8_t>& weights ); void saveWeights( std::vector<uint8_t>& weights ) const; void loadPStates( const std::vector<std::pair<uint16_t, uint16_t>>& probStates ); @@ -629,6 +632,24 @@ public: } } + void loadWinSizes( const std::vector<uint8_t>& windows ) + { + switch( m_BPMType ) + { + case BPM_Std: m_CtxStore_Std.loadWinSizes( windows ); break; + default: break; + } + } + + void saveWinSizes( std::vector<uint8_t>& windows ) const + { + switch( m_BPMType ) + { + case BPM_Std: m_CtxStore_Std.saveWinSizes( windows ); break; + default: break; + } + } + void loadPStates( const std::vector<std::pair<uint16_t, uint16_t>>& probStates ) { switch( m_BPMType ) @@ -747,6 +768,7 @@ public: { ctx.loadPStates( m_states ); ctx.loadWeights( m_weights ); + ctx.loadWinSizes( m_rate ); return true; } return false; @@ -755,6 +777,7 @@ public: { ctx.savePStates( m_states ); ctx.saveWeights( m_weights ); + ctx.saveWinSizes( m_rate ); m_valid = true; } @@ -762,6 +785,7 @@ private: std::vector<std::pair<uint16_t, uint16_t>> m_states; bool m_valid; std::vector<uint8_t> m_weights; + std::vector<uint8_t> m_rate; }; class CtxStateArray -- GitLab