Skip to content
Snippets Groups Projects
Commit 42826ff8 authored by Franck Galpin's avatar Franck Galpin
Browse files

HOP4 training support

parent 4f2b3974
No related branches found
No related tags found
No related merge requests found
Subproject commit ed737f9b945954c806be8077e43441d7a4a056ee
Subproject commit e068b52b28d73e70010af57a8df54d8e17f143eb
......@@ -37,6 +37,9 @@ import torch
from torch import nn
from torch.nn import functional as F
# ugly global to avoid putting the export model inside the model
model_for_export = None
class Conv(nn.Sequential):
def __init__(
......@@ -159,17 +162,24 @@ class MultiBranchModule(nn.Module):
class Attn(nn.Module):
def __init__(self, c):
def __init__(self, to_train, c, n_h):
super(Attn, self).__init__()
self.n_h = 2
self.n_h = n_h
self.c = c
self.to_train = to_train
self.conv1_1 = nn.Conv2d(c, c, kernel_size=1, bias=False)
self.conv1_2 = nn.Conv2d(c, c, kernel_size=1, bias=False)
self.conv1_3 = nn.Conv2d(c, c, kernel_size=1, bias=False)
self.conv3_1 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False)
self.conv3_2 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False)
self.conv3_3 = nn.Conv2d(c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False)
self.conv3_1 = nn.Conv2d(
c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False
)
self.conv3_2 = nn.Conv2d(
c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False
)
self.conv3_3 = nn.Conv2d(
c, c, kernel_size=(3, 3), padding=(1, 1), groups=c, bias=False
)
self.conv2 = nn.Conv2d(c, c, kernel_size=1, bias=False)
def forward(self, x):
......@@ -177,9 +187,14 @@ class Attn(nn.Module):
k = self.conv3_2(self.conv1_2(x))
v = self.conv3_3(self.conv1_3(x))
s = q.shape
q = q.reshape(s[0], self.n_h, s[1] // self.n_h, -1)
k = k.reshape(s[0], self.n_h, s[1] // self.n_h, -1)
v = v.reshape(s[0], self.n_h, s[1] // self.n_h, -1)
if self.to_train:
q = q.reshape(s[0], self.n_h, self.c // self.n_h, -1)
k = k.reshape(s[0], self.n_h, self.c // self.n_h, -1)
v = v.reshape(s[0], self.n_h, self.c // self.n_h, -1)
else:
q = q.reshape(1, self.n_h, self.c // self.n_h, -1)
k = k.reshape(1, self.n_h, self.c // self.n_h, -1)
v = v.reshape(1, self.n_h, self.c // self.n_h, -1)
map = torch.matmul(q, k.transpose(-2, -1))
p = torch.matmul(map, v)
p = p.reshape(s)
......@@ -188,10 +203,10 @@ class Attn(nn.Module):
class TFBlock(nn.Module):
def __init__(self, c=64, n_h=2, ext=2.66):
def __init__(self, to_train, c=64, n_h=2, ext=2.66):
super(TFBlock, self).__init__()
self.attention = Attn(c)
self.attention = Attn(to_train, c, n_h)
def forward(self, x):
x = x + self.attention(x)
......@@ -241,6 +256,7 @@ class ResidualBlock(nn.Sequential):
class ResidualBlock_TF(nn.Sequential):
def __init__(
self,
to_train: bool,
C: int = 64,
C1: int = 160,
C21: int = 32,
......@@ -272,7 +288,7 @@ class ResidualBlock_TF(nn.Sequential):
index=index,
groups=2,
),
TFBlock(C, 2, 2.66),
TFBlock(to_train, C, 2, 2.66),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -284,6 +300,7 @@ class SADLNet(nn.Sequential):
def __init__(
self,
to_train: bool = True,
input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1],
input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3],
D1: int = 192,
......@@ -321,13 +338,34 @@ class SADLNet(nn.Sequential):
Conv(sum(self.input_features), D6, kernel_size=1),
Conv(D6, C, kernel_size=3, stride=2),
*[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(7)],
ResidualBlock_TF(C, C1, C21, C22, C31, index=7),
ResidualBlock_TF(to_train, C, C1, C21, C22, C31, index=7),
*[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(8, 14)],
ResidualBlock_TF(C, C1, C21, C22, C31, index=14),
ResidualBlock_TF(to_train, C, C1, C21, C22, C31, index=14),
*[ResidualBlock(C, C1, C21, C22, C31, index=i) for i in range(15, N)],
Conv(C, C, kernel_size=3),
Conv(C, output_channels, kernel_size=3, post_activation=None),
)
# model for export: batch size=1 and no dynamic axis on reshape
if to_train:
global model_for_export
model_for_export = SADLNet(
False,
input_channels,
input_kernels,
D1,
D2,
D3,
D4,
D5,
D6,
N,
C,
C1,
C21,
C22,
C31,
output_channels,
)
def get_example_inputs(
self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1
......@@ -352,12 +390,15 @@ class SADLNet(nn.Sequential):
) -> None:
mode = self.training
self.eval()
global model_for_export
model_for_export.load_state_dict(self.state_dict())
model_for_export.eval()
torch.onnx.export(
self,
model_for_export,
self.get_example_inputs(patch_size, batch_size),
filename,
input_names=["in"],
dynamic_axes={'in': {2: 'h', 3: 'w'}},
dynamic_axes={"in": {2: "h", 3: "w"}},
opset_version=opset,
**kwargs,
)
......@@ -399,7 +440,22 @@ class Net(nn.Module):
self.input_channels = input_channels
sizes = [len(a) for a in input_channels]
self.SADL_model = SADLNet(
sizes, input_kernels, D1, D2, D3, D4, D5, D6, N, C, C1, C21, C22, C31, 6
True,
sizes,
input_kernels,
D1,
D2,
D3,
D4,
D5,
D6,
N,
C,
C1,
C21,
C22,
C31,
6,
)
self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest")
......
File deleted
File deleted
File deleted
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment