From 42826ff87e2507d9483fb2a647a2e06b1de894e4 Mon Sep 17 00:00:00 2001 From: Franck Galpin <franck.galpin@interdigital.com> Date: Fri, 24 May 2024 07:27:17 +0000 Subject: [PATCH] HOP4 training support --- sadl | 2 +- .../NN_Filtering/HOP4/model/model.py | 90 +++++++++++++++---- .../NN_Filtering/HOP4/models/HOP_float.I.sadl | 3 - .../HOP4/models/HOP_float.II.sadl | 3 - .../HOP4/models/HOP_int16.II.sadl | 3 - 5 files changed, 74 insertions(+), 27 deletions(-) delete mode 100644 training/training_scripts/NN_Filtering/HOP4/models/HOP_float.I.sadl delete mode 100644 training/training_scripts/NN_Filtering/HOP4/models/HOP_float.II.sadl delete mode 100644 training/training_scripts/NN_Filtering/HOP4/models/HOP_int16.II.sadl diff --git a/sadl b/sadl index ed737f9b94..e068b52b28 160000 --- a/sadl +++ b/sadl @@ -1 +1 @@ -Subproject commit ed737f9b945954c806be8077e43441d7a4a056ee +Subproject commit e068b52b28d73e70010af57a8df54d8e17f143eb diff --git a/training/training_scripts/NN_Filtering/HOP4/model/model.py b/training/training_scripts/NN_Filtering/HOP4/model/model.py index 70c41c4591..4e484407e2 100644 --- a/training/training_scripts/NN_Filtering/HOP4/model/model.py +++ b/training/training_scripts/NN_Filtering/HOP4/model/model.py @@ -37,6 +37,9 @@ import torch from torch import nn from torch.nn import functional as F +# ugly global to avoid putting the export model inside the model +model_for_export = None + class Conv(nn.Sequential): def __init__( @@ -159,17 +162,24 @@ class MultiBranchModule(nn.Module): class Attn(nn.Module): - def __init__(self, c): + def __init__(self, to_train, c, n_h): super(Attn, self).__init__() - self.n_h = 2 - + self.n_h = n_h + self.c = c + self.to_train = to_train self.conv1_1 = nn.Conv2d(c, c, kernel_size=1, bias=False) self.conv1_2 = nn.Conv2d(c, c, kernel_size=1, bias=False) self.conv1_3 = nn.Conv2d(c, c, kernel_size=1, bias=False) - self.conv3_1 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False) - self.conv3_2 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False) - self.conv3_3 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False) + self.conv3_1 = nn.Conv2d( + c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False + ) + self.conv3_2 = nn.Conv2d( + c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False + ) + self.conv3_3 = nn.Conv2d( + c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False + ) self.conv2 = nn.Conv2d(c, c, kernel_size=1, bias=False) def forward(self, x): @@ -177,9 +187,14 @@ class Attn(nn.Module): k = self.conv3_2(self.conv1_2(x)) v = self.conv3_3(self.conv1_3(x)) s = q.shape - q = q.reshape(s[0], self.n_h, s[1] // self.n_h, -1) - k = k.reshape(s[0], self.n_h, s[1] // self.n_h, -1) - v = v.reshape(s[0], self.n_h, s[1] // self.n_h, -1) + if self.to_train: + q = q.reshape(s[0], self.n_h, self.c // self.n_h, -1) + k = k.reshape(s[0], self.n_h, self.c // self.n_h, -1) + v = v.reshape(s[0], self.n_h, self.c // self.n_h, -1) + else: + q = q.reshape(1, self.n_h, self.c // self.n_h, -1) + k = k.reshape(1, self.n_h, self.c // self.n_h, -1) + v = v.reshape(1, self.n_h, self.c // self.n_h, -1) map = torch.matmul(q, k.transpose(-2, -1)) p = torch.matmul(map, v) p = p.reshape(s) @@ -188,10 +203,10 @@ class Attn(nn.Module): class TFBlock(nn.Module): - def __init__(self, c=64, n_h=2, ext=2.66): + def __init__(self, to_train, c=64, n_h=2, ext=2.66): super(TFBlock, self).__init__() - self.attention = Attn(c) + self.attention = Attn(to_train, c, n_h) def forward(self, x): x = x + self.attention(x) @@ -241,6 +256,7 @@ class ResidualBlock(nn.Sequential): class ResidualBlock_TF(nn.Sequential): def __init__( self, + to_train: bool, C: int = 64, C1: int = 160, C21: int = 32, @@ -272,7 +288,7 @@ class ResidualBlock_TF(nn.Sequential): index=index, groups=2, ), - TFBlock(C, 2, 2.66), + TFBlock(to_train, C, 2, 2.66), ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -284,6 +300,7 @@ class SADLNet(nn.Sequential): def __init__( self, + to_train: bool = True, input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1], input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3], D1: int = 192, @@ -321,13 +338,34 @@ class SADLNet(nn.Sequential): Conv(sum(self.input_features), D6, kernel_size=1), Conv(D6, C, kernel_size=3, stride=2), *[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(7)], - ResidualBlock_TF(C, C1, C21, C22, C31, index=7), + ResidualBlock_TF(to_train, C, C1, C21, C22, C31, index=7), *[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(8, 14)], - ResidualBlock_TF(C, C1, C21, C22, C31, index=14), + ResidualBlock_TF(to_train, C, C1, C21, C22, C31, index=14), *[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(15, N)], Conv(C, C, kernel_size=3), Conv(C, output_channels, kernel_size=3, post_activation=None), ) + # model for export: batch size=1 and no dynamic axis on reshape + if to_train: + global model_for_export + model_for_export = SADLNet( + False, + input_channels, + input_kernels, + D1, + D2, + D3, + D4, + D5, + D6, + N, + C, + C1, + C21, + C22, + C31, + output_channels, + ) def get_example_inputs( self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1 @@ -352,12 +390,15 @@ class SADLNet(nn.Sequential): ) -> None: mode = self.training self.eval() + global model_for_export + model_for_export.load_state_dict(self.state_dict()) + model_for_export.eval() torch.onnx.export( - self, + model_for_export, self.get_example_inputs(patch_size, batch_size), filename, input_names=["in"], - dynamic_axes={'in': {2: 'h', 3: 'w'}}, + dynamic_axes={"in": {2: "h", 3: "w"}}, opset_version=opset, **kwargs, ) @@ -399,7 +440,22 @@ class Net(nn.Module): self.input_channels = input_channels sizes = [len(a) for a in input_channels] self.SADL_model = SADLNet( - sizes, input_kernels, D1, D2, D3, D4, D5, D6, N, C, C1, C21, C22, C31, 6 + True, + sizes, + input_kernels, + D1, + D2, + D3, + D4, + D5, + D6, + N, + C, + C1, + C21, + C22, + C31, + 6, ) self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest") diff --git a/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.I.sadl b/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.I.sadl deleted file mode 100644 index 603e1a23c2..0000000000 --- a/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.I.sadl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9145ff03b5252071f06584a4107216e09787fdfe70dfb8753231c64552ad7906 -size 5846439 diff --git a/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.II.sadl b/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.II.sadl deleted file mode 100644 index a83fcd7e92..0000000000 --- a/training/training_scripts/NN_Filtering/HOP4/models/HOP_float.II.sadl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6d5f3ed52b88db4d1cdf03433ac3a7b3b275213a340f19d09a140d7cd6239251 -size 5858226 diff --git a/training/training_scripts/NN_Filtering/HOP4/models/HOP_int16.II.sadl b/training/training_scripts/NN_Filtering/HOP4/models/HOP_int16.II.sadl deleted file mode 100644 index d4edd8f6cf..0000000000 --- a/training/training_scripts/NN_Filtering/HOP4/models/HOP_int16.II.sadl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:182fe25e010a7c1767859c76b9b13586e043f2d1b831774358a927ba09e406bb -size 2959590 -- GitLab