From 42826ff87e2507d9483fb2a647a2e06b1de894e4 Mon Sep 17 00:00:00 2001
From: Franck Galpin <franck.galpin@interdigital.com>
Date: Fri, 24 May 2024 07:27:17 +0000
Subject: [PATCH] HOP4 training support

---
 sadl                                          |  2 +-
 .../NN_Filtering/HOP4/model/model.py          | 90 +++++++++++++++----
 .../NN_Filtering/HOP4/models/HOP_float.I.sadl |  3 -
 .../HOP4/models/HOP_float.II.sadl             |  3 -
 .../HOP4/models/HOP_int16.II.sadl             |  3 -
 5 files changed, 74 insertions(+), 27 deletions(-)
 delete mode 100644 training/training_scripts/NN_Filtering/HOP4/models/HOP_float.I.sadl
 delete mode 100644 training/training_scripts/NN_Filtering/HOP4/models/HOP_float.II.sadl
 delete mode 100644 training/training_scripts/NN_Filtering/HOP4/models/HOP_int16.II.sadl

diff --git a/sadl b/sadl
index ed737f9b94..e068b52b28 160000
--- a/sadl
+++ b/sadl
@@ -1 +1 @@
-Subproject commit ed737f9b945954c806be8077e43441d7a4a056ee
+Subproject commit e068b52b28d73e70010af57a8df54d8e17f143eb
diff --git a/training/training_scripts/NN_Filtering/HOP4/model/model.py b/training/training_scripts/NN_Filtering/HOP4/model/model.py
index 70c41c4591..4e484407e2 100644
--- a/training/training_scripts/NN_Filtering/HOP4/model/model.py
+++ b/training/training_scripts/NN_Filtering/HOP4/model/model.py
@@ -37,6 +37,9 @@ import torch
 from torch import nn
 from torch.nn import functional as F
 
+# ugly global to avoid putting the export model inside the model
+model_for_export = None
+
 
 class Conv(nn.Sequential):
     def __init__(
@@ -159,17 +162,24 @@ class MultiBranchModule(nn.Module):
 
 
 class Attn(nn.Module):
-    def __init__(self, c):
+    def __init__(self, to_train, c, n_h):
         super(Attn, self).__init__()
-        self.n_h = 2
-
+        self.n_h = n_h
+        self.c = c
+        self.to_train = to_train
         self.conv1_1 = nn.Conv2d(c, c, kernel_size=1, bias=False)
         self.conv1_2 = nn.Conv2d(c, c, kernel_size=1, bias=False)
         self.conv1_3 = nn.Conv2d(c, c, kernel_size=1, bias=False)
 
-        self.conv3_1 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False)
-        self.conv3_2 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False)
-        self.conv3_3 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False)
+        self.conv3_1 = nn.Conv2d(
+            c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False
+        )
+        self.conv3_2 = nn.Conv2d(
+            c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False
+        )
+        self.conv3_3 = nn.Conv2d(
+            c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False
+        )
         self.conv2 = nn.Conv2d(c, c, kernel_size=1, bias=False)
 
     def forward(self, x):
@@ -177,9 +187,14 @@ class Attn(nn.Module):
         k = self.conv3_2(self.conv1_2(x))
         v = self.conv3_3(self.conv1_3(x))
         s = q.shape
-        q = q.reshape(s[0], self.n_h, s[1] // self.n_h, -1)
-        k = k.reshape(s[0], self.n_h, s[1] // self.n_h, -1)
-        v = v.reshape(s[0], self.n_h, s[1] // self.n_h, -1)
+        if self.to_train:
+            q = q.reshape(s[0], self.n_h, self.c // self.n_h, -1)
+            k = k.reshape(s[0], self.n_h, self.c // self.n_h, -1)
+            v = v.reshape(s[0], self.n_h, self.c // self.n_h, -1)
+        else:
+            q = q.reshape(1, self.n_h, self.c // self.n_h, -1)
+            k = k.reshape(1, self.n_h, self.c // self.n_h, -1)
+            v = v.reshape(1, self.n_h, self.c // self.n_h, -1)
         map = torch.matmul(q, k.transpose(-2, -1))
         p = torch.matmul(map, v)
         p = p.reshape(s)
@@ -188,10 +203,10 @@ class Attn(nn.Module):
 
 
 class TFBlock(nn.Module):
-    def __init__(self, c=64, n_h=2, ext=2.66):
+    def __init__(self, to_train, c=64, n_h=2, ext=2.66):
         super(TFBlock, self).__init__()
 
-        self.attention = Attn(c)
+        self.attention = Attn(to_train, c, n_h)
 
     def forward(self, x):
         x = x + self.attention(x)
@@ -241,6 +256,7 @@ class ResidualBlock(nn.Sequential):
 class ResidualBlock_TF(nn.Sequential):
     def __init__(
         self,
+        to_train: bool,
         C: int = 64,
         C1: int = 160,
         C21: int = 32,
@@ -272,7 +288,7 @@ class ResidualBlock_TF(nn.Sequential):
                 index=index,
                 groups=2,
             ),
-            TFBlock(C, 2, 2.66),
+            TFBlock(to_train, C, 2, 2.66),
         )
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -284,6 +300,7 @@ class SADLNet(nn.Sequential):
 
     def __init__(
         self,
+        to_train: bool = True,
         input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1],
         input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3],
         D1: int = 192,
@@ -321,13 +338,34 @@ class SADLNet(nn.Sequential):
             Conv(sum(self.input_features), D6, kernel_size=1),
             Conv(D6, C, kernel_size=3, stride=2),
             *[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(7)],
-            ResidualBlock_TF(C, C1, C21, C22, C31, index=7),
+            ResidualBlock_TF(to_train, C, C1, C21, C22, C31, index=7),
             *[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(8, 14)],
-            ResidualBlock_TF(C, C1, C21, C22, C31, index=14),
+            ResidualBlock_TF(to_train, C, C1, C21, C22, C31, index=14),
             *[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(15, N)],
             Conv(C, C, kernel_size=3),
             Conv(C, output_channels, kernel_size=3, post_activation=None),
         )
+        # model for export: batch size=1 and no dynamic axis on reshape
+        if to_train:
+            global model_for_export
+            model_for_export = SADLNet(
+                False,
+                input_channels,
+                input_kernels,
+                D1,
+                D2,
+                D3,
+                D4,
+                D5,
+                D6,
+                N,
+                C,
+                C1,
+                C21,
+                C22,
+                C31,
+                output_channels,
+            )
 
     def get_example_inputs(
         self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1
@@ -352,12 +390,15 @@ class SADLNet(nn.Sequential):
     ) -> None:
         mode = self.training
         self.eval()
+        global model_for_export
+        model_for_export.load_state_dict(self.state_dict())
+        model_for_export.eval()
         torch.onnx.export(
-            self,
+            model_for_export,
             self.get_example_inputs(patch_size, batch_size),
             filename,
             input_names=["in"],
-            dynamic_axes={'in': {2: 'h', 3: 'w'}},
+            dynamic_axes={"in": {2: "h", 3: "w"}},
             opset_version=opset,
             **kwargs,
         )
@@ -399,7 +440,22 @@ class Net(nn.Module):
         self.input_channels = input_channels
         sizes = [len(a) for a in input_channels]
         self.SADL_model = SADLNet(
-            sizes, input_kernels, D1, D2, D3, D4, D5, D6, N, C, C1, C21, C22, C31, 6
+            True,
+            sizes,
+            input_kernels,
+            D1,
+            D2,
+            D3,
+            D4,
+            D5,
+            D6,
+            N,
+            C,
+            C1,
+            C21,
+            C22,
+            C31,
+            6,
         )
         self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest")
 
diff --git a/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.I.sadl b/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.I.sadl
deleted file mode 100644
index 603e1a23c2..0000000000
--- a/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.I.sadl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:9145ff03b5252071f06584a4107216e09787fdfe70dfb8753231c64552ad7906
-size 5846439
diff --git a/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.II.sadl b/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.II.sadl
deleted file mode 100644
index a83fcd7e92..0000000000
--- a/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.II.sadl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:6d5f3ed52b88db4d1cdf03433ac3a7b3b275213a340f19d09a140d7cd6239251
-size 5858226
diff --git a/training/training_scripts/NN_Filtering/HOP4/models/HOP_int16.II.sadl b/training/training_scripts/NN_Filtering/HOP4/models/HOP_int16.II.sadl
deleted file mode 100644
index d4edd8f6cf..0000000000
--- a/training/training_scripts/NN_Filtering/HOP4/models/HOP_int16.II.sadl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:182fe25e010a7c1767859c76b9b13586e043f2d1b831774358a927ba09e406bb
-size 2959590
-- 
GitLab