Skip to content
Snippets Groups Projects
Commit fe06fe41 authored by Franck Galpin's avatar Franck Galpin
Browse files

Merge branch 'AH0080_training_scripts' into 'VTM-11.0_nnvc'

JVET-AH0080 training scripts

See merge request !211
parents f56b868a f4aaa401
No related branches found
No related tags found
No related merge requests found
{
"stage1": {
"training": {
"mse_epoch": 126,
"max_epochs": 130,
"component_loss_weightings": [
6,
2
],
"dataloader": {
"batch_size": 128
},
"optimizer": {
"lr": 0.0004
},
"lr_scheduler": {
"milestones": [
120,
123,
126
]
},
"dct_size": 2
}
},
"stage2": {
"encdec_bvi": {
"vtm_option": "--NnlfDebugOption=1 --NnlfOption=1 --NnlfModelName=[stage1/conversion/full_path_filename]",
"vtm_dec_option": "--NnlfModelName=[stage1/conversion/full_path_filename]"
},
"encdec_bvi_valid": {
"vtm_option": "--NnlfDebugOption=1 --NnlfOption=1 --NnlfModelName=[stage1/conversion/full_path_filename]",
"vtm_dec_option": "--NnlfModelName=[stage1/conversion/full_path_filename]"
},
"encdec_tvd": {
"vtm_option": "--NnlfDebugOption=1 --NnlfOption=1 --NnlfModelName=[stage1/conversion/full_path_filename]",
"vtm_dec_option": "--NnlfModelName=[stage1/conversion/full_path_filename]"
},
"encdec_tvd_valid": {
"vtm_option": "--NnlfDebugOption=1 --NnlfOption=1 --NnlfModelName=[stage1/conversion/full_path_filename]",
"vtm_dec_option": "--NnlfModelName=[stage1/conversion/full_path_filename]"
},
"training": {
"mse_epoch": 58,
"max_epochs": 60,
"component_loss_weightings": [
6,
2
],
"dataloader": {
"batch_size": 128
},
"optimizer": {
"lr": 0.0004
},
"lr_scheduler": {
"milestones": [
55,
57,
58
]
},
"dct_size": 2
}
},
"stage3": {
"encdec_bvi": {
"vtm_option": "--NnlfOption=1 --NnlfModelName=[stage2/quantize/full_path_filename]",
"vtm_dec_option": "--NnlfModelName=[stage2/quantize/full_path_filename]"
},
"encdec_bvi_valid": {
"vtm_option": "--NnlfOption=1 --NnlfModelName=[stage2/quantize/full_path_filename]",
"vtm_dec_option": "--NnlfModelName=[stage2/quantize/full_path_filename]"
},
"encdec_tvd": {
"vtm_option": "--NnlfOption=1 --NnlfModelName=[stage2/quantize/full_path_filename]",
"vtm_dec_option": "--NnlfModelName=[stage2/quantize/full_path_filename]"
},
"encdec_tvd_valid": {
"vtm_option": "--NnlfOption=1 --NnlfModelName=[stage2/quantize/full_path_filename]",
"vtm_dec_option": "--NnlfModelName=[stage2/quantize/full_path_filename]"
},
"training": {
"mse_epoch": 58,
"max_epochs": 60,
"component_loss_weightings": [
12,
2
],
"dataloader": {
"batch_size": 64
},
"optimizer": {
"lr": 0.0002
},
"lr_scheduler": {
"milestones": [
41,
51,
55
]
},
"dct_size": 2
}
}
}
{ "model" : {
"path": "../../LOP3/model/model.py",
"class" : "Net",
"input_channels" : [
[
"rec_before_dbf_Y",
"rec_before_dbf_U",
"rec_before_dbf_V"
],
[
"pred_Y",
"pred_U",
"pred_V"
],
[
"bs_Y",
"bs_U",
"bs_V"
],
[ "qp_base" ],
[ "qp_slice" ],
[ "ipb_Y" ]
],
"input_kernels" : [
3,
3,
1,
1,
1,
1
],
"D1": 16,
"D2": 8,
"D3": 4,
"D4": 2,
"D5": 2,
"D6": 64,
"N_Y": 14,
"N_UV": 4,
"C": 32,
"C1_Y": 144,
"C1_UV": 128,
"C21": 32,
"dct_ch": 4
} }
"""
/* The copyright in this software is being made available under the BSD
* License, included below. This software may be subject to other third party
* and contributor rights, including patent rights, and no such rights are
* granted under this license.
*
* Copyright (c) 2010-2024, ITU/ISO/IEC
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
* be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
* THE POSSIBILITY OF SUCH DAMAGE.
"""
from typing import Union, Tuple, Optional, Type, Iterable, List, Dict
import torch
from torch import nn
from torch.nn import functional as F
class Conv(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Optional[Union[int, Tuple[int, int]]] = None,
is_separable: bool = False,
is_horizontal: bool = True,
hidden_separable_channels: Optional[int] = None,
post_activation: Optional[Type] = nn.PReLU,
**kwargs,
):
"""
Args:
in_channels: the number of input channels
out_channels: the number of output channels
kernel_size: the convolution's kernel size
stride: the convolution's stride(s)
padding: the convolution's padding
is_separable: whether to implement convolution separably
hidden_separable_channels: If is_separable, the number of hidden channels between convolutions. If None, use out_channels
post_activation: activation function to use after convolution. If None, no activation after convolution
**kwargs: additional kwargs to pass to nn.Conv2d
"""
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = (
(kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
)
self.stride = (stride, stride) if isinstance(stride, int) else stride
if padding is not None:
self.padding = (padding, padding) if isinstance(padding, int) else padding
else:
self.padding = tuple([k // 2 for k in self.kernel_size])
self.is_separable = is_separable
self.is_horizontal = is_horizontal
self.post_activation = post_activation
if self.is_separable:
self.hidden_separable_channels = hidden_separable_channels or out_channels
if self.is_horizontal:
modules = [
nn.Conv2d(
self.in_channels,
self.hidden_separable_channels,
(self.kernel_size[0], 1),
(self.stride[0], 1),
(self.padding[0], 0),
groups=self.hidden_separable_channels,
**kwargs,
)
]
else:
modules = [
nn.Conv2d(
self.hidden_separable_channels,
self.out_channels,
(1, self.kernel_size[1]),
(1, self.stride[1]),
(0, self.padding[1]),
groups=self.hidden_separable_channels,
**kwargs,
)
]
else:
modules = [
nn.Conv2d(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
**kwargs,
)
]
if self.post_activation is not None:
modules.append(self.post_activation())
super(Conv, self).__init__(*modules)
class MultiBranchModule(nn.Module):
"""A module representing multple, parallel branches. If the input is a list, each element in the list is fed into the corresponding branch,
otherwise the input is fed into every branch. The outputs of each branch are then merged."""
def __init__(self, *branch_modules, merge_dimension: int = -3):
"""
Args:
branch_modules: modules to run in parallel
merge_dimension: the dimension to merge outputs from each branch
"""
super().__init__()
self.branches = nn.ModuleList(branch_modules)
self.merge_dimension = merge_dimension
def forward(self, args: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
inputs = args if isinstance(args, list) else len(self.branches) * [args]
branch_outputs = [branch(input) for branch, input in zip(self.branches, inputs)]
return torch.cat(branch_outputs, dim=self.merge_dimension)
class NewResBlock_separate_prelu(nn.Sequential):
def __init__(self, C: int = 64, C1: int = 160, C21: int = 32):
super().__init__()
self.prelu = nn.PReLU()
self.conv1_11 = Conv(C1, C, kernel_size=1, post_activation=None)
self.conv2_13 = Conv(C, C, kernel_size=3, post_activation=None, is_separable=True, is_horizontal=True, hidden_separable_channels=C21)
self.conv3_31 = Conv(C, C, kernel_size=3, post_activation=None, is_separable=True, is_horizontal=False, hidden_separable_channels=C21)
self.conv4_11 = Conv(C, C1, kernel_size=1, post_activation=None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
temp = x
x1 = self.prelu(x)
x2 = self.conv1_11(x1)
x3 = self.conv2_13(x2)
x4 = self.conv3_31(x3)
x5 = self.conv4_11(x4)
return x5 + temp
class SplitLumaChromaBlocks(nn.Sequential):
def __init__(
self,
N_Y: int = 12,
N_UV: int = 6,
C: int = 16,
C1_Y: int = 64,
C1_UV: int = 48,
C21: int = 16,
output_channels_y: int = 4,
output_channels_uv: int = 2,
):
super().__init__()
self.split_y_path = nn.Sequential(
Conv(C, C1_Y, kernel_size=1, post_activation=None),
*[NewResBlock_separate_prelu(C, C1_Y, C21) for _ in range(N_Y)],
Conv(C1_Y, C, kernel_size=1, post_activation=None),
Conv(C, C, kernel_size=3, is_separable=True, is_horizontal=True, hidden_separable_channels=C21, post_activation=None),
Conv(C, C, kernel_size=3, is_separable=True, is_horizontal=False, hidden_separable_channels=C21, post_activation=None),
Conv(C, C, kernel_size=1),
Conv(C, output_channels_y, kernel_size=3, post_activation=None)
)
self.split_uv_path = nn.Sequential(
Conv(C, C1_UV, kernel_size=1, post_activation=None),
*[NewResBlock_separate_prelu(C, C1_UV, C21) for _ in range(N_UV)],
Conv(C1_UV, C, kernel_size=1, post_activation=None),
Conv(C, C, kernel_size=3, is_separable=True, is_horizontal=True, hidden_separable_channels=C21, post_activation=None),
Conv(C, C, kernel_size=3, is_separable=True, is_horizontal=False, hidden_separable_channels=C21, post_activation=None),
Conv(C, C, kernel_size=1),
Conv(C, output_channels_uv, kernel_size=3, post_activation=None)
)
self.Cy = C
self.Cuv = C
def forward(self, x: torch.Tensor) -> torch.Tensor:
split_y_input = x[:, : self.Cy, :, :]
split_uv_input = x[:, self.Cy : self.Cy + self.Cuv, :, :]
y_output = self.split_y_path.forward(split_y_input)
uv_output = self.split_uv_path.forward(split_uv_input)
return torch.cat((y_output, uv_output), dim=1)
class SADLNet(nn.Sequential):
"""The network used during SADL inference"""
def __init__(
self,
input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1],
input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3],
D1: int = 12,
D2: int = 8,
D3: int = 4,
D4: int = 2,
D5: int = 2,
D6: int = 24,
N_Y: int = 12,
N_UV: int = 6,
C: int = 16,
C1_Y: int = 64,
C1_UV: int = 48,
C21: int = 16,
output_channels_y: int = 4,
output_channels_uv: int = 2,
):
"""
Args:
input_channels: the number of channels expected for each input
input_kernels: the kernel size for each input convolution
output_channels: the number of output channels
"""
self.input_channels = input_channels
self.input_kernels = input_kernels
self.input_features = [D1, D2, D3, D4, D4, D5]
super(SADLNet, self).__init__(
MultiBranchModule(
*[
Conv(c, d, kernel_size=k, post_activation=None)
for c, d, k in zip(
self.input_channels, self.input_features, self.input_kernels
)
]
),
Conv(sum(self.input_features), D6, kernel_size=1),
Conv(D6, D6, kernel_size=3, stride=2, post_activation=None, is_separable=True, is_horizontal=True, hidden_separable_channels=D6),
Conv(D6, D6, kernel_size=3, stride=2, post_activation=None, is_separable=True, is_horizontal=False, hidden_separable_channels=D6),
Conv(D6, C + C, kernel_size=1),
SplitLumaChromaBlocks(N_Y, N_UV, C, C1_Y, C1_UV, C21, output_channels_y, output_channels_uv),
)
def get_example_inputs(
self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1
):
patch_size = (
(patch_size, patch_size) if isinstance(patch_size, int) else patch_size
)
return [
torch.rand(
batch_size, conv.in_channels, *patch_size, device=conv[0].weight.device
)
for conv in self[0].branches
]
def to_onnx(
self,
filename: str,
patch_size: int = 144,
batch_size: int = 1,
opset: int = 10,
**kwargs,
) -> None:
mode = self.training
self.eval()
torch.onnx.export(
self,
self.get_example_inputs(patch_size, batch_size),
filename,
opset_version=opset,
**kwargs,
)
self.train(mode)
class Net(nn.Module):
"""Wrapper for SADL model that implements input pre- and post-processing for training."""
def __init__(
self,
input_channels: Iterable[Iterable[str]] = [
["rec_before_dbf_Y", "rec_before_dbf_U", "rec_before_dbf_V"],
["pred_Y", "pred_U", "pred_V"],
["bs_Y", "bs_U", "bs_V"],
["qp_base"],
["qp_slice"],
["ipb_Y"],
],
input_kernels: Iterable[int] = [3, 3, 1, 1, 1, 1],
D1: int = 16,
D2: int = 8,
D3: int = 4,
D4: int = 2,
D5: int = 2,
D6: int = 64,
N_Y: int = 14,
N_UV: int = 4,
C: int = 32,
C1_Y: int = 144,
C1_UV: int = 128,
C21: int = 32,
dct_ch: int = 4,
path: str = None,
):
super(Net, self).__init__()
assert len(input_channels) == len(
input_kernels
), "[ERROR] input size and kernels size not equal"
self.input_channels = input_channels
sizes = [dct_ch + dct_ch // 2, dct_ch + dct_ch // 2, dct_ch + dct_ch // 2, 1, 1, dct_ch]
self.SADL_model = SADLNet(
sizes,
input_kernels,
D1,
D2,
D3,
D4,
D5,
D6,
N_Y,
N_UV,
C,
C1_Y,
C1_UV,
C21,
4 * dct_ch,
2 * dct_ch
)
self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest")
self.dct_ch = dct_ch
def preprocess_args(
self, batch: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
return [
torch.cat([batch[name] for name in input_], dim=1)
for input_ in self.input_channels
]
def postprocess_outputs(
self, batch: Dict[str, torch.Tensor], out: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Y_res, UV_res = out.split([4 * self.dct_ch, 2 * self.dct_ch], dim=1)
return (
F.pixel_shuffle(Y_res, 2),
UV_res,
)
def forward(
self, batch: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
args = self.preprocess_args(batch)
out = self.SADL_model(args)
return self.postprocess_outputs(batch, out)
{
"binaries": {
"sadl_path": "/path/to/src/sadl"
},
"model" : {
"path": "/path/to/src/training/training_scripts/NN_Filtering/LOP3/model/model.py"
},
"dataset": {
"div2k_train": {
"//": "dataset of png to convert",
"path": "/path/to/DIV2K/DIV2K_train_HR"
},
"div2k_valid": {
"path": "/path/to/DIV2K/DIV2K_valid_HR"
},
"bvi": {
"//": "dataset of yuv",
"path": "/path/to/bviorg",
"dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/bvi.json"
},
"bvi_valid": {
"//": "dataset of yuv",
"path": "/path/to/bviorg",
"dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/bvi_valid.json"
},
"tvd": {
"//": "dataset of yuv",
"path": "/path/to/tvdorg",
"dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/tvd.json"
},
"tvd_valid": {
"//": "dataset of yuv",
"path": "/path/to/tvdorg",
"dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/tvd_valid.json"
}
},
"stage1": {
"yuv": {
"//": "path to store yuv files dataset",
"path": "/path/to/stage1/yuv"
},
"yuv_valid": {
"//": "path to store yuv files dataset",
"path": "/path/to/stage1/yuv"
},
"encdec": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage1/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_intra_vtm.cfg"
},
"encdec_valid": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage1/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_intra_vtm.cfg"
},
"dataset": {
"//": "path to store the full dataset which will be used by the training",
"path": "/path/to/stage1/dataset"
},
"dataset_valid": {
"//": "path to store the full dataset which will be used by the training",
"path": "/path/to/stage1/dataset"
},
"training": {
"path": "/path/to/stage1/train"
},
"conversion": {
"//": "full path to output the model. input model is taken in training/ckpt_dir",
"full_path_filename": "/path/to/stage1/train/model_float.sadl"
}
},
"stage2": {
"yuv_tvd": {
"//": "path to store yuv files dataset",
"path": "/path/to/stage2/yuv"
},
"yuv_tvd_valid": {
"//": "path to store yuv files dataset",
"path": "/path/to/stage2/yuv"
},
"yuv_bvi": {
"//": "path to store yuv files dataset",
"path": "/path/to/stage2/yuv"
},
"yuv_bvi_valid": {
"//": "path to store yuv files dataset",
"path": "/path/to/stage2/yuv"
},
"encdec_bvi": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage2/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg"
},
"encdec_bvi_valid": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage2/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg"
},
"encdec_tvd": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage2/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg"
},
"encdec_tvd_valid": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage2/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg"
},
"dataset": {
"//": "path to store the full dataset which will be used by the training",
"path": "/path/to/stage2/dataset"
},
"dataset_valid": {
"//": "path to store the full dataset which will be used by the training",
"path": "/path/to/stage2/dataset"
},
"training": {
"path": "/path/to/stage2/train"
},
"conversion": {
"//": "full path to output the model. input model is taken in training/ckpt_dir",
"full_path_filename": "/path/to/stage2/train/model_float.sadl"
},
"quantize": {
"//": "full path to output the quantized model",
"full_path_filename": "/path/to/stage2/train/model_int16.sadl"
}
},
"stage3": {
"encdec_bvi": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage3/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg"
},
"encdec_bvi_valid": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage3/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg"
},
"encdec_tvd": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage3/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg"
},
"encdec_tvd_valid": {
"//": "path to store the shell script and all the generated files by the encoder/decoder",
"path": "/path/to/stage3/encdec",
"vtm_enc": "/path/to/src/bin/EncoderAppStatic",
"vtm_dec": "/path/to/src/bin/DecoderAppStatic",
"vtm_cfg": "/path/to/src/cfg/encoder_randomaccess_vtm.cfg"
},
"dataset": {
"//": "path to store the full dataset which will be used by the training",
"path": "/path/to/stage3/dataset"
},
"dataset_valid": {
"//": "path to store the full dataset which will be used by the training",
"path": "/path/to/stage3/dataset"
},
"training": {
"path": "/path/to/stage3/train"
},
"conversion": {
"//": "full path to output the model. input model is taken in training/ckpt_dir",
"full_path_filename": "/path/to/stage3/train/model_float.sadl"
},
"quantize": {
"//": "full path to output the quantized model",
"full_path_filename": "/path/to/stage3/train/model_int16.sadl"
}
}
}
"""
/* The copyright in this software is being made available under the BSD
* License, included below. This software may be subject to other third party
* and contributor rights, including patent rights, and no such rights are
* granted under this license.
*
* Copyright (c) 2010-2024, ITU/ISO/IEC
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
* be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
* THE POSSIBILITY OF SUCH DAMAGE.
"""
# Build SADL first to get sadl/sample_test/naive_quantization.
# The SADL directory, float model name, int model name from my_config.json
# is used.
# Example usage:
# python quantize.py -c my_config.json --stage 3
import os
import argparse
import json
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", "-c", type=str, help="config file"
)
parser.add_argument(
"--stage",
type=int,
choices=range(1, 4),
help="if specified use a model from a particular stage [1-3]",
)
args = parser.parse_args()
return vars(args)
def qfactor_gen(content):
output_q_str = ''
linenumber = 0
q0 = '13'
q_prev = q0
for line in content:
linenumber += 1
if line[:11] == '[INFO] id: ':
layer_str = line.split(' ')
layer_id = layer_str[2]
layer_name = layer_str[-1][:-1]
if layer_name == 'Placeholder':
qq = q0
elif layer_name == 'Const':
next_layer = content[linenumber + 5]
layer_str2 = next_layer.split(' ')
next_layer_name = layer_str2[-1][:-1]
if next_layer_name == 'Concat':
qq = '0'
else:
node_name_str0 = content[linenumber]
if 'split_y_path.' in node_name_str0:
node_name_str = node_name_str0.split('split_y_path')
node_module_str = node_name_str[1].split('.')[1]
node_module_int = int(node_module_str)
if node_module_int == 2 :
qq = '12'
elif node_module_int <= 4 :
qq = '11'
elif node_module_int > 4 and node_module_int <= 7:
qq = '10'
elif node_module_int > 7 and node_module_int <= 9:
qq = '9'
elif node_module_int > 9 and node_module_int <= 11:
qq = '8'
elif node_module_int > 11 and node_module_int <= 14:
qq = '7'
elif node_module_int == 15:
qq = '9'
elif node_module_int == 16:
qq = '11'
else :
qq = q0
elif 'branches' in node_name_str0 or \
'split_uv_path.5' in node_name_str0 or \
'split_uv_path.6' in node_name_str0 or \
'split_uv_path.7' in node_name_str0 or \
'split_uv_path.8' in node_name_str0 or \
'split_uv_path.9' in node_name_str0 :
qq = q0
elif 'split_uv_path.0' in node_name_str0 or \
'split_uv_path.1' in node_name_str0 :
qq = '12'
elif 'split_uv' in node_name_str0 :
qq = '11'
elif next_layer_name == 'PReLU':
qq = q_prev
else:
# fuse layers
qq = '12'
elif layer_name == 'Conv2D':
qq = '0'
elif layer_name == 'Mul':
qq = '0'
elif layer_name == 'BiasAdd' or \
layer_name == 'LeakyRelu' or \
layer_name == 'Concat' or \
layer_name == 'Shape' or \
layer_name == 'Expand' or \
layer_name == 'Add' or \
layer_name == 'Transpose' or \
layer_name == 'Reshape' or \
layer_name == 'PReLU' or \
layer_name == 'Slice' :
continue
else:
print('Undecided layer name', layer_name)
continue
q_prev = qq
output_q_str = output_q_str + layer_id + ' ' + qq + ' '
return output_q_str
if __name__ == "__main__":
args = parse_arguments()
json_config = args["config"]
stage = args["stage"]
try:
with open(json_config) as file:
config = json.load(file)
except Exception:
quit("[ERROR] unable to open json config")
main_dir = config['binaries']['sadl_path']
sadl_build_dir = os.path.join(main_dir, 'sample_test')
debug_cpp_dir = os.path.join(sadl_build_dir, 'debug_model')
float_sadl_dir = config[f"stage{stage}"]['conversion']['full_path_filename']
int_sadl_dir = config[f"stage{stage}"]['quantize']['full_path_filename']
debug_log_dir = float_sadl_dir[:-5] + '_debug.txt'
print(f"[INFO] SADL build dir {sadl_build_dir}")
print(f"[INFO] input model {float_sadl_dir}")
# debug float SADL model and get the log
cmd_debug = [debug_cpp_dir, float_sadl_dir, '>', debug_log_dir]
cmd_debug = " ".join(cmd_debug)
os.system(cmd_debug)
# read layer info and set quantizer
file = open(debug_log_dir)
content = file.readlines()
output_q_str = qfactor_gen(content)
file.close()
# apply quantization
naive_quantizer_cpp_dir = os.path.join(sadl_build_dir, 'naive_quantization')
cmd_quantize = ["echo '", output_q_str , " ' | ",
naive_quantizer_cpp_dir,
float_sadl_dir,
int_sadl_dir, ";"]
cmd_quantize = " ".join(cmd_quantize)
os.system(cmd_quantize)
# Training Stage 3
Use the json files in LOP3
```
python3 src/training/tools/create_config.py src/training/training_scripts/NN_Filtering/common/common_config.json src/training/training_scripts/NN_Filtering/LOP3/lop3_config.json src/training/training_scripts/NN_Filtering/LOP3/model/model.json src/training/training_scripts/NN_Filtering/LOP3/paths.json > my_config.json
python3 training_scripts/NN_Filtering/common/training/main.py --json_config my_config.json --stage 3
```
Converting to SADL model
```
python3 training_scripts/NN_Filtering/common/convert/to_sadl.py --json_config my_config.json --input_model stage3/training --output_model stage3/conversion/full_path_filename
```
Quantization of the model
```
python3 training_scripts/NN_Filtering/LOP3/quantize/quantize.py -c my_config.json --stage 3
```
"""
/* The copyright in this software is being made available under the BSD
* License, included below. This software may be subject to other third party
* and contributor rights, including patent rights, and no such rights are
* granted under this license.
*
* Copyright (c) 2010-2024, ITU/ISO/IEC
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may
* be used to endorse or promote products derived from this software without
* specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
* THE POSSIBILITY OF SUCH DAMAGE.
"""
import torch
import torch.nn as nn
class FastDCT():
def __init__(self, dct_size: int, device: str):
super().__init__()
self.dct_size = dct_size
self.device = device
def blockDCT_ps(self, xx):
if self.dct_size == 2:
yy = self.blockDCT2x2_ps2(xx)
else:
raise ValueError('DCT size unknown', self.dct_size)
return yy
def iblockDCT_ps(self, xx):
if self.dct_size == 2:
yy = self.iblockDCT2x2_ps2(xx)
else:
raise ValueError('DCT size unknown', self.dct_size)
return yy
def blockDCT_ps_halfsize(self, xx):
if self.dct_size == 2:
yy = xx
else:
raise ValueError('DCT size unknown', self.dct_size)
return yy
def iblockDCT_per_chUV(self, xx):
dct_ch = self.dct_size * self.dct_size
ch_n = xx.shape[1] // dct_ch
yy = torch.Tensor(0).to(self.device)
for i in range(ch_n):
yi = self.iblockDCT_ps(xx[:, i * dct_ch:(i + 1) * dct_ch, :, :])
yy = torch.cat((yy, yi), dim=1)
return yy
def blockDCT2x2_ps2(self, xx):
pix_unshuff = nn.PixelUnshuffle(self.dct_size)
yy = pix_unshuff(xx)
rec_dct_reshape = torch.zeros((1, 1, 1, 1), dtype=torch.float32).to(self.device).expand(yy.shape).clone()
a = yy[:, 0, :, :]
b = yy[:, 1, :, :]
c = yy[:, 2, :, :]
d = yy[:, 3, :, :]
t1 = a + b
t2 = c + d
t3 = a - b
t4 = c - d
c1 = t1 + t2
c2 = t3 + t4
c3 = t1 - t2
c4 = t3 - t4
rec_dct_reshape[:, 0, :, :] = c1
rec_dct_reshape[:, 1, :, :] = c2
rec_dct_reshape[:, 2, :, :] = c3
rec_dct_reshape[:, 3, :, :] = c4
return rec_dct_reshape / 4
def iblockDCT2x2_ps2(self, yy):
xx = torch.zeros((1, 1, 1, 1), dtype=torch.float32).to(self.device).expand(yy.shape).clone()
c1 = yy[:, 0, :, :]
c2 = yy[:, 1, :, :]
c3 = yy[:, 2, :, :]
c4 = yy[:, 3, :, :]
t1 = c1 + c2
t2 = c3 + c4
t3 = c1 - c2
t4 = c3 - c4
a = t1 + t2
b = t3 + t4
c = t1 - t2
d = t3 - t4
xx[:, 0, :, :] = a
xx[:, 1, :, :] = b
xx[:, 2, :, :] = c
xx[:, 3, :, :] = d
pix_shuff = nn.PixelShuffle(self.dct_size)
xx = pix_shuff(xx)
return xx
......@@ -44,6 +44,7 @@ import dataset
import logger
import importlib.util
from fastDCT import FastDCT
class Trainer:
......@@ -190,6 +191,18 @@ class Trainer:
)
self.loss_function = nn.L1Loss()
if "dct_size" in self.config_training or "dct_ch" in self.config['model']:
if "dct_size" not in self.config_training:
raise KeyError('Training config missing dct_size')
if "dct_ch" not in self.config['model']:
raise KeyError('Model config missing dct_ch')
self.dct_size = self.config_training["dct_size"]
self.fast_dct = FastDCT(self.dct_size, self.device)
model_dct_ch = self.config['model']['dct_ch']
print(f"[INFO] DCT transformed inputs used, size {self.dct_size}")
assert model_dct_ch == self.dct_size * self.dct_size, 'Unequal DCT size in training config and model config.'
else:
print("[INFO] DCT transformed inputs not used")
@staticmethod
def instantiate_from_dict(
......@@ -312,9 +325,21 @@ class Trainer:
return {"lossY": lossY, "lossUV": lossUV, "lossYUV": lossYUV}
def iteration(self, sample):
Y, UV = self.model(
{name: tensor.to(self.device) for name, tensor in sample.items()}
)
if "dct_size" in self.config_training and self.dct_size >= 2:
sample_dct = self.applyDCT2(sample)
Ycoeff_res, UVcoeff_res = self.model(
{name: tensor.to(self.device) for name, tensor in sample_dct.items()}
)
Y_res, UV_res = self.applyIDCT2(Ycoeff_res, UVcoeff_res)
Y = sample["rec_before_dbf_Y"].to(self.device) + Y_res
UV = UV_res + torch.cat((sample["rec_before_dbf_U"], sample["rec_before_dbf_V"]), dim=1)[
..., ::2, ::2
].to(self.device)
else:
Y, UV = self.model(
{name: tensor.to(self.device) for name, tensor in sample.items()}
)
target_Y = sample["org_Y"][..., 8:-8, 8:-8].to(self.device)
target_UV = torch.cat((sample["org_U"], sample["org_V"]), dim=1)[
..., 8:-8:2, 8:-8:2
......@@ -343,3 +368,25 @@ class Trainer:
)
val_metrics[tag] = val_metric
return val_metrics
def applyDCT2(self, sample):
sample_dct = {}
for name, tensor in sample.items():
if 'qp' in name:
yy = tensor[:, 0, 0, 0].to(self.device)
dim = tensor.shape
yy = yy.unsqueeze(1).unsqueeze(1).unsqueeze(1).expand((dim[0], 1, dim[2] // self.dct_size, dim[3] // self.dct_size))
elif '_Y' in name:
yy = self.fast_dct.blockDCT_ps(tensor.to(self.device))
elif '_U' in name or '_V' in name:
tensor = tensor.to(self.device)
yy = self.fast_dct.blockDCT_ps_halfsize(tensor[:, :, ::2, ::2])
else:
raise ValueError('Unknown input name', name, ' in applyDCT2().')
sample_dct[name] = yy
return sample_dct
def applyIDCT2(self, YCoeff, UVCoeff):
y = self.fast_dct.iblockDCT_ps(YCoeff)
uv = self.fast_dct.iblockDCT_per_chUV(UVCoeff)
return y, uv
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment