diff --git a/models/nnlf_lop5_model_float.sadl b/models/nnlf_lop5_model_float.sadl new file mode 100644 index 0000000000000000000000000000000000000000..7338bb31c92fd27414e257783c9e41c2f9083687 --- /dev/null +++ b/models/nnlf_lop5_model_float.sadl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a830f8e22259a1b2f68e6817c03a9d0c3f1d9dbf22230f5f31186568f461f0e4 +size 1024289 diff --git a/models/nnlf_lop5_model_int16.sadl b/models/nnlf_lop5_model_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..02b4e5304a7c1037b244e2bcf3ac6d64090c34dd --- /dev/null +++ b/models/nnlf_lop5_model_int16.sadl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7a70d4923bcfb61cc71d9eaf9a2107bd6eb618eb2f7dcf03070a6c0f2264063 +size 531952 diff --git a/source/App/DecoderApp/DecAppCfg.cpp b/source/App/DecoderApp/DecAppCfg.cpp index f9ce7af468204e580ac1ece1783bc8e5628618cc..b54300782908a5eb14406ecf814ad585869982a7 100644 --- a/source/App/DecoderApp/DecAppCfg.cpp +++ b/source/App/DecoderApp/DecAppCfg.cpp @@ -81,7 +81,7 @@ bool DecAppCfg::parseCfg( int argc, char* argv[] ) ("DumpBasename", m_dumpBasename, string(""), "basename for data dumping\n") #endif #if NN_LF_UNIFIED - ("NnlfModelName", m_nnModel[NnModel::LOP_UNIFIED_FILTER], string("models/nnlf_lop4_model_int16.sadl"), "loop filter model name\n") + ("NnlfModelName", m_nnModel[NnModel::LOP_UNIFIED_FILTER], string("models/nnlf_lop5_model_int16.sadl"), "loop filter model name\n") #if NN_LF_FORCE_USE ( "NnlfUnifiedDebugOption", m_nnlfDebugOption, 0, "Option used to debug stage 1 model. 0: default, 1: apply only on I slice, 2: apply on all slices using I type as input" ) #endif diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp index 8c66a7d9887011770bca76cfd76ba4f419878a6d..1088b1fec508685761732a9899d09b2b3b6efb3e 100644 --- a/source/App/EncoderApp/EncAppCfg.cpp +++ b/source/App/EncoderApp/EncAppCfg.cpp @@ -1204,7 +1204,7 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) #else ("NnlfOption", m_nnlfOption, 1, "NN-based in-loop filter option (0:disable nnlf, 1: enable unified nnlf, [legacy: 10: enable nnlf-0, 11: enable nnlf-1, 12: enable nnlf-LC])") #endif - ("NnlfModelName", m_nnModel[NnModel::LOP_UNIFIED_FILTER], string("models/nnlf_lop4_model_int16.sadl"), "unified nnlf model name." ) + ("NnlfModelName", m_nnModel[NnModel::LOP_UNIFIED_FILTER], string("models/nnlf_lop5_model_int16.sadl"), "unified nnlf model name." ) #endif #if NN_LF_UNIFIED ( "NnlfBlockSize", m_nnlfBlockSize, 128u, "Base inference size of NN-based in-loop filter") diff --git a/training/training_scripts/NN_Filtering/LOP5/config.json b/training/training_scripts/NN_Filtering/LOP5/config.json new file mode 100644 index 0000000000000000000000000000000000000000..440a7d76a72c6c62f6536c97938935d9c2c9d2ae --- /dev/null +++ b/training/training_scripts/NN_Filtering/LOP5/config.json @@ -0,0 +1,112 @@ +{ + "stage1": { + "training": { + "mse_epoch": 126, + "max_epochs": 130, + "component_loss_weightings": [ + 6, + 2 + ], + "dataloader": { + "batch_size": 128 + }, + "optimizer": { + "lr": 0.0004 + }, + "lr_scheduler": { + "milestones": [ + 120, + 123, + 126 + ] + }, + "dct_size": 2 + } + }, + + + "stage2": { + "encdec_bvi": { + "vtm_option": "--NnlfDebugOption=1 --NnlfOption=1 --NnlfModelName=[stage1/conversion/full_path_filename]", + "vtm_dec_option": "--NnlfModelName=[stage1/conversion/full_path_filename]" + }, + "encdec_bvi_valid": { + "vtm_option": "--NnlfDebugOption=1 --NnlfOption=1 --NnlfModelName=[stage1/conversion/full_path_filename]", + "vtm_dec_option": "--NnlfModelName=[stage1/conversion/full_path_filename]" + }, + "encdec_tvd": { + "vtm_option": "--NnlfDebugOption=1 --NnlfOption=1 --NnlfModelName=[stage1/conversion/full_path_filename]", + "vtm_dec_option": "--NnlfModelName=[stage1/conversion/full_path_filename]" + }, + "encdec_tvd_valid": { + "vtm_option": "--NnlfDebugOption=1 --NnlfOption=1 --NnlfModelName=[stage1/conversion/full_path_filename]", + "vtm_dec_option": "--NnlfModelName=[stage1/conversion/full_path_filename]" + }, + "training": { + "mse_epoch": 58, + "max_epochs": 60, + "component_loss_weightings": [ + 6, + 2 + ], + "dataloader": { + "batch_size": 128 + }, + "optimizer": { + "lr": 0.0004 + }, + "lr_scheduler": { + "milestones": [ + 55, + 57, + 58 + ] + }, + "dct_size": 2 + } + }, + + + "stage3": { + "encdec_bvi": { + "vtm_option": "--NnlfOption=1 --NnlfModelName=[stage2/quantize/full_path_filename]", + "vtm_dec_option": "--NnlfModelName=[stage2/quantize/full_path_filename]" + }, + "encdec_bvi_valid": { + "vtm_option": "--NnlfOption=1 --NnlfModelName=[stage2/quantize/full_path_filename]", + "vtm_dec_option": "--NnlfModelName=[stage2/quantize/full_path_filename]" + }, + "encdec_tvd": { + "vtm_option": "--NnlfOption=1 --NnlfModelName=[stage2/quantize/full_path_filename]", + "vtm_dec_option": "--NnlfModelName=[stage2/quantize/full_path_filename]" + }, + "encdec_tvd_valid": { + "vtm_option": "--NnlfOption=1 --NnlfModelName=[stage2/quantize/full_path_filename]", + "vtm_dec_option": "--NnlfModelName=[stage2/quantize/full_path_filename]" + }, + "training": { + "mse_epoch": 58, + "max_epochs": 60, + "component_loss_weightings": [ + 12, + 2 + ], + "dataloader": { + "batch_size": 64 + }, + "optimizer": { + "lr": 0.0002 + }, + "lr_scheduler": { + "milestones": [ + 41, + 51, + 55 + ] + }, + "dct_size": 2 + } + } + + +} diff --git a/training/training_scripts/NN_Filtering/LOP5/model/model.json b/training/training_scripts/NN_Filtering/LOP5/model/model.json new file mode 100644 index 0000000000000000000000000000000000000000..cbed570ca5813b71490f07cf7e2e10e9d3dd0d8a --- /dev/null +++ b/training/training_scripts/NN_Filtering/LOP5/model/model.json @@ -0,0 +1,45 @@ +{ "model" : { + "path": "../model/model.py", + "class" : "Net", + "input_channels" : [ + [ + "rec_before_dbf_Y", + "rec_before_dbf_U", + "rec_before_dbf_V" + ], + [ + "pred_Y", + "pred_U", + "pred_V" + ], + [ + "bs_Y", + "bs_U", + "bs_V" + ], + [ "qp_base" ], + [ "qp_slice" ], + [ "ipb_Y" ] + ], + "input_kernels" : [ + 3, + 3, + 1, + 1, + 1, + 1 + ], + "D1" : 16, + "D2" : 8, + "D3" : 4, + "D4" : 2, + "D5" : 2, + "D6" : 64, + "N_Y" : [2, 2, 2, 3], + "N_UV" :5, + "C" : 32, + "C1_Y" : 176, + "C1_UV" : 80, + "C21" : 32, + "dct_ch" : 4 + } } diff --git a/training/training_scripts/NN_Filtering/LOP5/model/model.py b/training/training_scripts/NN_Filtering/LOP5/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4136ebdbb9a11dd22d7e7f23ded036328de383 --- /dev/null +++ b/training/training_scripts/NN_Filtering/LOP5/model/model.py @@ -0,0 +1,730 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import Union, Tuple, Optional, Type, Iterable, List, Dict + +import torch +from torch import nn +from torch.nn import functional as F + + +class Conv_dw(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = None, + is_separable: bool = False, + hidden_separable_channels: Optional[int] = None, + post_activation: Optional[Type] = nn.PReLU, + **kwargs, + ): + """ + Args: + in_channels: the number of input channels + out_channels: the number of output channels + kernel_size: the convolution's kernel size + stride: the convolution's stride(s) + padding: the convolution's padding + is_separable: whether to implement convolution separably + hidden_separable_channels: If is_separable, the number of hidden channels between convolutions. If None, use out_channels + post_activation: activation function to use after convolution. If None, no activation after convolution + **kwargs: additional kwargs to pass to nn.Conv2d + """ + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + ) + self.stride = (stride, stride) if isinstance(stride, int) else stride + if padding is not None: + self.padding = (padding, padding) if isinstance(padding, int) else padding + else: + self.padding = tuple([k // 2 for k in self.kernel_size]) + self.is_separable = is_separable + self.post_activation = post_activation + + if self.is_separable: + self.hidden_separable_channels = hidden_separable_channels or out_channels + modules = [ + nn.Conv2d( + self.in_channels, + self.hidden_separable_channels, + (self.kernel_size[0], self.kernel_size[1]), + (self.stride[0], self.stride[1]), + (self.padding[0], self.padding[1]), + groups=self.hidden_separable_channels, + **kwargs, + ) + ] + else: + modules = [ + nn.Conv2d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + **kwargs, + ) + ] + + if self.post_activation is not None: + modules.append(self.post_activation()) + + super(Conv_dw, self).__init__(*modules) + + +class Conv(nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Optional[Union[int, Tuple[int, int]]] = None, + is_separable: bool = False, + is_horizontal: bool = True, + hidden_separable_channels: Optional[int] = None, + post_activation: Optional[Type] = nn.PReLU, + **kwargs, + ): + """ + Args: + in_channels: the number of input channels + out_channels: the number of output channels + kernel_size: the convolution's kernel size + stride: the convolution's stride(s) + padding: the convolution's padding + is_separable: whether to implement convolution separably + hidden_separable_channels: If is_separable, the number of hidden channels between convolutions. If None, use out_channels + post_activation: activation function to use after convolution. If None, no activation after convolution + **kwargs: additional kwargs to pass to nn.Conv2d + """ + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = ( + (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + ) + self.stride = (stride, stride) if isinstance(stride, int) else stride + if padding is not None: + self.padding = (padding, padding) if isinstance(padding, int) else padding + else: + self.padding = tuple([k // 2 for k in self.kernel_size]) + self.is_separable = is_separable + self.is_horizontal = is_horizontal + self.post_activation = post_activation + + if self.is_separable: + self.hidden_separable_channels = hidden_separable_channels or out_channels + if self.is_horizontal: + modules = [ + nn.Conv2d( + self.in_channels, + self.hidden_separable_channels, + (self.kernel_size[0], 1), + (self.stride[0], 1), + (self.padding[0], 0), + groups=self.hidden_separable_channels, + **kwargs, + ) + ] + else: + modules = [ + nn.Conv2d( + self.hidden_separable_channels, + self.out_channels, + (1, self.kernel_size[1]), + (1, self.stride[1]), + (0, self.padding[1]), + groups=self.hidden_separable_channels, + **kwargs, + ) + ] + else: + modules = [ + nn.Conv2d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + **kwargs, + ) + ] + + if self.post_activation is not None: + modules.append(self.post_activation()) + + super(Conv, self).__init__(*modules) + + +class LowerBoundFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs, bound): + bound = torch.ones(inputs.size(), device=inputs.device) * bound + bound = bound.to(inputs.device) + bound = bound.type(inputs.dtype) + ctx.save_for_backward(inputs, bound) + return torch.max(inputs, bound) + + @staticmethod + def backward(ctx, grad_output): + inputs, bound = ctx.saved_tensors + + pass_through_1 = inputs >= bound + pass_through_2 = grad_output < 0 + + pass_through = pass_through_1 | pass_through_2 + return pass_through.type(grad_output.dtype) * grad_output, None + + +def clamp(x, lower, upper): + return LowerBoundFunction.apply(-LowerBoundFunction.apply(-x, -upper), lower) + + +class HardSigmoid(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_parameter('_scalar', torch.nn.Parameter(torch.Tensor([1.0 / 6.0]), requires_grad=False)) + self.register_parameter('_bias', torch.nn.Parameter(torch.Tensor([0.5]), requires_grad=False)) + self.register_parameter('_min', torch.nn.Parameter(torch.zeros(1), requires_grad=False)) + self.register_parameter('_max', torch.nn.Parameter(torch.ones(1), requires_grad=False)) + + def forward(self, x): + if self.training: + return clamp(x / 6 + 0.5, 0, 1) + else: + return torch.max(torch.min(x * self._scalar + self._bias, self._max), self._min) + + +class ConvLayer(nn.Module): + def __init__( + self, + in_ch, + out_ch, + kernel_size, + stride, + groups=1, + padding=None, + bias=True, + *args, + **kwargs, + ): + super().__init__() + + self.conv2d = nn.Conv2d( + in_ch, + out_ch, + kernel_size, + stride, + padding=padding if padding is not None else kernel_size // 2, + groups=groups, + padding_mode="zeros", + bias=bias, + *args, + **kwargs, + ) + if self.conv2d.bias is not None: + self.conv2d.bias.data.fill_(0.0) + + def forward(self, x): + out = self.conv2d(x) + return out + + +class DepthConvBlock(nn.Module): + def __init__( + self, + in_ch, + out_ch, + kernel_size=3, + stride=1, + ffn_expansion=1, + act_first=False, + act_last=True, + *args, + **kwargs, + ): + super().__init__() + + self.in_ch = in_ch + self.out_ch = out_ch + + dw_ch = int(self.out_ch * ffn_expansion) + + self.conv1 = nn.Sequential( + ConvLayer( + dw_ch, + dw_ch, + kernel_size, + stride, + groups=dw_ch, + ), + ConvLayer(dw_ch, self.out_ch, 1, 1), + (nn.Identity() if not act_last else nn.PReLU(self.out_ch, init=0.2)), + ) + + if (self.in_ch != self.out_ch) and stride == 1: + self.skip = ConvLayer(self.in_ch, self.out_ch, 1, 1) + elif stride == 2: + self.skip = nn.Sequential( + ConvLayer(self.in_ch, self.out_ch, 1, 1), + ConvLayer( + self.out_ch, + self.out_ch, + 3, + 2, + groups=self.out_ch, + ), + ) + else: + self.skip = nn.Identity() + + def forward(self, x): + x = self.conv1(x) + self.skip(x) + return x + + +class SpatialAttention(nn.Module): + + def __init__( + self, + in_ch, + esa_ch, + ffn_expansion=None, + mask_type="sigmoid", + ): + super().__init__() + + assert mask_type in ["rescale", "sigmoid", "hard_sigmoid"] + + self.head_branch = nn.Sequential( + ConvLayer(in_ch, esa_ch, 1, 1), + nn.PReLU(esa_ch, init=0.2), + ) + self.pool_branch_one = DepthConvBlock( + esa_ch, + esa_ch, + 3, + 1, + ffn_expansion=ffn_expansion, + ) + self.tail_branch = nn.Sequential( + ConvLayer(esa_ch, esa_ch, 1, 1), + nn.PReLU(esa_ch, init=0.2), + ) + self.out = ConvLayer(esa_ch, in_ch, 1, 1) + + self.mask_type = mask_type + + if self.mask_type == "rescale": + self.mask_scale = nn.Parameter(torch.empty(1, in_ch, 1, 1)) + self.mask_shift = nn.Parameter(torch.empty(1, in_ch, 1, 1)) + nn.init.constant_(self.mask_scale, 0.1) + nn.init.constant_(self.mask_shift, 0.5) + elif self.mask_type == "hard_sigmoid": + self.hard_sigmoid = HardSigmoid() + self.mask_loss = 0.0 + + def forward(self, x): + main_branch = self.head_branch(x) + pool_branch = F.max_pool2d(main_branch, kernel_size=2, stride=2) + pool_branch = self.pool_branch_one(pool_branch) + for _ in range(1): + pool_branch = F.interpolate( + pool_branch, + size=None, + scale_factor=2, + mode="bilinear", + align_corners=False, + ) + main_branch = self.tail_branch(main_branch) + m = self.out(pool_branch + main_branch) + if self.mask_type == "sigmoid": + m = 0.5 * (1 + (m / (1 + torch.abs(x)))) + elif self.mask_type == "rescale": + m = self.mask_scale * m + self.mask_shift + self.mask_loss = self.calculate_mask_loss(m) + elif self.mask_type == "hard_sigmoid": + m = self.hard_sigmoid(m) + else: + raise NotImplementedError + return x * m + + def calculate_mask_loss(self, m): + m_valid = torch.ones_like(m) + m_valid[(m <= 1.0) & (m >= 0.0)] = 0 + m = m * m_valid + return torch.mean(m**2) + + +class MultiBranchModule(nn.Module): + """A module representing multple, parallel branches. If the input is a list, each element in the list is fed into the corresponding branch, + otherwise the input is fed into every branch. The outputs of each branch are then merged.""" + + def __init__(self, *branch_modules, merge_dimension: int = -3): + """ + Args: + branch_modules: modules to run in parallel + merge_dimension: the dimension to merge outputs from each branch + """ + super().__init__() + self.branches = nn.ModuleList(branch_modules) + self.merge_dimension = merge_dimension + + def forward(self, args: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: + inputs = args if isinstance(args, list) else len(self.branches) * [args] + branch_outputs = [branch(input) for branch, input in zip(self.branches, inputs)] + return torch.cat(branch_outputs, dim=self.merge_dimension) + + +class NewResBlock_separate_prelu(nn.Sequential): + def __init__(self, C: int = 64, C1: int = 160, C21: int = 32): + super().__init__() + self.prelu = nn.PReLU() + self.conv1_11 = Conv(C1, C, kernel_size=1, post_activation=None) + self.conv2_13 = Conv(C, C, kernel_size=3, post_activation=None, is_separable=True, is_horizontal=True, hidden_separable_channels=C21) + self.conv3_31 = Conv(C, C, kernel_size=3, post_activation=None, is_separable=True, is_horizontal=False, hidden_separable_channels=C21) + self.conv4_11 = Conv(C, C1, kernel_size=1, post_activation=None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + temp = x + x1 = self.prelu(x) + x2 = self.conv1_11(x1) + x3 = self.conv2_13(x2) + x4 = self.conv3_31(x3) + x5 = self.conv4_11(x4) + return x5 + temp + + +class NewResBlock_separate_prelu_crop(nn.Sequential): + def __init__(self, C: int = 64, C1: int = 160, C21: int = 32): + super().__init__() + self.prelu = nn.PReLU() + self.conv1_11 = Conv(C1, C, kernel_size=1, post_activation=None) + self.conv2_13 = Conv(C, C, kernel_size=3, padding=(0, 0), post_activation=None, is_separable=True, is_horizontal=True, hidden_separable_channels=C21) + self.conv3_31 = Conv(C, C, kernel_size=3, padding=(0, 0), post_activation=None, is_separable=True, is_horizontal=False, hidden_separable_channels=C21) + self.conv4_11 = Conv(C, C1, kernel_size=1, post_activation=None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + temp = x[:, :, 1:-1, 1:-1] + x1 = self.prelu(x) + x2 = self.conv1_11(x1) + x3 = self.conv2_13(x2) + x4 = self.conv3_31(x3) + x5 = self.conv4_11(x4) + return x5 + temp + + +class NewResBlock_separate_prelu_chroma(nn.Sequential): + def __init__(self, C: int = 64, C1: int = 160, C21: int = 32): + super().__init__() + self.prelu = nn.PReLU() + self.conv1_11 = Conv(C1, C, kernel_size=1, post_activation=None) + self.conv2_13 = Conv_dw(C, C, kernel_size=3, post_activation=None, is_separable=True, hidden_separable_channels=C21) + self.conv4_11 = Conv(C, C1, kernel_size=1, post_activation=None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + temp = x + x1 = self.prelu(x) + x2 = self.conv1_11(x1) + x3 = self.conv2_13(x2) + x4 = self.conv4_11(x3) + return x4 + temp + + +class NewResBlock_separate_prelu_chroma_crop(nn.Sequential): + def __init__(self, C: int = 64, C1: int = 160, C21: int = 32): + super().__init__() + self.prelu = nn.PReLU() + self.conv1_11 = Conv(C1, C, kernel_size=1, post_activation=None) + self.conv2_13 = Conv_dw(C, C, kernel_size=3, padding=(0, 0), post_activation=None, is_separable=True, hidden_separable_channels=C21) + self.conv4_11 = Conv(C, C1, kernel_size=1, post_activation=None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + temp = x[:, :, 1:-1, 1:-1] + x1 = self.prelu(x) + x2 = self.conv1_11(x1) + x3 = self.conv2_13(x2) + x4 = self.conv4_11(x3) + return x4 + temp + + +class NewResBlock_separate_prelu_group(nn.Sequential): + def __init__(self, C: int = 64, C1: int = 160, C21: int = 32, cAtten: int = 28, n: int = 2): + super().__init__() + self.convBlock = nn.ModuleList() + for i in range(n): + self.convBlock.append(NewResBlock_separate_prelu(C, C1, C21)) + + self.atten = SpatialAttention(C1, cAtten, 1, "hard_sigmoid") + self.act = nn.PReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = x + for block in self.convBlock: + x1 = block(x1) + + x2 = self.act(x1) + x3 = x2 + x + x4 = self.atten(x3) + return x4 + + +class Slice_only_y(nn.Sequential): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[:, :, 1:-1, :] + + +class Slice_only_x(nn.Sequential): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[:, :, :, 1:-1] + + +class SplitLumaChromaBlocks(nn.Sequential): + def __init__( + self, + N_Y: int = 12, + N_UV: int = 6, + C: int = 16, + C1_Y: int = 64, + C1_UV: int = 48, + C21: int = 16, + output_channels_y: int = 4, + output_channels_uv: int = 2, + ): + super().__init__() + + self.split_y_path = nn.Sequential( + Conv(C, C1_Y, kernel_size=1, post_activation=None), + *[NewResBlock_separate_prelu_crop(C, C1_Y, C21) for n in range(2)], + *[NewResBlock_separate_prelu_group(C, C1_Y, C21, 32, n) for n in N_Y], + Conv(C1_Y, C, kernel_size=1, post_activation=None), + Conv(C, C, kernel_size=3, is_separable=True, is_horizontal=True, hidden_separable_channels=C21, post_activation=None), + Conv(C, C, kernel_size=3, is_separable=True, is_horizontal=False, hidden_separable_channels=C21, post_activation=None), + Conv(C, C, kernel_size=1), + Conv(C, output_channels_y, kernel_size=3, post_activation=None), + ) + + self.split_uv_path = nn.Sequential( + Conv(C, C1_UV, kernel_size=1, post_activation=None), + *[NewResBlock_separate_prelu_chroma_crop(C, C1_UV, C21) for n in range(2)], + *[NewResBlock_separate_prelu_chroma(C, C1_UV, C21) for n in range(N_UV - 2)], + Conv(C1_UV, C, kernel_size=1, post_activation=None), + Conv(C, C, kernel_size=3, is_separable=True, is_horizontal=True, hidden_separable_channels=C21, post_activation=None), + Conv(C, C, kernel_size=3, is_separable=True, is_horizontal=False, hidden_separable_channels=C21, post_activation=None), + Conv(C, C, kernel_size=1), + Conv(C, output_channels_uv, kernel_size=3, post_activation=None), + ) + + self.Cy = C + self.Cuv = C + + def forward(self, x: torch.Tensor) -> torch.Tensor: + split_y_input = x[:, : self.Cy, :, :] + split_uv_input = x[:, self.Cy : self.Cy + self.Cuv, :, :] + + y_output = self.split_y_path.forward(split_y_input) + uv_output = self.split_uv_path.forward(split_uv_input) + return torch.cat((y_output, uv_output), dim=1) + + +class SADLNet(nn.Sequential): + """The network used during SADL inference""" + + def __init__( + self, + input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1], + input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3], + D1: int = 12, + D2: int = 8, + D3: int = 4, + D4: int = 2, + D5: int = 2, + D6: int = 24, + N_Y: int = 12, + N_UV: int = 6, + C: int = 16, + C1_Y: int = 64, + C1_UV: int = 48, + C21: int = 16, + output_channels_y: int = 4, + output_channels_uv: int = 2, + ): + """ + Args: + input_channels: the number of channels expected for each input + input_kernels: the kernel size for each input convolution + output_channels: the number of output channels + """ + self.input_channels = input_channels + self.input_kernels = input_kernels + self.input_features = [D1, D2, D3, D4, D4, D5] + super(SADLNet, self).__init__( + MultiBranchModule( + *[ + Conv(c, d, kernel_size=k, post_activation=None) + for c, d, k in zip( + self.input_channels, self.input_features, self.input_kernels + ) + ] + ), + Conv(sum(self.input_features), D6, kernel_size=1), + Conv(D6, D6, kernel_size=3, stride=2, post_activation=None, is_separable=True, is_horizontal=True, hidden_separable_channels=D6), + Conv(D6, D6, kernel_size=3, stride=2, post_activation=None, is_separable=True, is_horizontal=False, hidden_separable_channels=D6), + Conv(D6, C + C, kernel_size=1), + SplitLumaChromaBlocks(N_Y, N_UV, C, C1_Y, C1_UV, C21, output_channels_y, output_channels_uv), + ) + + def get_example_inputs( + self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1 + ): + patch_size = ( + (patch_size, patch_size) if isinstance(patch_size, int) else patch_size + ) + return [ + torch.rand( + batch_size, conv.in_channels, *patch_size, device=conv[0].weight.device + ) + for conv in self[0].branches + ] + + def to_onnx( + self, + filename: str, + patch_size: int = 144, + batch_size: int = 1, + opset: int = 11, + **kwargs, + ) -> None: + mode = self.training + self.eval() + torch.onnx.export( + self, + self.get_example_inputs(patch_size, batch_size), + filename, + opset_version=opset, + **kwargs, + ) + self.train(mode) + + +class Net(nn.Module): + """Wrapper for SADL model that implements input pre- and post-processing for training.""" + + def __init__( + self, + input_channels: Iterable[Iterable[str]] = [ + ["rec_before_dbf_Y", "rec_before_dbf_U", "rec_before_dbf_V"], + ["pred_Y", "pred_U", "pred_V"], + ["bs_Y", "bs_U", "bs_V"], + ["qp_base"], + ["qp_slice"], + ["ipb_Y"], + ], + input_kernels: Iterable[int] = [3, 3, 1, 1, 1, 1], + D1: int = 16, + D2: int = 8, + D3: int = 4, + D4: int = 2, + D5: int = 2, + D6: int = 64, + N_Y: int = [2, 2, 2, 2, 2], + N_UV: int = 5, + C: int = 32, + C1_Y: int = 144, + C1_UV: int = 80, + C21: int = 32, + dct_ch: int = 4, + path: str = None, + ): + super(Net, self).__init__() + assert len(input_channels) == len( + input_kernels + ), "[ERROR] input size and kernels size not equal" + self.input_channels = input_channels + sizes = [dct_ch + dct_ch // 2, dct_ch + dct_ch // 2, 3, 1, 1, 1] + self.SADL_model = SADLNet( + sizes, + input_kernels, + D1, + D2, + D3, + D4, + D5, + D6, + N_Y, + N_UV, + C, + C1_Y, + C1_UV, + C21, + 4 * dct_ch, + 2 * dct_ch + ) + self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest") + self.dct_ch = dct_ch + + def preprocess_args( + self, batch: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + return [ + torch.cat([batch[name] for name in input_], dim=1) + for input_ in self.input_channels + ] + + def postprocess_outputs( + self, batch: Dict[str, torch.Tensor], out: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + Y_res, UV_res = out.split([4 * self.dct_ch, 2 * self.dct_ch], dim=1) + return ( + F.pixel_shuffle(Y_res, 2), + UV_res, + ) + + def forward( + self, batch: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + args = self.preprocess_args(batch) + out = self.SADL_model(args) + return self.postprocess_outputs(batch, out) diff --git a/training/training_scripts/NN_Filtering/LOP5/paths.json b/training/training_scripts/NN_Filtering/LOP5/paths.json new file mode 100644 index 0000000000000000000000000000000000000000..a3df863c28d1506fca78eaa37053902871358274 --- /dev/null +++ b/training/training_scripts/NN_Filtering/LOP5/paths.json @@ -0,0 +1,200 @@ +{ + "binaries": { + "sadl_path": "/path/to/src/sadl" + }, + + "model" : { + "path": "/path/to/src/training/training_scripts/NN_Filtering/LOP5/model/model.py" + }, + + "dataset": { + "div2k_train": { + "//": "dataset of png to convert", + "path": "/path/to/DIV2K/DIV2K_train_HR" + }, + "div2k_valid": { + "path": "/path/to/DIV2K/DIV2K_valid_HR" + }, + "bvi": { + "//": "dataset of yuv", + "path": "/path/to/bviorg", + "dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/bvi.json" + }, + "bvi_valid": { + "//": "dataset of yuv", + "path": "/path/to/bviorg", + "dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/bvi_valid.json" + }, + "tvd": { + "//": "dataset of yuv", + "path": "/path/to/tvdorg", + "dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/tvd.json" + }, + "tvd_valid": { + "//": "dataset of yuv", + "path": "/path/to/tvdorg", + "dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/tvd_valid.json" + } + }, + + + "stage1": { + "yuv": { + "//": "path to store yuv files dataset", + "path": "/path/to/stage1/yuv" + }, + "yuv_valid": { + "//": "path to store yuv files dataset", + "path": "/path/to/stage1/yuv" + }, + "encdec": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage1/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_intra_vtm.cfg" + }, + "encdec_valid": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage1/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_intra_vtm.cfg" + }, + "dataset": { + "//": "path to store the full dataset which will be used by the training", + "path": "/path/to/stage1/dataset" + }, + "dataset_valid": { + "//": "path to store the full dataset which will be used by the training", + "path": "/path/to/stage1/dataset" + }, + "training": { + "path": "/path/to/stage1/train" + }, + "conversion": { + "//": "full path to output the model. input model is taken in training/ckpt_dir", + "full_path_filename": "/path/to/stage1/train/model_float.sadl" + } + }, + + + "stage2": { + "yuv_tvd": { + "//": "path to store yuv files dataset", + "path": "/path/to/stage2/yuv" + }, + "yuv_tvd_valid": { + "//": "path to store yuv files dataset", + "path": "/path/to/stage2/yuv" + }, + "yuv_bvi": { + "//": "path to store yuv files dataset", + "path": "/path/to/stage2/yuv" + }, + "yuv_bvi_valid": { + "//": "path to store yuv files dataset", + "path": "/path/to/stage2/yuv" + }, + "encdec_bvi": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage2/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg" + }, + "encdec_bvi_valid": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage2/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg" + }, + "encdec_tvd": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage2/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg" + }, + "encdec_tvd_valid": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage2/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg" + }, + "dataset": { + "//": "path to store the full dataset which will be used by the training", + "path": "/path/to/stage2/dataset" + }, + "dataset_valid": { + "//": "path to store the full dataset which will be used by the training", + "path": "/path/to/stage2/dataset" + }, + "training": { + "path": "/path/to/stage2/train" + }, + "conversion": { + "//": "full path to output the model. input model is taken in training/ckpt_dir", + "full_path_filename": "/path/to/stage2/train/model_float.sadl" + }, + "quantize": { + "//": "full path to output the quantized model", + "full_path_filename": "/path/to/stage2/train/model_int16.sadl" + } + }, + + + "stage3": { + "encdec_bvi": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage3/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg" + }, + "encdec_bvi_valid": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage3/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg" + }, + "encdec_tvd": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage3/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg" + }, + "encdec_tvd_valid": { + "//": "path to store the shell script and all the generated files by the encoder/decoder", + "path": "/path/to/stage3/encdec", + "vtm_enc": "/path/to/src/bin/EncoderAppStatic", + "vtm_dec": "/path/to/src/bin/DecoderAppStatic", + "vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg" + }, + "dataset": { + "//": "path to store the full dataset which will be used by the training", + "path": "/path/to/stage3/dataset" + }, + "dataset_valid": { + "//": "path to store the full dataset which will be used by the training", + "path": "/path/to/stage3/dataset" + }, + "training": { + "path": "/path/to/stage3/train" + }, + "conversion": { + "//": "full path to output the model. input model is taken in training/ckpt_dir", + "full_path_filename": "/path/to/stage3/train/model_float.sadl" + }, + "quantize": { + "//": "full path to output the quantized model", + "full_path_filename": "/path/to/stage3/train/model_int16.sadl" + } + } + +} + diff --git a/training/training_scripts/NN_Filtering/LOP5/quantize/quantizer.txt b/training/training_scripts/NN_Filtering/LOP5/quantize/quantizer.txt new file mode 100644 index 0000000000000000000000000000000000000000..a566c68d7b4193993b113f396d9e61ae20d7fca5 --- /dev/null +++ b/training/training_scripts/NN_Filtering/LOP5/quantize/quantizer.txt @@ -0,0 +1 @@ +0 12 1 12 2 12 3 12 4 12 5 12 6 12 7 0 8 12 10 12 11 0 12 12 14 12 15 0 16 12 18 12 19 0 20 12 22 12 23 0 24 12 26 12 27 0 28 12 30 0 32 12 33 0 34 12 36 10 38 10 39 0 40 11 42 11 43 0 44 10 46 10 47 0 48 10 50 10 54 10 55 0 56 10 60 10 62 10 63 0 64 10 66 10 67 0 68 10 70 10 71 0 72 10 74 10 75 0 76 10 81 12 83 11 84 0 85 10 87 11 88 0 89 10 91 10 92 0 93 10 95 10 96 0 97 10 100 10 102 10 103 0 104 10 106 10 107 0 108 10 110 10 111 0 112 10 114 10 115 0 116 10 119 10 121 10 122 0 123 10 125 10 126 0 127 10 129 10 130 0 131 10 133 10 134 0 135 10 138 10 141 10 142 0 143 10 145 10 148 10 149 0 150 10 152 10 153 0 154 10 156 10 159 0 161 10 162 0 163 10 165 12 168 12 169 0 170 11 172 12 173 0 174 12 176 12 178 10 180 0 181 12 183 12 184 0 185 11 187 11 188 0 189 10 191 10 192 0 193 10 195 10 196 0 197 10 200 10 202 11 203 0 204 10 206 10 207 0 208 10 210 10 211 0 212 10 214 10 215 0 216 10 219 10 222 10 223 0 224 10 226 10 229 10 230 0 231 10 233 10 234 0 235 10 237 10 240 0 242 10 243 0 244 10 246 12 249 12 250 0 251 10 253 12 254 0 255 12 257 12 259 12 261 0 262 12 264 12 265 0 266 11 268 11 269 0 270 11 272 10 273 0 274 11 276 10 277 0 278 11 281 12 283 11 284 0 285 10 287 10 288 0 289 10 291 10 292 0 293 10 295 11 296 0 297 10 300 10 303 10 304 0 305 10 307 10 310 10 311 0 312 10 314 10 315 0 316 10 318 10 321 0 323 10 324 0 325 10 327 10 330 12 331 0 332 10 334 12 335 0 336 12 338 12 340 12 342 0 343 12 345 12 346 0 347 12 349 10 350 0 351 11 353 10 354 0 355 11 357 11 358 0 359 11 362 12 364 12 365 0 366 11 368 11 369 0 370 11 372 10 373 0 374 11 376 12 377 0 378 11 381 11 383 11 384 0 385 10 387 10 388 0 389 10 391 10 392 0 393 10 395 10 396 0 397 10 400 10 403 12 404 0 405 10 407 12 410 11 411 0 412 11 414 12 415 0 416 12 418 12 421 0 423 12 424 0 425 11 427 12 430 12 431 0 432 11 434 12 435 0 436 12 438 12 440 12 442 0 443 12 444 0 445 12 447 12 448 0 449 12 451 12 452 0 453 12 455 12 456 0 457 12 459 12 461 12 462 0 463 12 465 12 466 0 467 10 471 12 473 12 474 0 475 12 477 12 478 0 479 12 481 12 482 0 483 12 488 10 490 10 491 0 492 11 494 12 495 0 496 12 498 12 499 0 500 12 503 12 505 12 506 0 507 12 509 12 510 0 511 12 513 12 514 0 515 11 518 10 520 11 521 0 522 11 524 11 525 0 526 11 528 12 529 0 530 11 533 11 535 11 536 0 537 11 539 12 540 0 541 11 543 12 544 0 545 11 548 11 549 0 550 11 552 12 553 0 554 12 556 12 557 0 558 12 560 12 561 0 562 12 564 12 566 12 567 0 568 12 570 0 \ No newline at end of file diff --git a/training/training_scripts/NN_Filtering/LOP5/readme.md b/training/training_scripts/NN_Filtering/LOP5/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..5006d9b2fc13278bb94e45c8971c2a1194c83a61 --- /dev/null +++ b/training/training_scripts/NN_Filtering/LOP5/readme.md @@ -0,0 +1,16 @@ +# Training Stage 3 +Use the json files in LOP5 +``` +python3 src/training/tools/create_config.py src/training/training_scripts/NN_Filtering/common/common_config.json src/training/training_scripts/NN_Filtering/LOP5/config.json src/training/training_scripts/NN_Filtering/LOP5/model/model.json src/training/training_scripts/NN_Filtering/LOP5/paths.json > my_config.json + +python3 training_scripts/NN_Filtering/common/training/main.py --json_config my_config.json --stage 3 +``` + +Converting to SADL model +``` +python3 training_scripts/NN_Filtering/common/convert/to_sadl.py --json_config my_config.json --input_model stage3/training --output_model stage3/conversion/full_path_filename + +``` + +Quantization with SADL naive_quantization for the model; quantization string is in LOP5/quantize/quantizer.txt +