Skip to content
Snippets Groups Projects
Commit f8c48bd6 authored by Yue Li's avatar Yue Li
Browse files

add converter into the training scripts of JVET-AA0111

parent 8c25c30a
No related branches found
No related tags found
No related merge requests found
Showing
with 655 additions and 0 deletions
import torch
import torch.nn as nn
from net import ConditionalNet
import numpy as np
luma = np.ones((1, 1, 100, 100), dtype=np.float32)
yuv = np.ones((1, 2, 50, 50), dtype=np.float32)
pred = np.ones((1, 2, 50, 50), dtype=np.float32)
bs = np.ones((1, 2, 50, 50), dtype=np.float32)
qp = np.ones((1, 1, 50, 50), dtype=np.float32)
# model = nn.DataParallel(ConditionalNet(96, 8)) # if model is trained on multiple GPUs
model = ConditionalNet(96, 8) # if model is trained with single GPU
state = torch.load('50.ckpt', map_location=torch.device('cpu'))
model.load_state_dict(state)
dummy_input = (torch.from_numpy(luma), torch.from_numpy(yuv), torch.from_numpy(pred), torch.from_numpy(bs), torch.from_numpy(qp))
torch.onnx.export(model.module, dummy_input, "NnlfSet1_ChromaCNNFilter_InterSlice.onnx")
import torch
import torch.nn as nn
def conv3x3(in_channels, out_channels, stride=1, padding=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=padding)
def conv1x1(in_channels, out_channels, stride=1, padding=0):
return nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, padding=padding)
# Conv3x3 + PReLU
class conv3x3_f(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(conv3x3_f, self).__init__()
self.conv = conv3x3(in_channels, out_channels, stride)
self.relu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Conv1x1 + PReLU
class conv1x1_f(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(conv1x1_f, self).__init__()
self.conv = conv1x1(in_channels, out_channels, stride)
self.relu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Residual Block
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.relu = nn.PReLU()
self.conv2 = conv3x3(out_channels, out_channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
out += residual
return out
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
# Channel Gate
class ChannelGate(nn.Module):
def __init__(self, channels):
super(ChannelGate, self).__init__()
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(1, channels),
nn.PReLU(),
nn.Linear(channels, channels)
)
def forward(self, x):
out = self.mlp(x)
return out
class SpatialGate(nn.Module):
def __init__(self, in_channels, num_features):
super(SpatialGate, self).__init__()
self.conv1 = conv3x3(in_channels, num_features, stride=2)
self.relu = nn.PReLU()
self.conv2 = conv3x3(num_features, 1)
def forward(self, x):
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out
class ConditionalNet(nn.Module):
def __init__(self, f, rbn):
super(ConditionalNet, self).__init__()
self.rbn = rbn
self.convLuma = conv3x3_f(1, f, 2)
self.convRec = conv3x3_f(2, f)
self.convPred = conv3x3_f(2, f)
self.convBs = conv3x3_f(2, f)
self.convQp = conv3x3_f(1, f)
self.fuse = conv1x1_f(5 * f, f)
self.transitionH = conv3x3_f(f, f, 2)
self.backbone = nn.ModuleList([ResidualBlock(f, f)])
for _ in range(self.rbn - 1):
self.backbone.append(ResidualBlock(f, f))
self.last_layer = nn.Sequential(
nn.Conv2d(
in_channels=f,
out_channels=f,
kernel_size=3,
stride=1,
padding=1),
nn.PReLU(),
nn.Conv2d(
in_channels=f,
out_channels=8,
kernel_size=3,
stride=1,
padding=1),
)
def forward(self, luma, rec, pred, bs, qp):
luma_f = self.convLuma(luma)
rec_f = self.convRec(rec)
pred_f = self.convPred(pred)
bs_f = self.convBs(bs)
qp_f = self.convQp(qp)
xh = torch.cat((luma_f, rec_f, pred_f, bs_f, qp_f), 1)
xh = self.fuse(xh)
x = self.transitionH(xh)
for i in range(self.rbn):
x = self.backbone[i](x)
# output
x = self.last_layer(x)
return x
import torch
import torch.nn as nn
from net import ConditionalNet
import numpy as np
luma = np.ones((1, 1, 100, 100), dtype=np.float32)
yuv = np.ones((1, 2, 50, 50), dtype=np.float32)
pred = np.ones((1, 2, 50, 50), dtype=np.float32)
split = np.ones((1, 2, 50, 50), dtype=np.float32)
bs = np.ones((1, 2, 50, 50), dtype=np.float32)
qp = np.ones((1, 1, 50, 50), dtype=np.float32)
# model = nn.DataParallel(ConditionalNet(96, 8)) # if model is trained on multiple GPUs
model = ConditionalNet(96, 8) # if model is trained with single GPU
state = torch.load('50.ckpt', map_location=torch.device('cpu'))
model.load_state_dict(state)
dummy_input = (torch.from_numpy(luma), torch.from_numpy(yuv), torch.from_numpy(pred), torch.from_numpy(split), torch.from_numpy(bs), torch.from_numpy(qp))
torch.onnx.export(model.module, dummy_input, "NnlfSet1_ChromaCNNFilter_IntraSlice.onnx")
import torch
import torch.nn as nn
def conv3x3(in_channels, out_channels, stride=1, padding=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=padding)
def conv1x1(in_channels, out_channels, stride=1, padding=0):
return nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, padding=padding)
# Conv3x3 + PReLU
class conv3x3_f(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(conv3x3_f, self).__init__()
self.conv = conv3x3(in_channels, out_channels, stride)
self.relu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Conv1x1 + PReLU
class conv1x1_f(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(conv1x1_f, self).__init__()
self.conv = conv1x1(in_channels, out_channels, stride)
self.relu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Residual Block
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.relu = nn.PReLU()
self.conv2 = conv3x3(out_channels, out_channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
out += residual
return out
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
# Channel Gate
class ChannelGate(nn.Module):
def __init__(self, channels):
super(ChannelGate, self).__init__()
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(1, channels),
nn.PReLU(),
nn.Linear(channels, channels)
)
def forward(self, x):
out = self.mlp(x)
return out
class SpatialGate(nn.Module):
def __init__(self, in_channels, num_features):
super(SpatialGate, self).__init__()
self.conv1 = conv3x3(in_channels, num_features, stride=2)
self.relu = nn.PReLU()
self.conv2 = conv3x3(num_features, 1)
def forward(self, x):
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out
class ConditionalNet(nn.Module):
def __init__(self, f, rbn):
super(ConditionalNet, self).__init__()
self.rbn = rbn
self.convLuma = conv3x3_f(1, f, 2)
self.convRec = conv3x3_f(2, f)
self.convPred = conv3x3_f(2, f)
self.convSplit = conv3x3_f(2, f)
self.convBs = conv3x3_f(2, f)
self.convQp = conv3x3_f(1, f)
self.fuse = conv1x1_f(6 * f, f)
self.transitionH = conv3x3_f(f, f, 2)
self.backbone = nn.ModuleList([ResidualBlock(f, f)])
for _ in range(self.rbn - 1):
self.backbone.append(ResidualBlock(f, f))
self.last_layer = nn.Sequential(
nn.Conv2d(
in_channels=f,
out_channels=f,
kernel_size=3,
stride=1,
padding=1),
nn.PReLU(),
nn.Conv2d(
in_channels=f,
out_channels=8,
kernel_size=3,
stride=1,
padding=1),
)
def forward(self, luma, rec, pred, split, bs, qp):
luma_f = self.convLuma(luma)
rec_f = self.convRec(rec)
pred_f = self.convPred(pred)
split_f = self.convSplit(split)
bs_f = self.convBs(bs)
qp_f = self.convQp(qp)
xh = torch.cat((luma_f, rec_f, pred_f, split_f, bs_f, qp_f), 1)
xh = self.fuse(xh)
x = self.transitionH(xh)
for i in range(self.rbn):
x = self.backbone[i](x)
# output
x = self.last_layer(x)
return x
import torch
import torch.nn as nn
from net import ConditionalNet
import numpy as np
import os
# input
yuv = np.ones((1, 1, 32, 32), dtype=np.float32)
pred = np.ones((1, 1, 32, 32), dtype=np.float32)
bs = np.ones((1, 1, 32, 32), dtype=np.float32)
qp = np.ones((1, 1, 32, 32), dtype=np.float32)
# model
# model = nn.DataParallel(ConditionalNet(96, 8)) # if model is trained on multiple GPUs
model = ConditionalNet(96, 8) # if model is trained with single GPU
state = torch.load('50.ckpt', map_location=torch.device('cpu'))
model.load_state_dict(state)
dummy_input = (torch.from_numpy(yuv), torch.from_numpy(pred), torch.from_numpy(bs), torch.from_numpy(qp))
torch.onnx.export(model.module, dummy_input, "NnlfSet1_LumaCNNFilter_InterSlice.onnx")
import torch
import torch.nn as nn
def conv3x3(in_channels, out_channels, stride=1, padding=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=padding)
def conv1x1(in_channels, out_channels, stride=1, padding=0):
return nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, padding=padding)
# Conv3x3 + PReLU
class conv3x3_f(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(conv3x3_f, self).__init__()
self.conv = conv3x3(in_channels, out_channels, stride)
self.relu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Conv1x1 + PReLU
class conv1x1_f(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(conv1x1_f, self).__init__()
self.conv = conv1x1(in_channels, out_channels, stride)
self.relu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Residual Block
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels)
self.relu = nn.PReLU()
self.conv2 = conv3x3(out_channels, out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
# Channel Gate
class ChannelGate(nn.Module):
def __init__(self, channels):
super(ChannelGate, self).__init__()
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(1, channels),
nn.PReLU(),
nn.Linear(channels, channels)
)
def forward(self, x):
out = self.mlp(x)
return out
class SpatialGate(nn.Module):
def __init__(self, in_channels, num_features):
super(SpatialGate, self).__init__()
self.conv1 = conv3x3(in_channels, num_features, stride=2)
self.relu = nn.PReLU()
self.conv2 = conv3x3(num_features, 1)
def forward(self, x):
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out
class ConditionalNet(nn.Module):
def __init__(self, f, rbn):
super(ConditionalNet, self).__init__()
self.rbn = rbn
self.convRec = conv3x3_f(1, f)
self.convPred = conv3x3_f(1, f)
self.convBs = conv3x3_f(1, f)
self.convQp = conv3x3_f(1, f)
self.fuse = conv1x1_f(4 * f, f)
self.transitionH = conv3x3_f(f, f, 2)
self.backbone = nn.ModuleList([ResidualBlock(f, f)])
for _ in range(self.rbn - 1):
self.backbone.append(ResidualBlock(f, f))
self.mask = nn.ModuleList([SpatialGate(4, 32)])
for _ in range(self.rbn - 1):
self.mask.append(SpatialGate(4, 32))
self.last_layer = nn.Sequential(
nn.Conv2d(
in_channels=f,
out_channels=f,
kernel_size=3,
stride=1,
padding=1),
nn.PReLU(),
nn.Conv2d(
in_channels=f,
out_channels=4,
kernel_size=3,
stride=1,
padding=1),
)
def forward(self, rec, pred, bs, qp):
rec_f = self.convRec(rec)
pred_f = self.convPred(pred)
bs_f = self.convBs(bs)
qp_f = self.convQp(qp)
xh = torch.cat((rec_f, pred_f, bs_f, qp_f), 1)
xh = self.fuse(xh)
x = self.transitionH(xh)
for i in range(self.rbn):
x_resi = self.backbone[i](x)
attention = self.mask[i](torch.cat((rec, pred, bs, qp), 1))
x = attention.expand_as(x_resi) * x_resi + x_resi + x
# output
x = self.last_layer(x)
return x
import torch
import torch.nn as nn
from net import ConditionalNet
import numpy as np
import os
# input
yuv = np.ones((1, 1, 32, 32), dtype=np.float32)
pred = np.ones((1, 1, 32, 32), dtype=np.float32)
split = np.ones((1, 1, 32, 32), dtype=np.float32)
bs = np.ones((1, 1, 32, 32), dtype=np.float32)
qp = np.ones((1, 1, 32, 32), dtype=np.float32)
# model = nn.DataParallel(ConditionalNet(96, 8)) # if model is trained on multiple GPUs
model = ConditionalNet(96, 8) # if model is trained with single GPU
state = torch.load('50.ckpt', map_location=torch.device('cpu'))
model.load_state_dict(state)
dummy_input = (torch.from_numpy(yuv), torch.from_numpy(pred), torch.from_numpy(split), torch.from_numpy(bs), torch.from_numpy(qp))
torch.onnx.export(model.module, dummy_input, "NnlfSet1_LumaCNNFilter_IntraSlice.onnx")
import torch
import torch.nn as nn
def conv3x3(in_channels, out_channels, stride=1, padding=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=padding)
def conv1x1(in_channels, out_channels, stride=1, padding=0):
return nn.Conv2d(in_channels, out_channels, kernel_size=1,
stride=stride, padding=padding)
# Conv3x3 + PReLU
class conv3x3_f(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(conv3x3_f, self).__init__()
self.conv = conv3x3(in_channels, out_channels, stride)
self.relu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Conv1x1 + PReLU
class conv1x1_f(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(conv1x1_f, self).__init__()
self.conv = conv1x1(in_channels, out_channels, stride)
self.relu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
# Residual Block
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels)
self.relu = nn.PReLU()
self.conv2 = conv3x3(out_channels, out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
# Channel Gate
class ChannelGate(nn.Module):
def __init__(self, channels):
super(ChannelGate, self).__init__()
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(1, channels),
nn.PReLU(),
nn.Linear(channels, channels)
)
def forward(self, x):
out = self.mlp(x)
return out
class SpatialGate(nn.Module):
def __init__(self, in_channels, num_features):
super(SpatialGate, self).__init__()
self.conv1 = conv3x3(in_channels, num_features, stride=2)
self.relu = nn.PReLU()
self.conv2 = conv3x3(num_features, 1)
def forward(self, x):
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out
class ConditionalNet(nn.Module):
def __init__(self, f, rbn):
super(ConditionalNet, self).__init__()
self.rbn = rbn
self.convRec = conv3x3_f(1, f)
self.convPred = conv3x3_f(1, f)
self.convSplit = conv3x3_f(1, f)
self.convBs = conv3x3_f(1, f)
self.convQp = conv3x3_f(1, f)
self.fuse = conv1x1_f(5 * f, f)
self.transitionH = conv3x3_f(f, f, 2)
self.backbone = nn.ModuleList([ResidualBlock(f, f)])
for _ in range(self.rbn - 1):
self.backbone.append(ResidualBlock(f, f))
self.mask = nn.ModuleList([SpatialGate(5, 32)])
for _ in range(self.rbn - 1):
self.mask.append(SpatialGate(5, 32))
self.last_layer = nn.Sequential(
nn.Conv2d(
in_channels=f,
out_channels=f,
kernel_size=3,
stride=1,
padding=1),
nn.PReLU(),
nn.Conv2d(
in_channels=f,
out_channels=4,
kernel_size=3,
stride=1,
padding=1),
)
def forward(self, rec, pred, split, bs, qp):
rec_f = self.convRec(rec)
pred_f = self.convPred(pred)
split_f = self.convSplit(split)
bs_f = self.convBs(bs)
qp_f = self.convQp(qp)
xh = torch.cat((rec_f, pred_f, split_f, bs_f, qp_f), 1)
xh = self.fuse(xh)
x = self.transitionH(xh)
for i in range(self.rbn):
x_resi = self.backbone[i](x)
attention = self.mask[i](torch.cat((rec, pred, split, bs, qp), 1))
x = attention.expand_as(x_resi) * x_resi + x_resi + x
# output
x = self.last_layer(x)
return x
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment