From 6d86c1d4c8dcb992d1db9d8c42a5324ccc971d9f Mon Sep 17 00:00:00 2001
From: Franck Galpin <franck.galpin@interdigital.com>
Date: Thu, 2 Feb 2023 12:13:30 +0000
Subject: [PATCH] isolate sadl code in cpp unit and use aggressive float
 optimization only on these units

---
 CMakeLists.txt                          |  2 +-
 Makefile                                |  2 +-
 sadl                                    |  2 +-
 source/Lib/CommonLib/CMakeLists.txt     | 10 ++--
 source/Lib/CommonLib/NNFilterSet0.cpp   | 26 ++++++----
 source/Lib/CommonLib/NNFilterSet0.h     | 18 ++++---
 source/Lib/CommonLib/NNInference.cpp    | 66 +++++++++++++++++++++++++
 source/Lib/CommonLib/NNInference.h      | 27 +++++-----
 source/Lib/EncoderLib/CMakeLists.txt    |  7 ++-
 source/Lib/EncoderLib/EncNNFilterSet1.h |  6 ++-
 10 files changed, 127 insertions(+), 39 deletions(-)
 create mode 100644 source/Lib/CommonLib/NNInference.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4f2b6f9d92..c8e195113e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -103,7 +103,7 @@ bb_enable_warnings( msvc warnings-as-errors "/wd4996" )
 # enable sse4.1 build for all source files for gcc and clang
 if( UNIX OR MINGW )
   add_compile_options( "-msse4.1" )
-  set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffast-math -fstrict-aliasing" )
+  set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}  -fstrict-aliasing" )
 endif()
 
 # enable parallel build for Visual Studio
diff --git a/Makefile b/Makefile
index ce1b38b89a..7252ce6ae2 100644
--- a/Makefile
+++ b/Makefile
@@ -18,7 +18,7 @@ BUILD_SCRIPT := $(CURDIR)/cmake/CMakeBuild/bin/cmake.py
 
 TARGETS := CommonLib DecoderAnalyserApp DecoderAnalyserLib DecoderApp DecoderLib 
 TARGETS += EncoderApp EncoderLib Utilities SEIRemovalApp StreamMergeApp
-SADL_HASH := "dedd0095bd3f8ca0a760eb0649b028a87b84dcfd" # hard coded because of windows $(shell git submodule status | grep sadl | cut -d' ' -f2)
+SADL_HASH := "v4rc2" # hard coded because of windows $(shell git submodule status | grep sadl | cut -d' ' -f2)
 
 ifeq ($(OS),Windows_NT)
   ifneq ($(MSYSTEM),)
diff --git a/sadl b/sadl
index e562317e36..309d584ec9 160000
--- a/sadl
+++ b/sadl
@@ -1 +1 @@
-Subproject commit e562317e36575995c29ad2ed0a0d315ea17d2165
+Subproject commit 309d584ec9ca874092a515580b6f4ec23a399c79
diff --git a/source/Lib/CommonLib/CMakeLists.txt b/source/Lib/CommonLib/CMakeLists.txt
index 8894f6dd8e..12e468c597 100644
--- a/source/Lib/CommonLib/CMakeLists.txt
+++ b/source/Lib/CommonLib/CMakeLists.txt
@@ -110,11 +110,13 @@ if( MSVC )
   set_property( SOURCE NNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ")
 elseif( UNIX OR MINGW )
   if( NNLF_BUILD_WITH_AVX512 STREQUAL "1" )
-    set_property( SOURCE NNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw")
-    set_property( SOURCE NNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw")
+    set_property( SOURCE NNInference.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw -ffast-math")
+    set_property( SOURCE NNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw -ffast-math")
+    set_property( SOURCE NNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw -ffast-math")
   else()
-    set_property( SOURCE NNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2")
-    set_property( SOURCE NNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2")
+    set_property( SOURCE NNInference.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2 -ffast-math")
+    set_property( SOURCE NNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2 -ffast-math")
+    set_property( SOURCE NNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2 -ffast-math")
   endif()
 endif()
 
diff --git a/source/Lib/CommonLib/NNFilterSet0.cpp b/source/Lib/CommonLib/NNFilterSet0.cpp
index 1807fa37cd..287519fe45 100644
--- a/source/Lib/CommonLib/NNFilterSet0.cpp
+++ b/source/Lib/CommonLib/NNFilterSet0.cpp
@@ -36,12 +36,17 @@
 */
 
 #include "NNFilterSet0.h"
+#define HAVE_INTTYPES_H 1
+#define __STDC_FORMAT_MACROS
+#include <sadl/model.h>
+#include "NNInference.h"
 
 #if NN_FILTERING_SET_0
 #include "CodingStructure.h"
 #include "Picture.h"
+using namespace std;
 
-NNFilterSet0::NNFilterSet0()
+NNFilterSet0::NNFilterSet0():m_Module(std::make_unique<sadl::Model<TypeSadl>>())
 {
   for( int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++ )
   {
@@ -50,6 +55,9 @@ NNFilterSet0::NNFilterSet0()
   m_initFlag = false;
 }
 
+
+NNFilterSet0::~NNFilterSet0() {}
+
 void NNFilterSet0::PreCNNLFProcess(Picture* pic, CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam)
 {
   initCnnModel();
@@ -363,7 +371,7 @@ void NNFilterSet0::runCNNLF(Picture* pic, PelUnitBuf& cnnUnitBuf, const int base
       const UnitArea inferBlock(cs.area.chromaFormat, Area(st_w, st_h, actualPatchSizeW, actualPatchSizeH));
       initPatch(actualPatchSizeW, actualPatchSizeH);
       NNInference::prepareInputs<TypeSadl>(pic, inferBlock, m_Input, baseQPFinal, sliceQp, slice_type, listInputData);
-      NNInference::infer<TypeSadl>(m_Module, m_Input);
+      NNInference::infer<TypeSadl>(*m_Module, m_Input);
 
       // extract the results
 #if JVET_AB0083_QPADJ
@@ -423,7 +431,7 @@ void NNFilterSet0::extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_
         int id_y = pix_y + (yy << 1) + kk / 2;
 
 #if NN_FIXED_POINT_IMPLEMENTATION
-        cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int(m_Module.result(0)(0, (id_y - st_h) >> 1, (id_x - st_w) >> 1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) << in_left_shift)) << out_left_shift));
+        cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int(m_Module->result(0)(0, (id_y - st_h) >> 1, (id_x - st_w) >> 1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) << in_left_shift)) << out_left_shift));
 #else
         cnnUnitBuf.get(COMPONENT_Y).at(id_x, id_y) = Pel(Clip3<int>(0, out_maxValue, int((m_Module.result(0)(0, (id_y - st_h) >> 1, (id_x - st_w)>>1, kk) + (recUnitBuf.get(COMPONENT_Y).at(id_x, id_y) / in_maxValue)) * out_maxValue + 0.5)));
 #endif
@@ -455,7 +463,7 @@ void NNFilterSet0::extractOutputs(Picture* pic, int pix_x, int pix_y, int pix_x_
         {
           int pos_x = ((pix_x - st_w) >> 1) + xx;
           int pos_y = ((pix_y - st_h) >> 1) + yy;
-          sample = sample + m_Module.result(0)(0, pos_y, pos_x, compID*4 + kk);
+          sample = sample + m_Module->result(0)(0, pos_y, pos_x, compID*4 + kk);
         }
         sample = sample / 4;
 
@@ -528,7 +536,7 @@ void NNFilterSet0::scaleResidue(CodingStructure& cs, PelUnitBuf recUnitBuf, PelU
 
 void NNFilterSet0::initPatch(const int PatchWidth, const int PatchHeight)
 {
-  m_Input                 = m_Module.getInputsTemplate();
+  m_Input                 = m_Module->getInputsTemplate();
   unsigned int m_Input_id = 0;
   for (auto &t: m_Input)
   {
@@ -542,7 +550,7 @@ void NNFilterSet0::initPatch(const int PatchWidth, const int PatchHeight)
     }
     m_Input_id++;
   }
-  if (!m_Module.init(m_Input))
+  if (!m_Module->init(m_Input))
   {
     cerr << "[ERROR] issue during initialization" << endl;
     exit(-1);
@@ -565,14 +573,14 @@ void NNFilterSet0::initCnnModel()
 #endif
 
   ifstream file(ModelPath, ios::binary);
-  if (!m_Module.load(file)) {
+  if (!m_Module->load(file)) {
     cerr << "[ERROR] Unable to read model " << ModelPath << endl;
     exit(-1);
   }
 
-  m_Input = m_Module.getInputsTemplate();
+  m_Input = m_Module->getInputsTemplate();
 
-  if (!m_Module.init(m_Input)) {
+  if (!m_Module->init(m_Input)) {
     cerr << "[ERROR] issue during initialization" << endl;
     exit(-1);
   }
diff --git a/source/Lib/CommonLib/NNFilterSet0.h b/source/Lib/CommonLib/NNFilterSet0.h
index b8837ce27e..935f294597 100644
--- a/source/Lib/CommonLib/NNFilterSet0.h
+++ b/source/Lib/CommonLib/NNFilterSet0.h
@@ -44,18 +44,20 @@
 
 #include "Unit.h"
 #include "Picture.h"
-#define HAVE_INTTYPES_H 1
-#define __STDC_FORMAT_MACROS
-#include <sadl/model.h>
-#include "NNInference.h"
 #include <fstream>
-using namespace std;
+
+// fwd
+namespace sadl {
+template<typename T> class Model;
+template<typename T> class Tensor;
+}
+
 
 class NNFilterSet0
 {
 public:
   NNFilterSet0();
-  virtual ~NNFilterSet0() {}
+  virtual ~NNFilterSet0();
   void create(const int picWidth, const int picHeight, const ChromaFormat format, const int maxCUWidth, const int maxCUHeight, const int maxCUDepth, const int inputBitDepth[MAX_NUM_CHANNEL_TYPE], std::string path);
   void destroy();
   void PreCNNLFProcess(Picture* pic, CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam);
@@ -88,8 +90,8 @@ protected:
   bool                         m_initFlag;
 
   std::string m_ModelPath;
-  sadl::Model<TypeSadl> m_Module;
-  vector<sadl::Tensor<TypeSadl>> m_Input;
+  std::unique_ptr<sadl::Model<TypeSadl>> m_Module;
+  std::vector<sadl::Tensor<TypeSadl>> m_Input;
 
   void filterPic(CodingStructure& cs, CnnlfSliceParam& cnnlfSliceParam);
   void filterBlk(PelUnitBuf &recDst, const CPelUnitBuf& recSrc, const Area& blk, const ComponentID compI, const ClpRng& clpRng);
diff --git a/source/Lib/CommonLib/NNInference.cpp b/source/Lib/CommonLib/NNInference.cpp
new file mode 100644
index 0000000000..7f1f144534
--- /dev/null
+++ b/source/Lib/CommonLib/NNInference.cpp
@@ -0,0 +1,66 @@
+/* The copyright in this software is being made available under the BSD
+ * License, included below. This software may be subject to other third party
+ * and contributor rights, including patent rights, and no such rights are
+ * granted under this license.
+ *
+ * Copyright (c) 2010-2020, ITU/ISO/IEC
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ *  * Redistributions of source code must retain the above copyright notice,
+ *    this list of conditions and the following disclaimer.
+ *  * Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *  * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
+ *    be used to endorse or promote products derived from this software without
+ *    specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+ * THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+/** \file     NNInference.h
+    \brief    neural network-based inference class (header)
+*/
+#include <sadl/model.h>
+#include "NNInference.h"
+//! \ingroup CommonLib
+//! \{
+
+using namespace std;
+
+#if NN_FIXED_POINT_IMPLEMENTATION
+template<>
+void NNInference::infer(sadl::Model<int16_t> &model, std::vector<sadl::Tensor<int16_t>> &inputs)
+{
+    if (!model.apply(inputs))
+    {
+        cerr << "[ERROR] issue during inference" << endl;
+        exit(-1);
+    }
+}
+
+#else
+template<>
+void NNInference::infer(sadl::Model<float> &model, std::vector<sadl::Tensor<float>> &inputs)
+{
+    if (!model.apply(inputs))
+    {
+        cerr << "[ERROR] issue during inference" << endl;
+        exit(-1);
+    }
+}
+#endif
+
diff --git a/source/Lib/CommonLib/NNInference.h b/source/Lib/CommonLib/NNInference.h
index 9c33262314..a40c44fff3 100644
--- a/source/Lib/CommonLib/NNInference.h
+++ b/source/Lib/CommonLib/NNInference.h
@@ -43,8 +43,7 @@
 #include "Unit.h"
 #include "Picture.h"
 #include "Reshape.h"
-#include <sadl/model.h>
-using namespace std;
+#include <sadl/tensor.h>
 //! \ingroup CommonLib
 //! \{
 
@@ -57,10 +56,13 @@ struct InputData {
   bool chroma;
 };
 
+namespace sadl {
+template<typename T> class Model;
+}
+
 class NNInference
 {
 public:
-  NNInference();
   template<typename T>
   static void fillInputFromBuf (Picture* pic, UnitArea inferArea, sadl::Tensor<T> &input, PelUnitBuf buf, bool luma, bool chroma, double scale, int shift)
   {
@@ -166,7 +168,7 @@ public:
 #endif
   }
   template<typename T>
-  static void prepareInputs (Picture* pic, UnitArea inferArea, vector<sadl::Tensor<T>> &inputs, int globalQp, int localQp, int sliceType, const std::vector<InputData> &listInputData)
+  static void prepareInputs (Picture* pic, UnitArea inferArea, std::vector<sadl::Tensor<T>> &inputs, int globalQp, int localQp, int sliceType, const std::vector<InputData> &listInputData)
   {
     for (auto inputData : listInputData)
     {
@@ -200,15 +202,16 @@ public:
     }
   }
   template<typename T>
-  static void infer(sadl::Model<T> &model, vector<sadl::Tensor<T>> &inputs)
-  {
-    if (!model.apply(inputs))
-    {
-      cerr << "[ERROR] issue during inference" << endl;
-      exit(-1);
-    }
-  }
+  static void infer(sadl::Model<T> &model, std::vector<sadl::Tensor<T>> &inputs);
 };
+
+#if NN_FIXED_POINT_IMPLEMENTATION
+template<>
+void NNInference::infer(sadl::Model<int16_t> &model, std::vector<sadl::Tensor<int16_t>> &inputs);
+#else
+template<>
+void NNInference::infer(sadl::Model<float> &model, std::vector<sadl::Tensor<float>> &inputs);
+#endif
 //! \}
 #endif
 #endif
diff --git a/source/Lib/EncoderLib/CMakeLists.txt b/source/Lib/EncoderLib/CMakeLists.txt
index b3d998427e..32013c747a 100644
--- a/source/Lib/EncoderLib/CMakeLists.txt
+++ b/source/Lib/EncoderLib/CMakeLists.txt
@@ -62,12 +62,15 @@ if( CMAKE_COMPILER_IS_GNUCC )
 endif()
 
 if( MSVC )
+  set_property( SOURCE EncNNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ")
   set_property( SOURCE EncNNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "/arch:AVX2 -DNDEBUG=1 ")
 elseif( UNIX OR MINGW )
   if( NNLF_BUILD_WITH_AVX512 STREQUAL "1" )
-    set_property( SOURCE EncNNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw")
+    set_property( SOURCE EncNNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw -ffast-math")
+    set_property( SOURCE EncNNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx512f -mavx512bw -ffast-math")
   else()
-    set_property( SOURCE EncNNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2")
+    set_property( SOURCE EncNNFilterSet0.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2 -ffast-math")
+    set_property( SOURCE EncNNFilterSet1.cpp APPEND PROPERTY COMPILE_FLAGS "-DNDEBUG=1 -mavx2 -ffast-math")
   endif()
 endif()
 
diff --git a/source/Lib/EncoderLib/EncNNFilterSet1.h b/source/Lib/EncoderLib/EncNNFilterSet1.h
index 3a5b4524c2..68c5d56c74 100644
--- a/source/Lib/EncoderLib/EncNNFilterSet1.h
+++ b/source/Lib/EncoderLib/EncNNFilterSet1.h
@@ -44,10 +44,14 @@
 #include "Reshape.h"
 #include "CABACWriter.h"
 #include "CommonLib/NNFilterSet1.h"
-#include <sadl/model.h>
 //! \ingroup CommonLib
 //! \{
 
+// fwd
+namespace sadl {
+template<typename T> class Model;
+template<typename T> class Tensor;
+}
 
 class EncNNFilterSet1 : public NNFilterSet1
 {
-- 
GitLab