diff --git a/cfg/nn-based/nnsr/nnsr.cfg b/cfg/nn-based/nnsr/nnsr.cfg index a24e836744872c68f5b36cc517a1cf371c486bfe..9de8e92f8adce4f873bf65a46305de7c1f4cf1fc 100644 --- a/cfg/nn-based/nnsr/nnsr.cfg +++ b/cfg/nn-based/nnsr/nnsr.cfg @@ -1,3 +1,3 @@ RPR : 1 NnsrOption : 1 -NnsrModelName : models/super_resolution/NNVC_SR_multiratio_int16.sadl +NnsrModelName : models/super_resolution/NNVC_SR_multiratio_wavelet_int16.sadl diff --git a/models/super_resolution/NNVC_SR_multiratio_wavelet_int16.sadl b/models/super_resolution/NNVC_SR_multiratio_wavelet_int16.sadl new file mode 100644 index 0000000000000000000000000000000000000000..b1a8a2491be3776b7368ccd73a6f9834b79e3a15 --- /dev/null +++ b/models/super_resolution/NNVC_SR_multiratio_wavelet_int16.sadl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6817994eab4ad3e358fbb1d1fb2dd5098014d5130dc59247b1742c84ac8eff2 +size 59516 diff --git a/training/training_scripts/NN_SR_Unified/model/model.json b/training/training_scripts/NN_SR_Unified/model/model.json deleted file mode 100644 index 293ebcc6528665f052c8fda145640dde2f92b2fe..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_SR_Unified/model/model.json +++ /dev/null @@ -1,39 +0,0 @@ -{ "model" : { - "path": "../../LOP/model/model.py", - "class" : "Net", - "input_channels" : [ - [ "rec_before_dbf_Y" ], - [ "rec_before_dbf_U" ], - [ "rec_before_dbf_V" ], - [ "pred_Y" ], - [ "pred_U" ], - [ "pred_V" ], - [ "bs_Y" ], - [ "bs_U" ], - [ "bs_V" ], - [ "ipb_Y" ], - [ "qp_base" ], - [ "qp_slice" ] - ], - "input_kernels" : [ - 3, - 3, - 1, - 1, - 1, - 1 - ], - "D1" : 32, - "D2" : 16, - "D3" : 16, - "D4" : 16, - "D5" : 16, - "D6" : 32, - "N_Y" : 20, - "N_UV" : 10, - "C" : 16, - "C1_Y" : 64, - "C1_UV" : 64, - "C21" : 16 - } -} diff --git a/training/training_scripts/NN_SR_Unified/model/model.py b/training/training_scripts/NN_SR_Unified/model/model.py deleted file mode 100644 index 97a552aeb38ae7830bb8d7f8b4b8b6cd0253792d..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_SR_Unified/model/model.py +++ /dev/null @@ -1,369 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2024, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -from typing import Union, Tuple, Optional, Type, Iterable, List, Dict - -import torch -from torch import nn -from torch.nn import functional as F - - -class Conv(nn.Sequential): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]] = 1, - padding: Optional[Union[int, Tuple[int, int]]] = None, - is_separable: bool = False, - hidden_separable_channels: Optional[int] = None, - post_activation: Optional[Type] = nn.PReLU, - **kwargs, - ): - """ - Args: - in_channels: the number of input channels - out_channels: the number of output channels - kernel_size: the convolution's kernel size - stride: the convolution's stride(s) - padding: the convolution's padding - is_separable: whether to implement convolution separably - hidden_separable_channels: If is_separable, the number of hidden channels between convolutions. If None, use out_channels - post_activation: activation function to use after convolution. If None, no activation after convolution - **kwargs: additional kwargs to pass to nn.Conv2d - """ - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = ( - (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size - ) - self.stride = (stride, stride) if isinstance(stride, int) else stride - if padding is not None: - self.padding = (padding, padding) if isinstance(padding, int) else padding - else: - self.padding = tuple([k // 2 for k in self.kernel_size]) - self.is_separable = is_separable - self.post_activation = post_activation - - if self.is_separable: - self.hidden_separable_channels = hidden_separable_channels or out_channels - modules = [ - nn.Conv2d( - self.in_channels, - self.hidden_separable_channels, - (self.kernel_size[0], 1), - (self.stride[0], 1), - (self.padding[0], 0), - groups=self.hidden_separable_channels, - **kwargs, - ), - nn.Conv2d( - self.hidden_separable_channels, - self.out_channels, - (1, self.kernel_size[1]), - (1, self.stride[1]), - (0, self.padding[1]), - groups=self.hidden_separable_channels, - **kwargs, - ), - nn.Conv2d( - self.out_channels, - self.out_channels, - (1, 1), - (1, 1), - (0, 0), - **kwargs, - ), - ] - else: - modules = [ - nn.Conv2d( - self.in_channels, - self.out_channels, - self.kernel_size, - self.stride, - self.padding, - **kwargs, - ) - ] - - if self.post_activation is not None: - modules.append(self.post_activation()) - - super(Conv, self).__init__(*modules) - - -class MultiBranchModule(nn.Module): - """A module representing multple, parallel branches. If the input is a list, each element in the list is fed into the corresponding branch, - otherwise the input is fed into every branch. The outputs of each branch are then merged.""" - - def __init__(self, *branch_modules, merge_dimension: int = -3): - """ - Args: - branch_modules: modules to run in parallel - merge_dimension: the dimension to merge outputs from each branch - """ - super().__init__() - self.branches = nn.ModuleList(branch_modules) - self.merge_dimension = merge_dimension - - def forward(self, args: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: - inputs = args if isinstance(args, list) else len(self.branches) * [args] - branch_outputs = [branch(input) for branch, input in zip(self.branches, inputs)] - return torch.cat(branch_outputs, dim=self.merge_dimension) - - -class ResidualBlock(nn.Sequential): - def __init__(self, C: int = 64, C1: int = 160, C21: int = 32): - super(ResidualBlock, self).__init__( - Conv(C, C1, kernel_size=1), - Conv(C1, C, kernel_size=1, post_activation=None), - Conv( - C, - C, - kernel_size=3, - post_activation=None, - is_separable=True, - hidden_separable_channels=C21, - ), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x + super(ResidualBlock, self).forward(x) - - -class SplitLumaChromaBlocks(nn.Sequential): - def __init__( - self, - N_Y: int = 12, - N_UV: int = 6, - C: int = 16, - C1_Y: int = 64, - C1_UV: int = 48, - C21: int = 16, - output_channels_y: int = 4, - output_channels_uv: int = 2, - path: str = None, - ): - super().__init__() - - self.split_y_path = nn.Sequential( - *[ResidualBlock(C, C1_Y, C21) for _ in range(N_Y)], - Conv(C, C, kernel_size=3, is_separable=True, hidden_separable_channels=C21), - Conv(C, output_channels_y, kernel_size=3, post_activation=None), - ) - - self.split_uv_path = nn.Sequential( - *[ResidualBlock(C, C1_UV, C21) for _ in range(N_UV)], - Conv(C, C, kernel_size=3, is_separable=True, hidden_separable_channels=C21), - Conv(C, output_channels_uv, kernel_size=3, post_activation=None), - ) - - self.Cy = C - self.Cuv = C - - def forward(self, x: torch.Tensor) -> torch.Tensor: - split_y_input = x[:, : self.Cy, :, :] - split_uv_input = x[:, self.Cy : self.Cy + self.Cuv, :, :] - - y_output = self.split_y_path.forward(split_y_input) - uv_output = self.split_uv_path.forward(split_uv_input) - return torch.cat((y_output, uv_output), dim=1) - - -class SADLNet(nn.Sequential): - """The network used during SADL inference""" - - def __init__( - self, - input_channels: Iterable[int] = [3, 3, 3, 1, 1, 1], - input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3], - D1: int = 32, - D2: int = 16, - D3: int = 16, - D4: int = 16, - D5: int = 16, - D6: int = 32, - N_Y: int = 20, - N_UV: int = 10, - C: int = 16, - C1_Y: int = 64, - C1_UV: int = 64, - C21: int = 16, - output_channels_y: int = 4, - output_channels_uv: int = 2, - path: str = None, - ): - """ - Args: - input_channels: the number of channels expected for each input - input_kernels: the kernel size for each input convolution - output_channels: the number of output channels - """ - self.input_channels = input_channels - self.input_kernels = input_kernels - self.input_features = [D1, D2, D3, D4, D4, D5] - - super(SADLNet, self).__init__( - MultiBranchModule( - *[ - Conv(c, d, kernel_size=k, post_activation=None) - for c, d, k in zip( - self.input_channels, self.input_features, self.input_kernels - ) - ] - ), - Conv(sum(self.input_features), C + C, kernel_size=1), - SplitLumaChromaBlocks( - N_Y, N_UV, C, C1_Y, C1_UV, C21, output_channels_y, output_channels_uv - ), - ) - - def get_example_inputs( - self, patch_size: Union[int, Tuple[int, int]] = 144, batch_size: int = 1 - ): - patch_size = ( - (patch_size, patch_size) if isinstance(patch_size, int) else patch_size - ) - return [ - torch.rand( - batch_size, conv.in_channels, *patch_size, device=conv[0].weight.device - ) - for conv in self[0].branches - ] - - def to_onnx( - self, - filename: str, - patch_size: int = 144, - batch_size: int = 1, - opset: int = 10, - **kwargs, - ) -> None: - mode = self.training - self.eval() - torch.onnx.export( - self, - self.get_example_inputs(patch_size, batch_size), - filename, - opset_version=opset, - **kwargs, - ) - self.train(mode) - - -class Net(nn.Module): - """Wrapper for SADL model that implements input pre- and post-processing for training.""" - - def __init__( - self, - input_channels: Iterable[Iterable[str]] = [ - ["rec_after_alf_Y"], - ["rec_after_alf_U"], - ["rec_after_alf_V"], - ["pred_Y"], - ["pred_U"], - ["pred_V"], - ["bs_Y"], - ["bs_U"], - ["bs_V"], - ["ipb_Y"], - ["qp_base"], - ["qp_slice"], - ], - input_kernels: Iterable[int] = [3, 3, 3, 1, 1, 3], - D1: int = 32, - D2: int = 16, - D3: int = 16, - D4: int = 16, - D5: int = 16, - D6: int = 32, - N_Y: int = 20, - N_UV: int = 10, - C: int = 16, - C1_Y: int = 64, - C1_UV: int = 64, - C21: int = 16, - path: str = None, - ): - super(Net, self).__init__() - assert len(input_channels) == len( - input_kernels - ), "[ERROR] input size and kernels size not equal" - self.input_channels = input_channels - sizes = [len(a) for a in input_channels] - self.SADL_model = SADLNet( - sizes, - input_kernels, - D1, - D2, - D3, - D4, - D5, - D6, - N_Y, - N_UV, - C, - C1_Y, - C1_UV, - C21, - 4, - 2, - ) - self.chroma_upsampler = nn.Upsample(scale_factor=2, mode="nearest") - - def preprocess_args( - self, batch: Dict[str, torch.Tensor] - ) -> Dict[str, torch.Tensor]: - return [ - torch.cat([batch[name] for name in input_], dim=1) - for input_ in self.input_channels - ] - - def postprocess_outputs( - self, batch: Dict[str, torch.Tensor], out: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - Y_res, UV_res = out.split([4, 2], dim=1) - return ( - F.pixel_shuffle(Y_res, 2), - UV_res, - ) - - def forward( - self, batch: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: - args = self.preprocess_args(batch) - out = self.SADL_model(args) - return self.postprocess_outputs(batch, out) diff --git a/training/training_scripts/NN_SR_Unified/paths.json b/training/training_scripts/NN_SR_Unified/paths.json deleted file mode 100644 index 16f9a2a5d364c28d38ae9563beeddcffd419342f..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_SR_Unified/paths.json +++ /dev/null @@ -1,81 +0,0 @@ -{ - "binaries": { - "sadl_path": "/path/to/src/sadl" - }, - - "model" : { - "path": "/path/to/src/training/training_scripts/NN_Filtering/LOP/model/model.py" - }, - - "dataset": { - "div2k_train": { - "//": "dataset of png to convert", - "path": "/path/to/DIV2K/DIV2K_train_HR" - }, - "div2k_valid": { - "path": "/path/to/DIV2K/DIV2K_valid_HR" - }, - "bvi": { - "//": "dataset of yuv", - "path": "/path/to/bviorg", - "dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/bvi.json" - }, - "bvi_valid": { - "//": "dataset of yuv", - "path": "/path/to/bviorg", - "dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/bvi_valid.json" - }, - "tvd": { - "//": "dataset of yuv", - "path": "/path/to/tvdorg", - "dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/tvd.json" - }, - "tvd_valid": { - "//": "dataset of yuv", - "path": "/path/to/tvdorg", - "dataset_file": "/path/to/src/training/training_scripts/NN_Filtering/common/datasets/tvd_valid.json" - } - }, - - - "stage1": { - "yuv": { - "//": "path to store yuv files dataset", - "path": "/path/to/stage1/yuv" - }, - "yuv_valid": { - "//": "path to store yuv files dataset", - "path": "/path/to/stage1/yuv" - }, - "encdec": { - "//": "path to store the shell script and all the generated files by the encoder/decoder", - "path": "/path/to/stage1/encdec", - "vtm_enc": "/path/to/src/bin/EncoderAppStatic", - "vtm_dec": "/path/to/src/bin/DecoderAppStatic", - "vtm_cfg": "/path/to/src/cfg/encoder_intra_vtm.cfg" - }, - "encdec_valid": { - "//": "path to store the shell script and all the generated files by the encoder/decoder", - "path": "/path/to/stage1/encdec", - "vtm_enc": "/path/to/src/bin/EncoderAppStatic", - "vtm_dec": "/path/to/src/bin/DecoderAppStatic", - "vtm_cfg": "/path/to/src/cfg/encoder_intra_vtm.cfg" - }, - "dataset": { - "//": "path to store the full dataset which will be used by the training", - "path": "/path/to/stage1/dataset" - }, - "dataset_valid": { - "//": "path to store the full dataset which will be used by the training", - "path": "/path/to/stage1/dataset" - }, - "training": { - "path": "/path/to/stage1/train" - }, - "conversion": { - "//": "full path to output the model. input model is taken in training/ckpt_dir", - "full_path_filename": "/path/to/stage1/train/model_float.sadl" - } - } -} - diff --git a/training/training_scripts/NN_SR_Unified/quantize/quantize.py b/training/training_scripts/NN_SR_Unified/quantize/quantize.py deleted file mode 100644 index 8e7852272b248626bd74e13cd1fff717264d3fa5..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_SR_Unified/quantize/quantize.py +++ /dev/null @@ -1,188 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2024, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" -import argparse -import json -import os -import subprocess - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--json_config", action="store", type=str, help="global configuration file" - ) - parser.add_argument( - "--input_model", - type=str, - help="SADL float model name or training section name in the config.json", - required=True, - ) - parser.add_argument( - "--output_model", - type=str, - help="sadl int16 model name or conversion section name in the config.json", - required=True, - ) - parser.add_argument( - "--sadl_quantize", - type=str, - help="full path to sadl naive_quantizer", - required=False, - ) - args = parser.parse_args() - return vars(args) - - -def qfactor_gen(content): - word = "[INFO] id: " - stop_word = "[INFO] == end model loading ==" - - q = 11 # input - groups = 1 - - string = "" - i = 0 - q_wt = 11 - - for idx, line in enumerate(content): - i += 1 - if line.find(word) != -1: - line_split = line.split(" ") - id = line_split[2] - name = line_split[5] - if name == "Placeholder\n": - string = string + id + " " + str(q) + " " - elif name == "Conv2D\n": - if groups >= int(1): - layer_name = content[idx - 5] - layer_name_split = layer_name.split(" ")[4].split(".")[1] - q_upd = q_wt - groups = 0 - if int(q) > 9: - string = string[0:-3] + str(q_upd) + " " + id + " 0 " - else: - string = string[0:-2] + str(q_upd) + " " + id + " 0 " - else: - string = string + id + " 0 " - elif ( - name == "LeakyRelu\n" - or name == "Concat\n" - or name == "Add\n" - or name == "PReLU\n" - ): - id = id - elif name == "BiasAdd\n": - layer_name = content[idx - 5] - layer_name_split = layer_name.split(" ")[4].split(".")[1] - - chroma_condition = ( - "body_chroma.7" in layer_name or "body_chroma.6" in layer_name - ) - luma_condition = ( - "body_luma.15" in layer_name - or "body_luma.14" in layer_name - or "body_luma.13" in layer_name - or "body_luma.12" in layer_name - ) - if ( - "body_chroma.8" in layer_name - or "body_chroma.9" in layer_name - or "body_luma.16" in layer_name - or "body_luma.17" in layer_name - or "body_luma.18" in layer_name - or "body_luma.19" in layer_name - ): - q_upd = 8 - elif chroma_condition or luma_condition: - q_upd = 9 - elif "tail_luma" in layer_name or "tail_chroma" in layer_name: - q_upd = 11 - elif ( - "body_luma" == layer_name_split or "body_chroma" == layer_name_split - ): - q_upd = 10 - else: - q_upd = 11 - if int(q) > 9: - string = string[0:-3] + str(q_upd) + " " - else: - string = string[0:-2] + str(q_upd) + " " - else: - if (name == "Const\n") and content[idx + 6].find("Concat") != -1: - q_upd = 0 - else: - q_upd = q - string = string + id + " " + str(q_upd) + " " - if name != "Const\n": - print("Entered here for wrong name", name) - elif line.find(stop_word) != -1: - break - else: - grp_line = content[idx + 6] - if grp_line.find("groups:") != -1: - groups = int(grp_line.split(" ")[-1][:-1]) - return string - - -if __name__ == "__main__": - args = parse_arguments() - json_config = args["json_config"] - infile = args["input_model"] - debug_text = os.path.join( - os.path.dirname(args["output_model"]), "sadl_float_debug_dump.txt" - ) - outfile = args["output_model"] - sadl_quantize = args["sadl_quantize"] - - try: - with open(json_config) as file: - prm = json.load(file) - except Exception: - quit("[ERROR] unable to open json config") - debug_model = os.path.join(prm["binaries"]["sadl_path"], "sample", "debug_model") - - debug_txt_content = open(debug_text, "w") - process = subprocess.Popen([debug_model, infile], stdout=debug_txt_content) - process.wait() - - debug_txt_content = open(debug_text, "r") - q_log_string = qfactor_gen(debug_txt_content.readlines()) - os.remove(debug_text) - sadl_quantize = os.path.join( - prm["binaries"]["sadl_path"], "sample", "naive_quantization" - ) - - p = subprocess.Popen([sadl_quantize, infile, outfile], stdin=subprocess.PIPE) - p.communicate(input=q_log_string.encode()) - print(f"[INFO] output model {outfile}") diff --git a/training/training_scripts/NN_SR_Unified/readme.md b/training/training_scripts/NN_SR_Unified/readme.md deleted file mode 100644 index dca6d62e793983340e761870b6731c5e06d3edae..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_SR_Unified/readme.md +++ /dev/null @@ -1,234 +0,0 @@ -# Super Resolution Model Training - -## Overview -The training of this NNVC Super Resolution (SR) is basically the same as stage 1 of unified LOP loop filter and only one stage is performed for training SR. Following the procedure of firsdt stage of unified LOP loop filter could finish the training of SR. - -### Preparation of the directory -At minima, paths should first be set. The file ``training/training_scripts/NN_SR_Unified/paths.json`` should be copy and edited to match your environment. -All keys with the name ``path`` should be edited to fit your particular environement. -Additionally, you should also edit the variable ``vtm_xx`` to point to the VTM binaries and configuration files, the ``sadl_path`` to point to the sadl repository. - -Assuming that all data are on the same storage, then the following directory structure can be used: -``` - - src [ contains the NNVC repository] - - DIV2K [ contains the original div2k dataset] - - bviorg [ contains the original BVI dataset ] - - tvdorg [ contains the original TVD dataset ] - - - stage1 [ will be created by the scripts] - - stage2 [ will be created by the scripts] - - stage3 [ will be created by the scripts] -``` - -To create ``src`` the following commands can be used: -```sh -git clone https://vcgit.hhi.fraunhofer.de/jvet-ahg-nnvc/VVCSoftware_VTM.git src; -cd src; -git checkout VTM-11.0_nnvc -git submodule init -git submodule update -mkdir build -cd build -cmake -DCMAKE_BUILD_TYPE=Release .. -make -j -cd ../.. -``` - -To create the DIV2K directory: -```sh -mkdir DIV2K; -cd DIV2K; -wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip -unzip DIV2K_train_HR.zip -wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip -unzip DIV2K_valid_HR.zip -cd .. -``` - -To create the bviorg directory: -```sh -mkdir bviorg; -cd bviorg; -wget https://data.bris.ac.uk/datasets/tar/3h0hduxrq4awq2ffvhabjzbzi1.zip -unzip 3h0hduxrq4awq2ffvhabjzbzi1.zip -cd .. -``` - -To create the tvdorg directory: -```sh -mkdir tvdorg; -# download TVD dataset from https://multimedia.tencent.com/resources/tvd -``` - - -With this file structure, the ``paths.json`` is simply edited by just replacing ``/path/to`` by the absolute path of the root directory of the experiment. -If your datasets are on different storages, just edit the relevant lines in ``paths.json``. - -### Creation of the consolidated configuration file -The main configuration file is ``common/common_config.json``. -The LOP specific parameters are in ``LOP/lop_config.json``. -The variables specific to your environment is ``LOP/paths.json``. -The model is in ``LOP/model/model.json``. -A unique configuration is create by merging all the different files using the following command: -```sh -cp src/training_scripts/NN_SR_Unified/LOP/paths.json . -# edit paths.json -python3 src/training/tools/create_config.py src/training/training_scripts/NN_SR_Unified/common/common_config.json src/training/training_scripts/NN_SR_Unified/LOP/lop_config.json src/training/training_scripts/NN_SR_Unified/LOP/model/model.json paths.json > my_config.json -``` - -Then this file will be used in command line below, you should be able to run the process just by copy/pasting all lines of shell below. - -Other keys should not be edited except for testing reasons. - - -## Training Data - -### A- Data extraction for intra from vanilla VTM -#### 1. Dataset preparation - div2k conversion - -Convert div2k (4:4:4 RGB -> YUV420 10 bits): - -```sh - python3 tools/convert_dataset.py --json_config my_config.json --input_dataset dataset/div2k_train --augmentation --output_location stage1/yuv - python3 tools/convert_dataset.py --json_config my_config.json --input_dataset dataset/div2k_valid --output_location stage1/yuv_valid -``` -dataset files are placed in the target directory (as set in the config.json ["stage1"]["yuv"]["path"]), a json file named ["stage1"]["yuv"]["dataset_filename"] is updated with the new data. - - -#### 2. Prepare scripts for encoding/decoding of the dataset -Please note that a VTM without NN tools is used. NNVC-4.0 to NNVC-6.0 tags can be used to generate the binaries and cfg file. The configuration file is the vanilla VTM one (see config.json). -The macro for data dump should be: -``` -// which data are used for inference/dump -#define NNVC_USE_PRED 1 // prediction -#define NNVC_USE_BS 1 // BS of DBF -#define NNVC_USE_QP 1 // QP slice -#define JVET_AC0089_NNVC_USE_BPM_INFO 1 // JVET-AC0089: dump Block Prediction Mode - -#define SR_DATA_DUMP 1 // only used for dumping training data for super resolution -#if SR_DATA_DUMP -#define NNVC_USE_REC_AFTER_ALF 1 // reconstruction after ALF -#define NNVC_USE_REC_AFTER_UPSAMPLING 1 // reconstruction after upsampling -#endif -``` -Other macros can be set to 0. When SR_DATA_DUMP is enabled, the JVET_AG0130_UNIFIED_SR should be set to 0. - - -Extract cfg files and encoding/decoding script: -```sh - python3 tools/dataset_to_encoding_script.py --json_config my_config.json --input_dataset stage1/yuv --output_location stage1/encdec - python3 tools/dataset_to_encoding_script.py --json_config my_config.json --input_dataset stage1/yuv_valid --output_location stage1/encdec_valid -``` -It will generate the cfg files for the dataset and a shell script to encode and decode all sequences in the dataset in the directory ["stage1"]["encdec"]["path"]. - -#### 3. Encode/decode the sequences: - -Loop on all sequences to encode and make sure the following parameters are used for Encoder: ---ScalingRatioHor=2 --ScalingRatioVer=2 --UpscaledOutput=2 - -The following parameters are used for Decoder: ---UpscaledOutput=2 - -an example (may need to add the parameters described above for encoder and decoder) could be found in: -```sh -cd stage1/encdec; -N1=32000; -for((i=0;i<N1;i++)); do - ./encode_decode_dataset.sh $i; -done -N2=1000; -for((i=0;i<N2;i++)); do - ./encode_decode_dataset_valid.sh $i; -done -``` -or you can use the script to encode on your cluster. N is the number of sequences (run ./encode_decode_dataset.sh to get the value N). -The script will perform the following step for each sequence: - -1. encodes the yuv and produces an encoding log -2. decodes the bitstream, produces a decoding log and dump the data in ["stage1"]["encdec"]["path"]/["dump_dir"] - -#### 4. Create a consolidated dataset - -```sh -python3 tools/concatenate_dataset.py --json_config my_config.json --input_dir_json stage1/encdec --output_json stage1/encdec -python3 tools/concatenate_dataset.py --json_config my_config.json --input_dir_json stage1/encdec_valid --output_json stage1/encdec_valid -``` -It will generate a unique dataset in ["stage1"]["encdec"]["path"] from all individual datasets in ["stage1"]["encdec"]["path"]/["dump_dir"] and encoder logs in ["stage1"]["encdec"]["enc_dir"]. - - -#### 5a. Create an offline dataset with all batches - -```sh -python3 tools/create_unified_dataset.py --json_config my_config.json \ - --nb_patches -1 --patch_size 128 --border_size 8 --input_dataset stage1/encdec \ - --components org_Y,org_U,org_V,rec_after_upsampling_Y,rec_after_upsampling_U,rec_after_upsampling_V,pred_Y,pred_U,pred_V,rec_after_alf_Y,rec_after_alf_U,rec_after_alf_V,bs_Y,bs_U,bs_V,ipb_Y,qp_base,qp_slice \ - --output_location stage1/dataset -python3 tools/create_unified_dataset.py --json_config my_config.json \ - --nb_patches -1 --patch_size 128 --border_size 8 --input_dataset stage1/encdec_valid \ - --components org_Y,org_U,org_V,rec_after_upsampling_Y,rec_after_upsampling_U,rec_after_upsampling_V,pred_Y,pred_U,pred_V,rec_after_alf_Y,rec_after_alf_U,rec_after_alf_V,bs_Y,bs_U,bs_V,ipb_Y,qp_base,qp_slice \ - --output_location stage1/dataset_valid -``` -It will generate a unique dataset of patches ready for training in ["stage1"]["dataset"]["path"] from the dataset in ["stage1"]["encdec"]["path"]. - -### Training stage -#### 1. Train SR model -If you need to adapt the settings of your device for training, please edit the file ``my_config.json`` (default parameters). You can also change the loggers verbosity in these files. -When ready, simply run: - -```sh -python3 training_scripts/NN_SR_Unified/common/training/main.py --json_config my_config.json --stage 1 -``` - -#### 2. Convert model to SADL -The last ONNX model is converted into float SADL format. -```sh -python3 training_scripts/NN_SR_Unified/common/convert/to_sadl.py \ - --json_config my_config.json \ - --input_model stage1/training --output_model stage1/conversion/full_path_filename -``` - -The converter will use: - * the json config file used for training - * the file model_onnx_filename of the training as input (usually last.onnx) - * output the file model_filename of the conversion section - -**Note:** the directory in dataset can now be deleted if there is no need to retrain. - - -#### 3. Test model -To test the int model, the following macros should be enabled in ``TypeDef.h``: -``` -#define JVET_AG0130_UNIFIED_SR 1 -#if JVET_AG0130_UNIFIED_SR -#define ADAPTIVE_RPR 1 -#define SR_FIXED_POINT_IMPLEMENTATION 1 -#if SR_FIXED_POINT_IMPLEMENTATION -using TypeSadlSr = int16_t; -#else -using TypeSadlSr = float; -#endif -#endif -``` - -and the correct data macros should be set: -``` -// which data are used for inference/dump -#define NNVC_USE_REC_BEFORE_DBF 1 // reconstruction before DBF -#define NNVC_USE_PRED 1 // prediction -#define NNVC_USE_BS 1 // BS of DBF -#define NNVC_USE_QP 1 // QP slice -#define JVET_AC0089_NNVC_USE_BPM_INFO 1 // JVET-AC0089: dump Block Prediction Mode - -#define SR_DATA_DUMP 0 // only used for dumping training data for super resolution -#if SR_DATA_DUMP -#define NNVC_USE_REC_AFTER_ALF 1 // reconstruction after ALF -#define NNVC_USE_REC_AFTER_UPSAMPLING 1 // reconstruction after upsampling -#endif -``` - -##### 3.1 Inference test -The model is tested with the following parameters: -``` ---NnsrOption=1 --NnsrModelName=models/NNVC_SR_int16.sadl -``` -The configuration file is ``encoder_xxx_vtm.cfg`` and the anchor VTM-11.0\_NNVC. \ No newline at end of file diff --git a/training/training_scripts/NN_SR_Unified/sr_config.json b/training/training_scripts/NN_SR_Unified/sr_config.json deleted file mode 100644 index f04b2c21856e638b19d5c8b815c71bd8460576c0..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_SR_Unified/sr_config.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "stage1": { - "training": { - "mse_epoch": 43, - "max_epochs": 45, - "component_loss_weightings": [ - 12, - 2 - ], - "dataloader": { - "batch_size": 16 - }, - "optimizer": { - "lr": 0.0004 - }, - "lr_scheduler": { - "milestones": [ - 25, - 38, - 43 - ] - } - } - } - -} diff --git a/training/training_scripts/NN_SR_Unified/training/trainer.py b/training/training_scripts/NN_SR_Unified/training/trainer.py deleted file mode 100644 index 986866e8e884bdc56d88f9ca00b3ee676fdfe7d4..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_SR_Unified/training/trainer.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2024, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import os -import sys -import torch - -sys.path.append( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "../..", "NN_Filtering") -) -from common.training import trainer as TrainerLF # noqa: E402 - - -class Trainer(TrainerLF.Trainer): - """Training procedure.""" - - def iteration(self, sample): - Y, UV = self.model( - {name: tensor.to(self.device) for name, tensor in sample.items()} - ) - target_Y = sample["org_Y"][..., 8:-8, 8:-8].to(self.device) - target_UV = torch.cat((sample["org_U"], sample["org_V"]), dim=1)[ - ..., 8:-8:2, 8:-8:2 - ].to(self.device) - upsampled_Y = sample["rec_after_upsampling_Y"][..., 8:-8, 8:-8].to(self.device) - upsampled_UV = torch.cat( - (sample["rec_after_upsampling_U"], sample["rec_after_upsampling_V"]), dim=1 - )[..., 8:-8:2, 8:-8:2].to(self.device) - - lossY = self.loss_function(Y[..., 8:-8, 8:-8] + upsampled_Y, target_Y) - lossUV = self.loss_function(UV[..., 4:-4, 4:-4] + upsampled_UV, target_UV) - return lossY, lossUV diff --git a/training/training_scripts/NN_SR_WaveletLoss_Multiratio/.gitkeep b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/training/training_scripts/NN_SR_Multiratio/common_config.json b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/common_config.json similarity index 99% rename from training/training_scripts/NN_SR_Multiratio/common_config.json rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/common_config.json index 9bc41fa408b72ee11e6538fb05426623ddd0c3b8..70217b8cfdd41cf3f9a107a0aa8c636a9ec8570b 100644 --- a/training/training_scripts/NN_SR_Multiratio/common_config.json +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/common_config.json @@ -240,6 +240,7 @@ } }, "training": { + "use_dwt_loss": true, "ckpt_dir": "ckpt", "ckpt_reload": "last.ckpt", "model_onnx_filename": "last.onnx", diff --git a/training/training_scripts/NN_SR_Multiratio/dataset_to_nnsr.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/dataset_to_nnsr.py similarity index 90% rename from training/training_scripts/NN_SR_Multiratio/dataset_to_nnsr.py rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/dataset_to_nnsr.py index 283606cbe22a1fea6b32d24dfb4f7daf51335088..b034ca47c4da1a8ecfaa611bd10a8e8fa60d8225 100644 --- a/training/training_scripts/NN_SR_Multiratio/dataset_to_nnsr.py +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/dataset_to_nnsr.py @@ -24,7 +24,8 @@ def remove_files_from_json(filename, pattern): data = json.load(f) print(filename + ':', len(data)) - to_remove = [filename for filename in data.keys() if pattern.match(filename)] + to_remove = [filename for filename in data.keys() + if pattern.match(filename)] for file in to_remove: data.pop(file, None) @@ -44,7 +45,8 @@ parser.add_argument( "--input_dataset", action="store", type=str, - help="dataset to process: expressed as name from config.json. Eg: dataset_src/div2k_train, dataset_src/bvi", + help="dataset to process: expressed as name from config.json. \ + Eg: dataset_src/div2k_train, dataset_src/bvi", required=True, ) diff --git a/training/training_scripts/NN_SR_Multiratio/lop_config.json b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/lop_config.json similarity index 100% rename from training/training_scripts/NN_SR_Multiratio/lop_config.json rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/lop_config.json diff --git a/training/training_scripts/NN_SR_Multiratio/model/model.json b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/model/model.json similarity index 100% rename from training/training_scripts/NN_SR_Multiratio/model/model.json rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/model/model.json diff --git a/training/training_scripts/NN_SR_Multiratio/model/model.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/model/model.py similarity index 100% rename from training/training_scripts/NN_SR_Multiratio/model/model.py rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/model/model.py diff --git a/training/training_scripts/NN_SR_Multiratio/paths.json b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/paths.json similarity index 100% rename from training/training_scripts/NN_SR_Multiratio/paths.json rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/paths.json diff --git a/training/training_scripts/NN_SR_Multiratio/quantize/quantize.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/quantize/quantize.py similarity index 90% rename from training/training_scripts/NN_SR_Multiratio/quantize/quantize.py rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/quantize/quantize.py index 9ba6cdb6521af3ada7cbdfab113af3fa992a5d06..bdd63c2c7550a23d982f55d4c06d384c36c4efb4 100644 --- a/training/training_scripts/NN_SR_Multiratio/quantize/quantize.py +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/quantize/quantize.py @@ -85,6 +85,8 @@ def qfactor_gen(content): string = string + id + " " + str(q) + " " elif name == "Conv2D\n": if groups >= int(1): + layer_name = content[idx - 5] + layer_name_split = layer_name.split(" ")[4].split(".")[1] q_upd = q_wt groups = 0 if int(q) > 9: @@ -101,7 +103,14 @@ def qfactor_gen(content): ): id = id elif name == "BiasAdd\n": - q_upd = 11 + layer_name = content[idx - 5] + layer_name_split = layer_name.split(" ")[4].split(".")[1] + if "split_y_path" == layer_name_split: + q_upd = 10 + elif "split_uv_path" == layer_name_split: + q_upd = 11 + else: + q_upd = 11 if int(q) > 9: string = string[0:-3] + str(q_upd) + " " else: @@ -109,6 +118,8 @@ def qfactor_gen(content): else: if (name == "Const\n") and content[idx + 6].find("Concat") != -1: q_upd = 0 + elif (name == "Const\n") and content[idx + 6].find("Resize") != -1: + q_upd = 0 else: q_upd = q string = string + id + " " + str(q_upd) + " " diff --git a/training/training_scripts/NN_SR_Multiratio/readme.md b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/readme.md similarity index 95% rename from training/training_scripts/NN_SR_Multiratio/readme.md rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/readme.md index 63c3f46426f5eee5ceda78d4c6d87d4fc579b586..c99403b6db9dad22b4b6993271056d1eb4d30c49 100644 --- a/training/training_scripts/NN_SR_Multiratio/readme.md +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/readme.md @@ -286,7 +286,7 @@ If you need to adapt the settings of your device for training, please edit the f When ready, simply run: ```sh -python3 src/training/training_scripts/NN_Filtering/common/training/main.py --json_config my_config.json --stage 3 +python3 src/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/main.py --json_config my_config.json --stage 3 ``` #### 2. Convert model @@ -297,13 +297,10 @@ python3 src/training/training_scripts/NN_Filtering/common/convert/to_sadl.py -- ``` #### 3. Integerized model -The float model is integerized into int16 SADL format using a naive quantization (all quantizers to 2^11). First the ``naive_quantization`` software should be build. Please refer to SADL documentation to build the software. ```sh -python3 src/training/training_scripts/NN_Filtering/LOP/quantize/quantize.py --json_config my_config.json --input_model stage3/conversion/full_path_filename --output_model stage3/quantize/full_path_filename - -./src/sadl/sample/debug_model stage3/train/model_float.sadl >stage3/train/sadl_float_debug_dump.txt -python3 quantize-1.py --json_config my_config.json --input_model stage3/train/model_float.sadl --output_model stage3/train/EE1-4.2_int16.sadl +./src/sadl/sample/debug_model /path/to/model_float.sadl > /path/to/sadl_float_debug_dump.txt +python3 src/training/training_scripts/NN_SR_WaveletLoss_Multiratio/quantize/quantize.py --json_config my_config.json --input_model /path/to/model_float.sadl --output_model /path/to/model_int16.sadl ``` diff --git a/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/SWT.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/SWT.py new file mode 100644 index 0000000000000000000000000000000000000000..f0fb69fb99018ab0ebb222bf5dd41f88d6463779 --- /dev/null +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/SWT.py @@ -0,0 +1,539 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import pywt +import torch + +# import matplotlib.pyplot as plt +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + + +def reflect(x, minx, maxx): + """Reflect the values in matrix *x* about the scalar values *minx* and + *maxx*. Hence a vector *x* containing a long linearly increasing series is + converted into a waveform which ramps linearly up and down between *minx* + and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers + + 0.5), the ramps will have repeated max and min samples. + .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013 + .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999. + """ + x = np.asanyarray(x) + rng = maxx - minx + rng_by_2 = 2 * rng + mod = np.fmod(x - minx, rng_by_2) + normed_mod = np.where(mod < 0, mod + rng_by_2, mod) + out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx + return np.array(out, dtype=x.dtype) + + +def mypad(x, pad, mode='constant', value=0): + """ Function to do numpy like padding on tensors. Only works for 2-D + padding. + Inputs: + x (tensor): tensor to pad + pad (tuple): tuple of (left, right, top, bottom) pad sizes + mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or + 'zero'. The padding technique. + """ + if mode == 'symmetric': + # Vertical only + if pad[0] == 0 and pad[1] == 0: + m1, m2 = pad[2], pad[3] + g = x.shape[-2] + xe = reflect(np.arange(-m1, g + m2, dtype='int32'), -0.5, g - 0.5) + return x[:, :, xe] + # horizontal only + elif pad[2] == 0 and pad[3] == 0: + m1, m2 = pad[0], pad[1] + g = x.shape[-1] + xe = reflect(np.arange(-m1, g + m2, dtype='int32'), -0.5, g - 0.5) + return x[:, :, :, xe] + # Both + else: + m1, m2 = pad[0], pad[1] + l1 = x.shape[-1] + xe_row = reflect(np.arange(-m1, l1 + m2, dtype='int32'), -0.5, + l1 - 0.5) + m1, m2 = pad[2], pad[3] + l2 = x.shape[-2] + xe_col = reflect(np.arange(-m1, l2 + m2, dtype='int32'), -0.5, + l2 - 0.5) + i = np.outer(xe_col, np.ones(xe_row.shape[0])) + j = np.outer(np.ones(xe_col.shape[0]), xe_row) + return x[:, :, i, j] + elif mode == 'periodic': + # Vertical only + if pad[0] == 0 and pad[1] == 0: + xe = np.arange(x.shape[-2]) + xe = np.pad(xe, (pad[2], pad[3]), mode='wrap') + return x[:, :, xe] + # Horizontal only + elif pad[2] == 0 and pad[3] == 0: + xe = np.arange(x.shape[-1]) + xe = np.pad(xe, (pad[0], pad[1]), mode='wrap') + return x[:, :, :, xe] + # Both + else: + xe_col = np.arange(x.shape[-2]) + xe_col = np.pad(xe_col, (pad[2], pad[3]), mode='wrap') + xe_row = np.arange(x.shape[-1]) + xe_row = np.pad(xe_row, (pad[0], pad[1]), mode='wrap') + i = np.outer(xe_col, np.ones(xe_row.shape[0])) + j = np.outer(np.ones(xe_col.shape[0]), xe_row) + return x[:, :, i, j] + + elif mode == 'constant' or mode == 'reflect' or mode == 'replicate': + return F.pad(x, pad, mode, value) + elif mode == 'zero': + return F.pad(x, pad) + else: + raise ValueError("Unkown pad type: {}".format(mode)) + + +def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=None): + """ + Prepares the filters to be of the right form for the afb2d function. In + particular, makes the tensors the right shape. It takes mirror images of + them as as afb2d uses conv2d which acts like normal correlation. + Inputs: + h0_col (array-like): low pass column filter bank + h1_col (array-like): high pass column filter bank + h0_row (array-like): low pass row filter bank. If none, will assume + the same as column filter + h1_row (array-like): high pass row filter bank. If none, will assume + the same as column filter + device: which device to put the tensors on to + Returns: + (h0_col, h1_col, h0_row, h1_row) + """ + h0_col = np.array(h0_col[::-1]).ravel() + h1_col = np.array(h1_col[::-1]).ravel() + t = torch.get_default_dtype() + if h0_row is None: + h0_row = h0_col + else: + h0_row = np.array(h0_row[::-1]).ravel() + if h1_row is None: + h1_row = h1_col + else: + h1_row = np.array(h1_row[::-1]).ravel() + h0_col = torch.tensor(h0_col, device=device, dtype=t).reshape( + (1, 1, -1, 1)) + h1_col = torch.tensor(h1_col, device=device, dtype=t).reshape( + (1, 1, -1, 1)) + h0_row = torch.tensor(h0_row, device=device, dtype=t).reshape( + (1, 1, 1, -1)) + h1_row = torch.tensor(h1_row, device=device, dtype=t).reshape( + (1, 1, 1, -1)) + + return h0_col, h1_col, h0_row, h1_row + + +def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None): + """ + Prepares the filters to be of the right form for the sfb2d function. In + particular, makes the tensors the right shape. It does not mirror image + them as as sfb2d uses conv2d_transpose which acts like normal convolution. + Inputs: + g0_col (array-like): low pass column filter bank + g1_col (array-like): high pass column filter bank + g0_row (array-like): low pass row filter bank. If none, will assume + the same as column filter + g1_row (array-like): high pass row filter bank. If none, will assume + the same as column filter + device: which device to put the tensors on to + Returns: + (g0_col, g1_col, g0_row, g1_row) + """ + g0_col = np.array(g0_col).ravel() + g1_col = np.array(g1_col).ravel() + t = torch.get_default_dtype() + if g0_row is None: + g0_row = g0_col + if g1_row is None: + g1_row = g1_col + g0_col = torch.tensor(g0_col, device=device, dtype=t).reshape( + (1, 1, -1, 1)) + g1_col = torch.tensor(g1_col, device=device, dtype=t).reshape( + (1, 1, -1, 1)) + g0_row = torch.tensor(g0_row, device=device, dtype=t).reshape( + (1, 1, 1, -1)) + g1_row = torch.tensor(g1_row, device=device, dtype=t).reshape( + (1, 1, 1, -1)) + + return g0_col, g1_col, g0_row, g1_row + + +def afb1d_atrous(x, h0, h1, mode='symmetric', dim=-1, dilation=1): + """ 1D analysis filter bank (along one dimension only) of an image without + downsampling. Does the a trous algorithm. + Inputs: + x (tensor): 4D input with the last two dimensions the spatial input + h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + mode (str): padding method + dim (int) - dimension of filtering. d=2 is for a vertical filter ( + called column filtering but filters across the rows). d=3 is for a + horizontal filter, (called row filtering but filters across the + columns). + dilation (int): dilation factor. Should be a power of 2. + Returns: + lohi: lowpass and highpass subbands concatenated along the channel + dimension + """ + C = x.shape[1] + # Convert the dim to positive + d = dim % 4 + # If h0, h1 are not tensors, make them. If they are, then assume that they + # are in the right order + if not isinstance(h0, torch.Tensor): + h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]), + dtype=torch.float, device=x.device) + if not isinstance(h1, torch.Tensor): + h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]), + dtype=torch.float, device=x.device) + L = h0.numel() + shape = [1, 1, 1, 1] + shape[d] = L + # If h aren't in the right shape, make them so + if h0.shape != tuple(shape): + h0 = h0.reshape(*shape) + if h1.shape != tuple(shape): + h1 = h1.reshape(*shape) + h = torch.cat([h0, h1] * C, dim=0) + + # Calculate the pad size + L2 = (L * dilation) // 2 + pad = (0, 0, L2 - dilation, L2) if d == 2 else (L2 - dilation, L2, 0, 0) + # ipdb.set_trace() + x = mypad(x, pad=pad, mode=mode) + lohi = F.conv2d(x, h, groups=C, dilation=dilation) + + return lohi + + +def afb2d_atrous(x, filts, mode='symmetric', dilation=1): + """ Does a single level 2d wavelet decomposition of an input. Does separate + row and column filtering by two calls to `afb1d_atrous` + Inputs: + x (torch.Tensor): Input to decompose + filts (list of ndarray or torch.Tensor): If a list of tensors has been + given, this function assumes they are in the right form (the form + returned by `prep_filt_afb2d`). + Otherwise, this function will prepare the filters to be of the + right form by calling `prep_filt_afb2d`. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. + Otherwise, the output size will be slightly larger than half. + dilation (int): dilation factor for the filters. Should be 2**level + Returns: + y: Tensor of shape (N, C, 4, H, W) + """ + tensorize = [not isinstance(f, torch.Tensor) for f in filts] + if len(filts) == 2: + h0, h1 = filts + if True in tensorize: + h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d( + h0, h1, device=x.device) + else: + h0_col = h0 + h0_row = h0.transpose(2, 3) + h1_col = h1 + h1_row = h1.transpose(2, 3) + elif len(filts) == 4: + if True in tensorize: + h0_col, h1_col, h0_row, h1_row = prep_filt_afb2d( + *filts, device=x.device) + else: + h0_col, h1_col, h0_row, h1_row = filts + else: + raise ValueError("Unknown form for input filts") + + lohi = afb1d_atrous(x, h0_row, h1_row, mode=mode, dim=3, dilation=dilation) + y = afb1d_atrous(lohi, h0_col, h1_col, mode=mode, dim=2, dilation=dilation) + + return y + + +def sfb1d_atrous(lo, hi, g0, g1, mode='symmetric', dim=-1, dilation=1, + pad1=None, pad=None): + """ 1D synthesis filter bank of an image tensor with no upsampling. + Used for the stationary wavelet transform. + """ + C = lo.shape[1] + d = dim % 4 + # If g0, g1 are not tensors, make them. If they are, then assume that they + # are in the right order + if not isinstance(g0, torch.Tensor): + g0 = torch.tensor(np.copy(np.array(g0).ravel()), + dtype=torch.float, device=lo.device) + if not isinstance(g1, torch.Tensor): + g1 = torch.tensor(np.copy(np.array(g1).ravel()), + dtype=torch.float, device=lo.device) + L = g0.numel() + shape = [1, 1, 1, 1] + shape[d] = L + # If g aren't in the right shape, make them so + if g0.shape != tuple(shape): + g0 = g0.reshape(*shape) + if g1.shape != tuple(shape): + g1 = g1.reshape(*shape) + g0 = torch.cat([g0] * C, dim=0) + g1 = torch.cat([g1] * C, dim=0) + + # Calculate the padding size. + # With dilation, zeros are inserted between the filter taps but not after. + # that means a filter that is [a b c d] becomes [a 0 b 0 c 0 d]. + # centre = L / 2 + fsz = (L - 1) * dilation + 1 + # newcentre = fsz / 2 + # before = newcentre - dilation*centre + + # When conv_transpose2d is done, a filter with k taps expands an input with + # N samples to be N + k - 1 samples. The 'padding' is really the opposite + # of that, and is how many samples on the edges you want to cut out. + # In addition to this, we want the input to be extended before convolving. + # This means the final output size without the padding option will be + # N + k - 1 + k - 1 + # The final thing to worry about is making sure that the output is centred. + # short_offset = dilation - 1 + # centre_offset = fsz % 2 + a = fsz // 2 + b = fsz // 2 + (fsz + 1) % 2 + + # pad = (0, 0, a, b) if d == 2 else (a, b, 0, 0) + pad = (0, 0, b, a) if d == 2 else (b, a, 0, 0) + lo = mypad(lo, pad=pad, mode=mode) + hi = mypad(hi, pad=pad, mode=mode) + + # unpad = (fsz - 1, 0) if d == 2 else (0, fsz - 1) + unpad = (fsz, 0) if d == 2 else (0, fsz) + + y = F.conv_transpose2d(lo, g0, padding=unpad, + groups=C, dilation=dilation) + \ + F.conv_transpose2d(hi, g1, padding=unpad, + groups=C, dilation=dilation) + + return y / (2 * dilation) + + +def sfb2d_atrous(ll, lh, hl, hh, filts, mode='symmetric'): + """ Does a single level 2d wavelet reconstruction of wavelet coefficients. + Does separate row and column filtering by two calls to `sfb1d_atrous` + Inputs: + ll (torch.Tensor): lowpass coefficients + lh (torch.Tensor): horizontal coefficients + hl (torch.Tensor): vertical coefficients + hh (torch.Tensor): diagonal coefficients + filts (list of ndarray or torch.Tensor): If a list of tensors has been + given, this function assumes they are in the right form (the form + returned by `prep_filt_sfb2d`). + Otherwise, this function will prepare the filters to be of the + right form by calling `prep_filt_sfb2d`. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which + padding to use. If periodization, the output size will be half the + input size. + Otherwise, the output size will be slightly larger than half. + """ + tensorize = [not isinstance(x, torch.Tensor) for x in filts] + if len(filts) == 2: + g0, g1 = filts + if True in tensorize: + g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(g0, g1) + else: + g0_col = g0 + g0_row = g0.transpose(2, 3) + g1_col = g1 + g1_row = g1.transpose(2, 3) + elif len(filts) == 4: + if True in tensorize: + g0_col, g1_col, g0_row, g1_row = prep_filt_sfb2d(*filts) + else: + g0_col, g1_col, g0_row, g1_row = filts + else: + raise ValueError("Unknown form for input filts") + + lo = sfb1d_atrous(ll, lh, g0_col, g1_col, mode=mode, dim=2) + hi = sfb1d_atrous(hl, hh, g0_col, g1_col, mode=mode, dim=2) + y = sfb1d_atrous(lo, hi, g0_row, g1_row, mode=mode, dim=3) + + return y + + +class SWTForward(nn.Module): + """ Performs a 2d Stationary wavelet transform (or undecimated wavelet + transform) of an image + Args: + J (int): Number of levels of decomposition + wave (str or pywt.Wavelet): Which wavelet to use. Can be a string to + pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class, + or can be a two tuple of array-like objects for the analysis low + and high pass filters. + mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The + padding scheme. PyWavelets uses only periodization so we use this + as our default scheme. + """ + def __init__(self, J=1, wave='db1', mode='symmetric'): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + h0_col, h1_col = wave.dec_lo, wave.dec_hi + h0_row, h1_row = h0_col, h1_col + else: + if len(wave) == 2: + h0_col, h1_col = wave[0], wave[1] + h0_row, h1_row = h0_col, h1_col + elif len(wave) == 4: + h0_col, h1_col = wave[0], wave[1] + h0_row, h1_row = wave[2], wave[3] + + # Prepare the filters + filts = prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row) + self.h0_col = nn.Parameter(filts[0], requires_grad=False) + self.h1_col = nn.Parameter(filts[1], requires_grad=False) + self.h0_row = nn.Parameter(filts[2], requires_grad=False) + self.h1_row = nn.Parameter(filts[3], requires_grad=False) + + self.J = J + self.mode = mode + + def forward(self, x): + """ Forward pass of the SWT. + Args: + x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + Returns: + List of coefficients for each scale. Each coefficient has + shape :math:`(N, C_{in}, 4, H_{in}, W_{in})` where the extra + dimension stores the 4 subbands for each scale. The ordering in + these 4 coefficients is: (A, H, V, D) or (ll, lh, hl, hh). + """ + ll = x + coeffs = [] + # Do a multilevel transform + filts = (self.h0_col, self.h1_col, self.h0_row, self.h1_row) + for j in range(self.J): + # Do 1 level of the transform + y = afb2d_atrous(ll, filts, self.mode) + coeffs.append(y) + ll = y[:, 0:1, :, :] + + return coeffs + + +class SWTInverse(nn.Module): + """ Performs a 2d DWT Inverse reconstruction of an image + Args: + wave (str or pywt.Wavelet): Which wavelet to use + C: deprecated, will be removed in future + """ + def __init__(self, wave='db1', mode='symmetric'): + super().__init__() + if isinstance(wave, str): + wave = pywt.Wavelet(wave) + if isinstance(wave, pywt.Wavelet): + g0_col, g1_col = wave.rec_lo, wave.rec_hi + g0_row, g1_row = g0_col, g1_col + else: + if len(wave) == 2: + g0_col, g1_col = wave[0], wave[1] + g0_row, g1_row = g0_col, g1_col + elif len(wave) == 4: + g0_col, g1_col = wave[0], wave[1] + g0_row, g1_row = wave[2], wave[3] + # Prepare the filters + + filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) + self.g0_col = nn.Parameter(filts[0], requires_grad=False) + self.g1_col = nn.Parameter(filts[1], requires_grad=False) + self.g0_row = nn.Parameter(filts[2], requires_grad=False) + self.g1_row = nn.Parameter(filts[3], requires_grad=False) + + self.mode = mode + + def forward(self, coeffs): + """ + Args: + coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: + yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', + W_{in}')` and yh is a list of bandpass tensors of shape + :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match + the format returned by DWTForward + Returns: + Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly + downsampled shapes of the DWT pyramid. + Note: + Can have None for any of the highpass scales and will treat the + values as zeros (not in an efficient way though). + """ + + yl = coeffs[-1][:, 0:1, :, :] + yh = [] + for lohi in coeffs: + yh.append(lohi[:, None, 1:4, :, :]) + + ll = yl + + # Do the synthesis filter banks + for h_ in yh[::-1]: + lh, hl, hh = torch.unbind(h_, dim=2) + filts = (self.g0_col, self.g1_col, self.g0_row, self.g1_row) + ll = sfb2d_atrous(ll, lh, hl, hh, filts, mode=self.mode) + + return ll + + +if __name__ == '__main__': + J = 2 + wave = 'db3' + mode = 'symmetric' + + img_1 = pywt.data.camera() + img_2 = pywt.data.ascent() + img = np.stack([img_1, img_2], 0) + + xx = torch.tensor(img).reshape(2, 1, 512, 512).float().cuda() + + sfm = SWTForward(J, wave, mode).cuda() + ifm = SWTInverse(wave, mode).cuda() + + coeffs = sfm(xx) + recon = ifm(coeffs) diff --git a/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/dataset.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a18443bdf58c308b9ba2e744e6c9f75e520e3366 --- /dev/null +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/dataset.py @@ -0,0 +1,236 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2023, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import Dict, Iterable, Any, Union +import json +from functools import partial + +import numpy as np + +import torch +from torch.utils.data import Dataset +from torchvision.transforms import Compose + + +class ToTensor: + """Converts all values in a dictionary from numpy to a Torch Tensor""" + + def __call__(self, sample: Dict[Any, np.ndarray]) -> Dict[Any, torch.Tensor]: + return {comp: torch.from_numpy(block) for comp, block in sample.items()} + + +class Augment: + """ + Applies a sequence of augmentations randomly to all values in sample. All augmentations in the sequence are independantly + applied but with the same probability. All values in sample are augmented or not for each augmentation. + """ + + AUGMENTATIONS = { + "vflip": partial(torch.flip, dims=[-1]), + "hflip": partial(torch.flip, dims=[-2]), + "transpose": partial(torch.transpose, dim0=-1, dim1=-2), + } + + def __init__(self, augs: Iterable[str] = [], aug_prob: float = 0.2): + self.augs = augs + self.aug_prob = aug_prob + + def __call__(self, sample: Dict[Any, torch.Tensor]) -> Dict[Any, torch.Tensor]: + for aug in self.augs: + if torch.randn(1) < self.aug_prob: + for comp, block in sample.items(): + sample[comp] = self.AUGMENTATIONS[aug](block) + return sample + + +class FilterDataset(Dataset): + """Binary dataset created using the common data dumping framework""" + + def __init__( + self, + desc_file: str, + norm_value: Dict[str, Union[int, float]], + augs: Iterable[str] = [], + aug_prob: float = 0.2, + nb_worker: int = 1, + out_type: str = "float32", + ): + """ + Args: + desc_file: dataset description json file + norm_value: normalisation values for specific components to override default + augs: augmentations to apply + aug_prob: probability of applying each augmentation + """ + self.norm_value = norm_value + self.transform = Compose([ToTensor(), Augment(augs, aug_prob)]) + self.out_type = out_type + + with open(desc_file) as f: + self.description = json.load(f) + + self.luma_patch_size = ( + self.description["patch_size"] + 2 * self.description["border_size"] + ) + self.block_volume = ( + len(self.description["components"]) * self.luma_patch_size ** 2 + ) + self.shape_input = ( + 1, + self.luma_patch_size, + self.luma_patch_size, + len(self.description["components"]), + ) + self.files = [] + for i in range(nb_worker): + self.files.append(open(self.description["data"], "rb")) + + def __len__(self) -> int: + return self.description["nb_patches"] + + def __getitem__(self, idx: int) -> Dict[Any, np.ndarray]: + if torch.is_tensor(idx): + idx = idx.tolist() + + id_file = torch.utils.data.get_worker_info().id + offset = self.block_volume * idx * np.dtype(self.description["type"]).itemsize + self.files[id_file].seek(offset, 0) + v = self.files[id_file].read( + self.block_volume * np.dtype(self.description["type"]).itemsize + ) + block = np.frombuffer( + v, + dtype=self.description["type"], + count=self.block_volume, + ).reshape(self.shape_input) + + components = dict( + zip( + self.description["components"], + np.moveaxis(block, -1, 0).astype(self.out_type), + ) + ) + if self.description["type"].startswith("int"): + components = { + comp: block / self.norm_value.get(comp) + for comp, block in components.items() + } + return self.transform(components) if self.transform else components + + +# to todo later +class FilterDatasetOnTheFly(Dataset): + """Binary dataset created on the fly""" + + @staticmethod + def is_comp_chroma(component): + return component.endswith("_U") or component.endswith("_V") + + def __getitem__(self, idx: int) -> Dict[Any, np.ndarray]: + if torch.is_tensor(idx): + idx = idx.tolist() + + with open(self.description["data"]) as f: + block = np.fromfile( + f, + dtype=self.description["type"], + count=self.block_volume, + offset=self.block_volume + * idx + * np.dtype(self.description["type"]).itemsize, + ) + comp_blocks = np.array_split(block, np.cumsum(self.component_volumes)) + + components = {} + for comp, block, is_chroma in zip( + self.description["components"], comp_blocks, self.are_components_chroma + ): + if is_chroma: + components[comp] = block.reshape( + 1, self.chroma_patch_size, self.chroma_patch_size + ).astype(self.out_type) + else: + components[comp] = block.reshape( + 1, self.luma_patch_size, self.luma_patch_size + ).astype(self.out_type) + + if self.description["type"].startswith("int"): + components = { + comp: block + / self.specific_norm_value.get(comp, self.default_norm_value) + for comp, block in components.items() + } + return self.transform(components) if self.transform else components + + +class MockDataset(Dataset): + """A dataset mocking LumaChromaDataset, for debugging purposes. Additional *args and **kwargs are ignored to reduce config changes required for use.""" + + def __init__( + self, *args: Any, length: int, patch_size: int = 144, **kwargs: Dict[Any, Any] + ): + """ + Args: + args: additional args, that are ignored + length: the length of the mock dataset + luma_patch_size: the size of generated luma patches + kwargs: additional kwargs, that are ignored + """ + print(f"MockDataset is ignoring *args: '{list(args)}'") + print(f"MockDataset is ignoring **kwargs: '{dict(kwargs)}'") + self.length = length + self.patch_size = (1, patch_size, patch_size) + self.components = [ + "org_Y", + "org_U", + "org_V", + "pred_Y", + "pred_U", + "pred_V", + "rec_before_dbf_Y", + "rec_before_dbf_U", + "rec_before_dbf_V", + "bs_Y", + "bs_U", + "bs_V", + "qp_base", + "qp_slice", + "ipb_Y", + ] + + def __len__(self) -> int: + return self.length + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + return {comp: torch.rand(*(self.patch_size)) for comp in self.components} diff --git a/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/fastDCT.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/fastDCT.py new file mode 100644 index 0000000000000000000000000000000000000000..647b21933f51a3c6ca291d88463a65317e82b3a4 --- /dev/null +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/fastDCT.py @@ -0,0 +1,124 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2024, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn + + +class FastDCT(): + def __init__(self, dct_size: int, device: str): + super().__init__() + self.dct_size = dct_size + self.device = device + + def blockDCT_ps(self, xx): + if self.dct_size == 2: + yy = self.blockDCT2x2_ps2(xx) + else: + raise ValueError('DCT size unknown', self.dct_size) + return yy + + def iblockDCT_ps(self, xx): + if self.dct_size == 2: + yy = self.iblockDCT2x2_ps2(xx) + else: + raise ValueError('DCT size unknown', self.dct_size) + return yy + + def blockDCT_ps_halfsize(self, xx): + if self.dct_size == 2: + yy = xx + else: + raise ValueError('DCT size unknown', self.dct_size) + return yy + + def iblockDCT_per_chUV(self, xx): + dct_ch = self.dct_size * self.dct_size + ch_n = xx.shape[1] // dct_ch + yy = torch.Tensor(0).to(self.device) + for i in range(ch_n): + yi = self.iblockDCT_ps(xx[:, i * dct_ch:(i + 1) * dct_ch, :, :]) + yy = torch.cat((yy, yi), dim=1) + return yy + + def blockDCT2x2_ps2(self, xx): + pix_unshuff = nn.PixelUnshuffle(self.dct_size) + yy = pix_unshuff(xx) + rec_dct_reshape = \ + torch.zeros((1, 1, 1, 1), dtype=torch.float32).to(self.device). \ + expand(yy.shape).clone() + + a = yy[:, 0, :, :] + b = yy[:, 1, :, :] + c = yy[:, 2, :, :] + d = yy[:, 3, :, :] + t1 = a + b + t2 = c + d + t3 = a - b + t4 = c - d + c1 = t1 + t2 + c2 = t3 + t4 + c3 = t1 - t2 + c4 = t3 - t4 + rec_dct_reshape[:, 0, :, :] = c1 + rec_dct_reshape[:, 1, :, :] = c2 + rec_dct_reshape[:, 2, :, :] = c3 + rec_dct_reshape[:, 3, :, :] = c4 + + return rec_dct_reshape / 4 + + def iblockDCT2x2_ps2(self, yy): + xx = torch.zeros((1, 1, 1, 1), dtype=torch.float32).to(self.device). \ + expand(yy.shape).clone() + + c1 = yy[:, 0, :, :] + c2 = yy[:, 1, :, :] + c3 = yy[:, 2, :, :] + c4 = yy[:, 3, :, :] + t1 = c1 + c2 + t2 = c3 + c4 + t3 = c1 - c2 + t4 = c3 - c4 + a = t1 + t2 + b = t3 + t4 + c = t1 - t2 + d = t3 - t4 + xx[:, 0, :, :] = a + xx[:, 1, :, :] = b + xx[:, 2, :, :] = c + xx[:, 3, :, :] = d + + pix_shuff = nn.PixelShuffle(self.dct_size) + xx = pix_shuff(xx) + return xx diff --git a/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/logger.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..69a2c4dd19807ae87881d16538a3a35671ab6e3d --- /dev/null +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/logger.py @@ -0,0 +1,346 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import Iterable, Optional, Dict, Any +import shutil +import datetime +import os + + +class BaseLogger: + """A base class representing the interface for all loggers.""" + + def on_train_start(self) -> None: + pass + + def on_train_end(self) -> None: + pass + + def on_train_epoch_start(self, epoch: int, info: str) -> None: + pass + + def on_train_epoch_end(self, epoch: int) -> None: + pass + + def on_train_iter_start(self, epoch: int, iteration: int) -> None: + pass + + def on_train_iter_end( + self, epoch: int, iteration: int, train_metrics: Dict[str, Any] + ) -> None: + pass + + def on_val_epoch_start(self, epoch: int) -> None: + pass + + def on_val_epoch_end( + self, epoch: int, val_metrics: Dict[str, Dict[str, Any]] + ) -> None: + pass + + +class LoggerList(BaseLogger): + """A container for multiple loggers.""" + + def __init__(self, loggers: Iterable[BaseLogger]): + self.loggers = loggers + + def on_train_start(self) -> None: + for logger in self.loggers: + logger.on_train_start() + + def on_train_end(self) -> None: + for logger in self.loggers: + logger.on_train_end() + + def on_train_epoch_start(self, epoch: int, info: str) -> None: + for logger in self.loggers: + logger.on_train_epoch_start(epoch, info) + + def on_train_epoch_end(self, epoch: int) -> None: + for logger in self.loggers: + logger.on_train_epoch_end(epoch) + + def on_train_iter_start(self, epoch: int, iteration: int) -> None: + for logger in self.loggers: + logger.on_train_iter_start(epoch, iteration) + + def on_train_iter_end( + self, epoch: int, iteration: int, train_metrics: Dict[str, Any] + ) -> None: + for logger in self.loggers: + logger.on_train_iter_end(epoch, iteration, train_metrics) + + def on_val_epoch_start(self, epoch: int) -> None: + for logger in self.loggers: + logger.on_val_epoch_start(epoch) + + def on_val_epoch_end( + self, epoch: int, val_metrics: Dict[str, Dict[str, Any]] + ) -> None: + for logger in self.loggers: + logger.on_val_epoch_end(epoch, val_metrics) + + +class PrintLogger(BaseLogger): + """Prints metrics to file (stdout by default).""" + + def __init__( + self, + log_train_interval: int = 0, + log_val_metrics: bool = True, + out_filename: Optional[str] = None, + ): + """ + Args: + log_train_interval: log train metrics every n iterations. 0 to disable + log_val_metrics: log validation metrics. Interval determined by trainer + out_filename: log metrics to file (stdout if None) + """ + self.log_train_interval = log_train_interval + self.log_val_metrics = log_val_metrics + if out_filename is None: + self.out_file = None + return + os.makedirs(os.path.dirname(out_filename), exist_ok=True) + out_bak = out_filename + i = 1 + while os.path.exists(out_bak): + out_bak = out_filename + "." + str(i) + i = i + 1 + if i > 1: + shutil.copy2(out_filename, out_bak) + self.out_file = open(out_filename, "w") + self.time = datetime.datetime.now() + + def __del__(self): + if self.out_file is not None: + self.out_file.close() + + @classmethod + def format_metrics(cls, metrics: Dict[str, Any]) -> str: + return ", ".join([f"{metric}={value:.4e}" for metric, value in metrics.items()]) + + def on_train_iter_end( + self, epoch: int, iteration: int, train_metrics: Dict[str, Any] + ) -> None: + if ( + self.log_train_interval > 0 + and (iteration + 1) % self.log_train_interval == 0 + ): + print( + f"Epoch {epoch}, iteration {iteration}: {self.format_metrics(train_metrics)} dt={datetime.datetime.now()-self.time}", + file=self.out_file, + ) + self.time = datetime.datetime.now() + if self.out_file is not None: + self.out_file.flush() + + def on_val_epoch_end( + self, epoch: int, val_metrics: Dict[str, Dict[str, Any]] + ) -> None: + if self.log_val_metrics: + print(f"Val epoch {epoch}", file=self.out_file) + for val_tag, val_tag_metrics in val_metrics.items(): + print( + f"\t{val_tag}: {self.format_metrics(val_tag_metrics)}", + file=self.out_file, + ) + if self.out_file is not None: + self.out_file.flush() + + +class ProgressLogger(BaseLogger): + """Prints progress updates to file (stdout by default).""" + + def __init__( + self, + log_train_epochs: bool = True, + log_val_epochs: bool = True, + log_train_iterations: bool = False, + log_stage_ends: bool = False, + out_filename: Optional[str] = None, + ): + """ + Args: + log_train_epochs: log updates every training epoch + log_val_epochs: log updates every validation epoch + log_train_iterations: log updates every training iteration + log_stage_ends: also log updates at completion of logged stages. If false, only log updates at start of logged stages + out_filename: log metrics to file (stdout if None) + """ + self.log_train_epochs = log_train_epochs + self.log_val_epochs = log_val_epochs + self.log_train_iterations = log_train_iterations + self.log_stage_ends = log_stage_ends + if out_filename is None: + self.out_file = None + return + + os.makedirs(os.path.dirname(out_filename), exist_ok=True) + out_bak = out_filename + i = 1 + while os.path.exists(out_bak): + out_bak = out_filename + "." + str(i) + i = i + 1 + if i > 1: + shutil.copy2(out_filename, out_bak) + self.out_file = open(out_filename, "w") + + def __del__(self): + if self.out_file is not None: + self.out_file.close() + + def on_train_start(self): + print(f"{datetime.datetime.now()}: Training started", file=self.out_file) + if self.out_file is not None: + self.out_file.flush() + + def on_train_end(self): + if self.log_stage_ends: + print(f"{datetime.datetime.now()}: Training finished", file=self.out_file) + if self.out_file is not None: + self.out_file.flush() + + def on_train_epoch_start(self, epoch: int, info: str): + if self.log_train_epochs: + print( + f"{datetime.datetime.now()}: Training started epoch {epoch}\n{info}", + file=self.out_file, + ) + if self.out_file is not None: + self.out_file.flush() + + def on_train_epoch_end(self, epoch: int): + if self.log_train_epochs and self.log_stage_ends: + print( + f"{datetime.datetime.now()}: Training finished epoch {epoch}", + file=self.out_file, + ) + if self.out_file is not None: + self.out_file.flush() + + def on_train_iter_start(self, epoch: int, iteration: int): + if self.log_train_iterations: + print( + f"{datetime.datetime.now()}: Training starting epoch {epoch}, iteration {iteration}", + file=self.out_file, + ) + if self.out_file is not None: + self.out_file.flush() + + def on_train_iter_end( + self, epoch: int, iteration: int, train_metrics: Dict[str, Any] + ): + if self.log_train_iterations and self.log_stage_ends: + print( + f"{datetime.datetime.now()}: Training finished epoch {epoch}, iteration {iteration}", + file=self.out_file, + ) + if self.out_file is not None: + self.out_file.flush() + + def on_val_epoch_start(self, epoch: int): + if self.log_val_epochs: + print( + f"{datetime.datetime.now()}: Validation starting epoch {epoch}", + file=self.out_file, + ) + if self.out_file is not None: + self.out_file.flush() + + def on_val_epoch_end(self, epoch: int, val_metrics: Dict[str, Dict[str, Any]]): + if self.log_val_epochs and self.log_stage_ends: + print( + f"{datetime.datetime.now()}: Validation finished epoch {epoch}", + file=self.out_file, + ) + if self.out_file is not None: + self.out_file.flush() + + +from torch.utils.tensorboard import SummaryWriter # noqa: E402 + + +class TensorboardLogger(BaseLogger): + """Log metrics to tensorboard""" + + def __init__( + self, + log_dir: str, + log_train_progress: bool = True, + log_train_interval: int = 0, + log_val_metrics: bool = True, + ): + """ + Args: + log_dir: log tensorboards to directory + log_train_progress: log progress by epochs + log_train_interval: log train metrics every n iterations. 0 to disable + log_val_metrics: log validation metrics. Interval determined by trainer + """ + self.log_train_progress = log_train_progress + self.log_train_interval = log_train_interval + self.log_val_metrics = log_val_metrics + self.writer = SummaryWriter(log_dir) + + if self.log_train_interval > 0: + self.global_iteration = 0 + + def on_train_epoch_start(self, epoch: int, info: str): + if self.log_train_progress: + self.writer.add_scalar("epochs", epoch, global_step=epoch) + + def on_train_iter_end( + self, epoch: int, iteration: int, train_metrics: Dict[str, Any] + ) -> None: + if ( + self.log_train_interval > 0 + and (iteration + 1) % self.log_train_interval == 0 + ): + self.global_iteration += self.log_train_interval + for metric, value in train_metrics.items(): + self.writer.add_scalar( + f"train/{metric}", value, global_step=self.global_iteration + ) + + def on_val_epoch_end( + self, epoch: int, val_metrics: Dict[str, Dict[str, Any]] + ) -> None: + if self.log_val_metrics: + for val_dataset_tag, val_dataset_metrics in val_metrics.items(): + for metric, value in val_dataset_metrics.items(): + self.writer.add_scalar( + f"{val_dataset_tag}/{metric}", value, global_step=epoch + ) diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv_frame.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/main.py similarity index 67% rename from training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv_frame.py rename to training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/main.py index 03823392325ca8411a046f2641b5c1f435816440..bf7364fb4910a8149171d785cbfb2cf96312e996 100644 --- a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv_frame.py +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/main.py @@ -4,7 +4,7 @@ * and contributor rights, including patent rights, and no such rights are * granted under this license. * -* Copyright (c) 2010-2022, ITU/ISO/IEC +* Copyright (c) 2010-2023, ITU/ISO/IEC * All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -32,23 +32,38 @@ * THE POSSIBILITY OF SUCH DAMAGE. """ -import os +import argparse +import json +from trainer import Trainer -yuv_dir = "./bvi_dvc_YUV" -yuv_dir_frame = "./bvi_dvc_YUV_frame" -if not os.path.exists(yuv_dir_frame): - os.makedirs(yuv_dir_frame) +parser = argparse.ArgumentParser( + prog="train NN Filter model", + usage="main.py --json_config my_config.json --stage N\nN: stage number in [1-3]", + formatter_class=argparse.RawDescriptionHelpFormatter, +) -for yuv_file in sorted(os.listdir(yuv_dir)): - if yuv_file[0] != "A": - continue +parser.add_argument( + "--json_config", + action="store", + type=str, + help="global configuration file", + required=True, +) - yuv_path = os.path.join(yuv_dir, yuv_file) - tar = os.path.join(yuv_dir_frame, yuv_file[:-4]) - if not os.path.exists(tar): - os.makedirs(tar) +parser.add_argument( + "--stage", action="store", type=int, help="stage number in [1-3]", required=True +) - size = yuv_file.split("_")[1] +args = parser.parse_args() +try: + with open(args.json_config) as file: + config = json.load(file) +except Exception: + quit("[ERROR] unable to open json config") - cmd = f"ffmpeg -s {size} -pix_fmt yuv420p10le -i {yuv_path} -c copy -f segment -segment_time 0.001 {tar}/frame_%03d.yuv" - os.system(cmd) + +if config["verbose"] > 1: + print(json.dumps(config, indent=1)) + +trainer = Trainer(config, args.stage) +trainer.train() diff --git a/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/trainer.py b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd0261287ad3d54decbb18aac97074e00394e6b --- /dev/null +++ b/training/training_scripts/NN_SR_WaveletLoss_Multiratio/training/trainer.py @@ -0,0 +1,482 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os +import json +from typing import Dict, Any +import random +import numpy as np +import torch +from torch import nn +from torch.utils.data import ConcatDataset, DataLoader +import dataset +import logger + +import importlib.util +from fastDCT import FastDCT +import SWT +import pywt + + +class Trainer: + """Training procedure.""" + + def __init__(self, config: Dict[str, Any], stage: int): + """ + Args: + config: config hyperparameters stored in dictionary. + See cfg/base_desc.txt for more details. + """ + self.config = config + self.config_stage = config[f"stage{stage}"] + self.config_training = self.config_stage["training"] + self.current_epoch = 0 + self.current_iteration = 0 + self.device = self.config_training["device"] or ( + torch.device("cuda") if torch.cuda.is_available() + else torch.device("cpu") + ) + print( + f"[INFO] tf32 {torch.backends.cuda.matmul.allow_tf32} \ + {torch.backends.cudnn.allow_tf32}" + ) + self.base_dir = self.config_training["path"] + self.save_dir = os.path.join(self.base_dir, + self.config_training["ckpt_dir"]) + os.makedirs(self.save_dir, exist_ok=True) + print(f"[INFO] All outputs will be saved in '{self.save_dir}'") + + with open(f"{self.base_dir}/training_config.json", "w") as f: + json.dump(self.config, f, indent=4) + + # training set + DataSetCls = getattr( + dataset, self.config_training["dataset_config"].pop("class") + ) + train_args = [] + if "dataset_files" in self.config_stage["dataset"]: + for name, value in\ + self.config_stage["dataset"]["dataset_files"].items(): + arg = {} + arg["desc_file"] = os.path.join( + self.config_stage["dataset"]["path"], value["dataset_file"] + ) + arg["norm_value"] = \ + self.config_training["dataset_config"]["quantizer"] + arg["augs"] = value["augs"] + arg["aug_prob"] = value["aug_prob"] + arg["nb_worker"] = \ + self.config_training["dataloader"]["num_workers"] + train_args.append(arg) + else: # old way for stage1 + arg = {} + arg["desc_file"] = os.path.join( + self.config_stage["dataset"]["path"], + self.config_stage["dataset"]["dataset_file"], + ) + arg["norm_value"] = \ + self.config_training["dataset_config"]["quantizer"] + arg["augs"] = [] + arg["aug_prob"] = 0.2 + arg["nb_worker"] = \ + self.config_training["dataloader"]["num_workers"] + train_args.append(arg) + train_dataset = \ + ConcatDataset([DataSetCls(**args) for args in train_args]) + print(f"[INFO] total number of samples {train_dataset.__len__()}") + self.train_dataloader = DataLoader( + train_dataset, **self.config_training["dataloader"] + ) + + # validation set + valid_args = [] + if "dataset_files" in self.config_stage["dataset_valid"]: + for name, value in self.config_stage["dataset_valid"][ + "dataset_files" + ].items(): + arg = {} + arg["desc_file"] = os.path.join( + self.config_stage["dataset_valid"]["path"], + value["dataset_file"] + ) + arg["norm_value"] = \ + self.config_training["dataset_config"]["quantizer"] + arg["augs"] = value["augs"] + arg["aug_prob"] = value["aug_prob"] + arg["nb_worker"] = \ + self.config_training["dataloader"]["num_workers"] + valid_args.append(arg) + val_dataset = \ + ConcatDataset([DataSetCls(**args) for args in valid_args]) + else: # old way for stage1 + arg = {} + arg["desc_file"] = os.path.join( + self.config_stage["dataset_valid"]["path"], + self.config_stage["dataset_valid"]["dataset_file"], + ) + arg["norm_value"] = \ + self.config_training["dataset_config"]["quantizer"] + arg["augs"] = [] + arg["aug_prob"] = 0.2 + arg["nb_worker"] = \ + self.config_training["dataloader"]["num_workers"] + valid_args.append(arg) + val_dataset = \ + ConcatDataset([DataSetCls(**args) for args in valid_args]) + print(f"[INFO] total number of \ + validation samples {val_dataset.__len__()}") + self.val_dataloaders = { + "val_dataset": DataLoader(val_dataset, + **self.config_training["dataloader"]) + } + + # Model + if "model" in self.config_stage: # use particular model for this stage + json_model = self.config_stage["model"] + else: + json_model = self.config["model"] + spec_model = \ + importlib.util.spec_from_file_location("model", json_model["path"]) + model = importlib.util.module_from_spec(spec_model) + spec_model.loader.exec_module(model) + self.model = self.instantiate_from_dict(model, json_model) + if self.config["verbose"] > 2: + print(f"[INFO] model: {json_model}") + self.model.to(self.device) + + # Optimisation + self.optimizer = self.instantiate_from_dict( + torch.optim, + self.config_training["optimizer"], + self.model.parameters() + ) + self.lr_scheduler = self.instantiate_from_dict( + torch.optim.lr_scheduler, + self.config_training["lr_scheduler"], + self.optimizer, + ) + + # reload + last_ckpt = os.path.join(self.save_dir, + self.config_training["ckpt_reload"]) + if os.path.isfile(last_ckpt): + print(f"[INFO] reload checkpoint {last_ckpt}") + self.load_checkpoint(last_ckpt) + else: + print(f"[WARNING] checkpoint \ + {last_ckpt} not found, training from scratch") + + # Logging + # put path relative to path + for name, logitem in self.config_training["loggers"].items(): + if self.config["verbose"] > 1: + print(f"[INFO] logger {name}: {logitem}") + if "out_filename" in \ + logitem and logitem["out_filename"] is not None: + logitem["out_filename"] = os.path.join( + self.base_dir, logitem["out_filename"] + ) + elif "log_dir" in logitem and logitem["log_dir"] is not None: + logitem["log_dir"] = \ + os.path.join(self.base_dir, logitem["log_dir"]) + + self.loggers = logger.LoggerList( + [ + getattr(logger, LoggerCls)(**kwargs) + for LoggerCls, kwargs in + self.config_training["loggers"].items() + ] + ) + + self.loss_function = nn.L1Loss() + if "dct_size" in \ + self.config_training or "dct_ch" in self.config['model']: + if "dct_size" not in self.config_training: + raise KeyError('Training config missing dct_size') + if "dct_ch" not in self.config['model']: + raise KeyError('Model config missing dct_ch') + self.dct_size = self.config_training["dct_size"] + self.fast_dct = FastDCT(self.dct_size, self.device) + model_dct_ch = self.config['model']['dct_ch'] + print(f"[INFO] DCT transformed inputs used, size {self.dct_size}") + assert model_dct_ch == self.dct_size * self.dct_size, \ + 'Unequal DCT size in training config and model config.' + else: + print("[INFO] DCT transformed inputs not used") + + # wavelet init + wave = "sym7" + wavelet = pywt.Wavelet(wave) + + dlo = wavelet.dec_lo + an_lo = np.divide(dlo, sum(dlo)) + an_hi = wavelet.dec_hi + rlo = wavelet.rec_lo + syn_lo = 2 * np.divide(rlo, sum(rlo)) + syn_hi = wavelet.rec_hi + + filters = pywt.Wavelet('wavelet_normalized', + [an_lo, an_hi, syn_lo, syn_hi]) + self.sfm = SWT.SWTForward(1, filters, 'periodic').to(self.device) + + @staticmethod + def instantiate_from_dict( + namespace, dict: Dict[str, Any], *args: Any, **kwargs: Dict[str, Any] + ) -> Any: + """Instatiate instance from config dict. + Value of 'class' in dict should be a class available in namespace. + Args: + namespace: load class definition from this namespace + dict: dict containing kwargs for instantiation + *args: args for instantiation + **kwargs: additional kwargs for instantiation + """ + return getattr(namespace, dict.pop("class"))(*args, **dict, **kwargs) + + def save_checkpoint(self, filename: str): + torch.save( + { + "epoch": self.current_epoch, + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + }, + filename, + ) + + def load_checkpoint(self, filename: str): + checkpoint = torch.load(filename, map_location=self.device) + self.current_epoch = checkpoint["epoch"] + 1 + self.model.load_state_dict(checkpoint["model"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if ( + "override_scheduler" in self.config_training + and self.config_training["override_scheduler"] + ): + # hack lr scheduler + self.lr_scheduler._step_count = self.current_epoch + 1 + self.lr_scheduler.last_epoch = self.current_epoch + self.lr_scheduler._last_lr = checkpoint[ + "optimizer"]["param_groups"][0][ + "lr" + ] + print("[INFO] scheduler parameters from new file") + else: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + print("[INFO] scheduler parameters from checkpoint") + print( + f"[INFO] resume at epoch {self.current_epoch} with lr=\ + {self.lr_scheduler._last_lr}" + ) + + def loss_merger(self, Y_loss: torch.Tensor, + UV_loss: torch.Tensor) -> torch.Tensor: + return ( + self.config_training["component_loss_weightings"][0] * Y_loss + + self.config_training["component_loss_weightings"][1] * UV_loss + ) / sum(self.config_training["component_loss_weightings"]) + + def seed_everything(self): + if self.config_training["seed"] is not None: + random.seed(self.config_training["seed"]) + np.random.seed(self.config_training["seed"]) + torch.manual_seed(self.config_training["seed"]) + torch.cuda.manual_seed_all(self.config_training["seed"]) + + def train(self): + self.seed_everything() + self.loggers.on_train_start() + for epoch in range(self.current_epoch, + self.config_training["max_epochs"]): + self.current_epoch = epoch + print("") + info = "" + if epoch >= self.config_training["mse_epoch"]: + info += f"[INFO] epoch {epoch} L2 loss, \ + lr={self.lr_scheduler.get_last_lr()} \ + {self.optimizer.param_groups[0]['lr']}\n" + self.loss_function = nn.MSELoss() + else: + info += f"[INFO] epoch {epoch} L1 loss, \ + lr={self.lr_scheduler.get_last_lr()} \ + {self.optimizer.param_groups[0]['lr']}\n" + info += "[INFO] optimizer: " + for p in self.optimizer.param_groups[0]: + if p != "params": + info += f"{p}={self.optimizer.param_groups[0][p]} " + info += f"\n[INFO] scheduler: {self.lr_scheduler.state_dict()}" + print(info) + self.loggers.on_train_epoch_start(self.current_epoch, info) + self.train_epoch() + self.loggers.on_train_epoch_end(self.current_epoch) + + if (epoch + 1) % self.config_training["validation_interval"] == 0: + self.loggers.on_val_epoch_start(self.current_epoch) + val_metrics = self.val_epoch() + self.loggers.on_val_epoch_end(self.current_epoch, val_metrics) + + self.lr_scheduler.step() + if ( + self.config_training[ + "model_ckpt"]["every_n_epochs"] is not None + and (epoch + 1) % self.config_training[ + "model_ckpt"]["every_n_epochs"] + == 0 + ): + self.save_checkpoint(f"{self.save_dir}/epoch_{epoch}.ckpt") + if self.config_training["model_ckpt"]["save_last"]: + self.save_checkpoint(f"{self.save_dir}/last.ckpt") + if self.config_training["model_ckpt"]["export_last"]: + self.model.SADL_model.to_onnx(f"{self.save_dir}/last.onnx") + + self.loggers.on_train_end() + + def train_epoch(self): + for i, sample in enumerate(self.train_dataloader): + self.current_iteration = i + self.loggers.on_train_iter_start(self.current_epoch, + self.current_iteration) + train_metrics = self.train_iteration(sample) + self.loggers.on_train_iter_end( + self.current_epoch, self.current_iteration, train_metrics + ) + + def train_iteration(self, sample): + lossY, lossUV = self.iteration(sample) + lossYUV = self.loss_merger(lossY, lossUV) + + self.optimizer.zero_grad() + lossYUV.backward() + self.optimizer.step() + + return {"lossY": lossY, "lossUV": lossUV, "lossYUV": lossYUV} + + def iteration(self, sample): + if "dct_size" in self.config_training and self.dct_size >= 2: + sample_dct = self.applyDCT2(sample) + Ycoeff_res, UVcoeff_res = self.model( + {name: tensor.to(self.device) for name, + tensor in sample_dct.items()} + ) + Y_res, UV_res = self.applyIDCT2(Ycoeff_res, UVcoeff_res) + Y = sample["rec_before_dbf_Y"].to(self.device) + Y_res + UV = UV_res + torch.cat((sample["rec_before_dbf_U"], + sample["rec_before_dbf_V"]), dim=1)[ + ..., ::2, ::2 + ].to(self.device) + else: + Y, UV = self.model( + {name: tensor.to(self.device) for name, + tensor in sample.items()} + ) + + target_Y = sample["org_Y"][..., 8:-8, 8:-8].to(self.device) + target_UV = torch.cat((sample["org_U"], sample["org_V"]), dim=1)[ + ..., 8:-8:2, 8:-8:2 + ].to(self.device) + + if "use_dwt_loss" in \ + self.config_training and \ + self.config_training["use_dwt_loss"] is True: + y_dwt_sr = self.sfm(Y[..., 8:-8, 8:-8])[0] + uv_dwt_sr = self.sfm(UV[..., 4:-4, 4:-4])[0] + y_dwt_org = self.sfm(target_Y)[0] + uv_dwt_org = self.sfm(target_UV)[0] + + loss_ratio = [0.5, 0.05, 0.05, 0.25] + lossY = 0 + lossUV = 0 + for i in range(4): + loss = loss_ratio[i] * self.loss_function( + y_dwt_sr[:, i:i + 1, :, :], y_dwt_org[:, i:i + 1, :, :]) + lossY += loss + + loss = loss_ratio[i] * self.loss_function( + uv_dwt_sr[:, i:i + 1, :, :], + uv_dwt_org[:, i:i + 1, :, :]) \ + + loss_ratio[i] * self.loss_function( + uv_dwt_sr[:, i + 4:i + 5, :, :], + uv_dwt_org[:, i + 4:i + 5, :, :]) + lossUV += loss + else: + lossY = self.loss_function(Y[..., 8:-8, 8:-8], target_Y) + lossUV = self.loss_function(UV[..., 4:-4, 4:-4], target_UV) + return lossY, lossUV + + def val_epoch(self): + with torch.no_grad(): + val_metrics = {} + for tag, dataloader in self.val_dataloaders.items(): + Y_loss_sum = 0 + UV_loss_sum = 0 + for sample in dataloader: + lossY, lossUV = self.iteration(sample) + Y_loss_sum += lossY + UV_loss_sum += lossUV + val_metric = { + "lossY": Y_loss_sum / len(dataloader), + "lossUV": UV_loss_sum / len(dataloader), + } + val_metric["lossYUV"] = self.loss_merger( + val_metric["lossY"], val_metric["lossUV"] + ) + val_metrics[tag] = val_metric + return val_metrics + + def applyDCT2(self, sample): + sample_dct = {} + for name, tensor in sample.items(): + if 'qp' in name: + yy = tensor[:, 0, 0, 0].to(self.device) + dim = tensor.shape + yy = yy.unsqueeze(1).unsqueeze(1).unsqueeze(1).expand( + (dim[0], + 1, + dim[2] // self.dct_size, + dim[3] // self.dct_size)) + elif '_Y' in name: + yy = self.fast_dct.blockDCT_ps(tensor.to(self.device)) + elif '_U' in name or '_V' in name: + tensor = tensor.to(self.device) + yy = self.fast_dct.blockDCT_ps_halfsize(tensor[:, :, ::2, ::2]) + else: + raise ValueError('Unknown input name', + name, ' in applyDCT2().') + sample_dct[name] = yy + return sample_dct + + def applyIDCT2(self, YCoeff, UVCoeff): + y = self.fast_dct.iblockDCT_ps(YCoeff) + uv = self.fast_dct.iblockDCT_per_chUV(UVCoeff) + return y, uv diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/1_ReadMe.md b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/1_ReadMe.md deleted file mode 100644 index 58d18ceeb97dbfa408aed7fe4ae2d3f5d21d3a6a..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/1_ReadMe.md +++ /dev/null @@ -1,43 +0,0 @@ -## [TVD](https://multimedia.tencent.com/resources/tvd) -All the sequences with 10 bit-depth (74 sequences) are used. (Partial dataset) -Generation: -1. Run tvd_to_yuv.py (Convert mp4 files to YUV files) -2. Run tvd_to_yuv_frame.py (Split YUV videos) -3. Obtain the raw data in **TVD_Video_YUV** to be used for the subsequent compression and as the ground truth - -## [BVI-DVC](https://vilab.blogs.bristol.ac.uk/2020/02/2375/) -All the 3840x2176 sequences (200 sequences) are used. (Partial dataset) -Generation: -1. Run bvi_dvc_to_yuv.py (Convert mp4 files to YUV files) -2. Run bvi_dvc_to_yuv_frame.py (Split YUV videos) -2. Obtain the raw data in **bvi_dvc_YUV** to be used for the subsequent compression and as the ground truth - -## The file structure of the raw dataset -When the raw dataset generation is finished, it should be the following file structure. - -### TVD -``` - TVD_Video_YUV - │ Bamboo_3840x2160_25fps_10bit_420.yuv - │ │ frame_000.yuv - │ │ frame_001.yuv - │ │ ... - │ BlackBird_3840x2160_25fps_10bit_420.yuv - │ │ frame_000.yuv - │ │ frame_001.yuv - │ │ ... - │ ... -``` -### BVI-DVC -``` - bvi_dvc_YUV - │ AAdvertisingMassagesBangkokVidevo_3840x2176_25fps_10bit_420.yuv - │ │ frame_000.yuv - │ │ frame_001.yuv - │ │ ... - │ AAmericanFootballS2Harmonics_3840x2176_60fps_10bit_420.yuv - │ │ frame_000.yuv - │ │ frame_001.yuv - │ │ ... - │ ... -``` \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv.py b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv.py deleted file mode 100644 index 91d3db3c4949235cfb8128460c5c571478a5f265..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import os - -mp4_dir = "./Videos" -yuv_dir = "./bvi_dvc_YUV" -if not os.path.exists(yuv_dir): - os.makedirs(yuv_dir) - -for mp4_file in sorted(os.listdir(mp4_dir)): - if mp4_file[0] != "A": - continue - - mp4_path = os.path.join(mp4_dir, mp4_file) - yuv_path = os.path.join(yuv_dir, mp4_file[:-4] + ".yuv") - if not os.path.exists(tar): # noqa: F821 - os.makedirs(tar) # noqa: F821 - - size = mp4_file.split("_")[1] - - cmd = f"ffmpeg -i {mp4_path} -pix_fmt yuv420p10le {yuv_path}" - os.system(cmd) diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv.py b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv.py deleted file mode 100644 index afaa82c6be90a592be6eba00b2187b6fc0a10d45..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import os - -mp4_dir = "./Video" -yuv_dir = "./TVD_Video_YUV" -if not os.path.exists(yuv_dir): - os.makedirs(yuv_dir) - -for mp4_file in sorted(os.listdir(mp4_dir)): - if "_10bit_" not in mp4_file: - continue - - mp4_path = os.path.join(mp4_dir, mp4_file) - yuv_path = os.path.join(yuv_dir, mp4_file[:-4] + ".yuv") - if not os.path.exists(tar): # noqa: F821 - os.makedirs(tar) # noqa: F821 - - size = mp4_file.split("_")[1] - - cmd = f"ffmpeg -i {mp4_path} -pix_fmt yuv420p10le {yuv_path}" - os.system(cmd) diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv_frame.py b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv_frame.py deleted file mode 100644 index bdf62f5f39e9e21251d2a4013e82092239f966ce..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv_frame.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import os - -yuv_dir = "./TVD_Video_YUV" -yuv_dir_frame = "./TVD_Video_YUV_frame" -if not os.path.exists(yuv_dir_frame): - os.makedirs(yuv_dir_frame) - -for yuv_file in sorted(os.listdir(yuv_dir)): - if "_10bit_" not in yuv_file: - continue - - yuv_path = os.path.join(yuv_dir, yuv_file) - tar = os.path.join(yuv_dir_frame, yuv_file) - if not os.path.exists(tar): - os.makedirs(tar) - - size = yuv_file.split("_")[1] - - cmd = f"ffmpeg -s {size} -pix_fmt yuv420p10le -i {yuv_path} -c copy -f segment -segment_time 0.001 {tar}/frame_%03d.yuv" - os.system(cmd) diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/2_ReadMe.md b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/2_ReadMe.md deleted file mode 100644 index 7b4b724ece03812609570b62d35a5f6124210078..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/2_ReadMe.md +++ /dev/null @@ -1,114 +0,0 @@ -## Generate the dataset -For the convenience, the detailed codec information, including resolution, encode level and so on, is provided by the python files bvi_dvc_codec_info.py and tvd_codec_info.py, which can make it easier to build your own script on the cluster. - -The corresponding raw dataset in sequence level YUV format to be compressed is generated by the scripts in '../1_generate_raw_data/'. -Finally, the compression dataset in frame level YUV format is obtained at the decoder based on VTM-11.0_nnvc-2.0 (https://vcgit.hhi.fraunhofer.de/jvet-ahg-nnvc/VVCSoftware_VTM/-/tree/VTM-11.0_nnvc-2.0). -Two patches (generate_SR_datasets_for_I_slices.patch and generate_SR_datasets_for_B_slices.patch) are applied to VTM-11.0_nnvc-2.0 to generate I-frame datasets and B-frame datasets, respectively. - -Specifically, the compression dataset generated from decoder includes reconstruction images, prediction images and RPR images. - -### Generate the I data -TVD and BVI-DVC are both compressed under AI configuration and then all I slices are selected to build this dataset. -VTM-11.0_nnvc-2.0 with generate_SR_datasets_for_I_slices.patch is used to generate I data, and TemporalSubsampleRatio is set to 1 as follows. -``` ---TemporalSubsampleRatio=1 -``` - -The macro configurations are provided as follows. Note that this macro DATA_GEN_DEC should be turned off on the encoder, and turned on on the decoder. -``` -#define DATA_GEN_ENC 1 // Encode frame by RPR downsampling -#define DATA_GEN_DEC 1 // Decode bin files to generate dataset, which should be turned off when running the encoder -#define DATA_PREDICTION 1 // Prediction data -``` - -### Generate the B data -TVD and BVI-DVC are both compressed under RA configuration and then all B slices are selected to build this dataset. - -The macro configurations are provided as follows. Note that this macro DATA_GEN_DEC should be turned off on the encoder, and turned on on the decoder. -``` -#define DATA_GEN_ENC 1 // Encode frame by RPR downsampling -#define DATA_GEN_DEC 1 // Decode bin files to generate dataset, which should be turned off when running the encoder -#define DATA_PREDICTION 1 // Prediction data -``` - -## The file structure of the compression dataset -When the compression dataset generation is finished, it should be adjusted into the following file structure. - -### AI -``` - AI_TVD - └───yuv - │ │ Bin_T1AI_A_S01_R32_qp32_s0_f65_t1_poc000.yuv - │ │ Bin_T1AI_A_S02_R27_qp27_s0_f65_t1_poc064.yuv - │ │ Bin_T1AI_A_S03_R32_qp32_s0_f65_t1_poc015.yuv - │ │ ... - │ - └───prediction_image - │ │ Bin_T1AI_A_S01_R32_qp32_s0_f65_t1_poc000_prediction.yuv - │ │ Bin_T1AI_A_S02_R27_qp27_s0_f65_t1_poc064_prediction.yuv - │ │ Bin_T1AI_A_S03_R32_qp32_s0_f65_t1_poc015_prediction.yuv - │ │ ... - └───rpr_image - │ Bin_T1AI_A_S01_R32_qp32_s0_f65_t1_poc000_rpr.yuv - │ Bin_T1AI_A_S02_R27_qp27_s0_f65_t1_poc064_rpr.yuv - │ Bin_T1AI_A_S03_R32_qp32_s0_f65_t1_poc015_rpr.yuv - │ ... - - AI_BVI_DVC - └───yuv - │ │ Bin_T1AI_A_S001_R32_qp32_s0_f64_t1_poc000.yuv - │ │ Bin_T1AI_A_S002_R27_qp27_s0_f64_t1_poc063.yuv - │ │ Bin_T1AI_A_S003_R32_qp32_s0_f64_t1_poc015.yuv - │ │ ... - │ - └───prediction_image - │ │ Bin_T1AI_A_S001_R32_qp32_s0_f64_t1_poc000_prediction.yuv - │ │ Bin_T1AI_A_S002_R27_qp27_s0_f64_t1_poc063_prediction.yuv - │ │ Bin_T1AI_A_S003_R32_qp32_s0_f64_t1_poc015_prediction.yuv - │ │ ... - │ - └───rpr_image - │ Bin_T1AI_A_S001_R32_qp32_s0_f64_t1_poc000_rpr.yuv - │ Bin_T1AI_A_S002_R27_qp27_s0_f64_t1_poc063_rpr.yuv - │ Bin_T1AI_A_S003_R32_qp32_s0_f64_t1_poc015_rpr.yuv - │ ... -``` -### RA -``` - RA_TVD - └───yuv - │ │ Bin_T2RA_A_S01_R27_qp27_s0_f65_t1_poc063_qp36.yuv - │ │ Bin_T2RA_A_S02_R32_qp32_s0_f65_t1_poc031_qp41.yuv - │ │ Bin_T2RA_A_S03_R42_qp42_s0_f65_t1_poc062_qp50.yuv - │ │ ... - │ - └───prediction_image - │ │ Bin_T2RA_A_S01_R27_qp27_s0_f65_t1_poc063_qp36_prediction.yuv - │ │ Bin_T2RA_A_S02_R32_qp32_s0_f65_t1_poc031_qp41_prediction.yuv - │ │ Bin_T2RA_A_S03_R42_qp42_s0_f65_t1_poc062_qp50_prediction.yuv - │ │ ... - └───rpr_image - │ Bin_T2RA_A_S01_R27_qp27_s0_f65_t1_poc063_qp36_rpr.yuv - │ Bin_T2RA_A_S02_R32_qp32_s0_f65_t1_poc031_qp41_rpr.yuv - │ Bin_T2RA_A_S03_R42_qp42_s0_f65_t1_poc062_qp50_rpr.yuv - │ ... - - RA_BVI_DVC - └───yuv - │ │ Bin_T2RA_A_S001_R27_qp27_s0_f64_t1_poc063_qp36.yuv - │ │ Bin_T2RA_A_S002_R32_qp32_s0_f64_t1_poc031_qp41.yuv - │ │ Bin_T2RA_A_S003_R42_qp42_s0_f65_t1_poc062_qp50.yuv - │ │ ... - │ - └───prediction_image - │ │ Bin_T2RA_A_S001_R27_qp27_s0_f64_t1_poc063_qp36_prediction.yuv - │ │ Bin_T2RA_A_S002_R32_qp32_s0_f64_t1_poc031_qp41_prediction.yuv - │ │ Bin_T2RA_A_S003_R42_qp42_s0_f65_t1_poc062_qp50_prediction.yuv - │ │ ... - │ - └───rpr_image - │ │ Bin_T2RA_A_S001_R27_qp27_s0_f64_t1_poc063_qp36_rpr.yuv - │ │ Bin_T2RA_A_S002_R32_qp32_s0_f64_t1_poc031_qp41_rpr.yuv - │ │ Bin_T2RA_A_S003_R42_qp42_s0_f65_t1_poc062_qp50_rpr.yuv - │ ... -``` diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_B_slices.patch b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_B_slices.patch deleted file mode 100644 index 731bafebbf61a1d749fd8489ff380ed7766883c0..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_B_slices.patch +++ /dev/null @@ -1,561 +0,0 @@ -From 73a78fab59f91a4f30f24ab520baf70edbe855fc Mon Sep 17 00:00:00 2001 -From: renjiechang <renjiechang@tencent.com> -Date: Tue, 14 Feb 2023 15:21:58 +0800 -Subject: [PATCH] Generate SR datasets for B slices - ---- - source/App/DecoderApp/DecApp.cpp | 26 +++++++- - source/App/DecoderApp/DecAppCfg.cpp | 4 ++ - source/App/EncoderApp/EncAppCfg.cpp | 8 +++ - source/Lib/CommonLib/CodingStructure.cpp | 48 ++++++++++++++ - source/Lib/CommonLib/CodingStructure.h | 17 +++++ - source/Lib/CommonLib/Picture.cpp | 19 ++++++ - source/Lib/CommonLib/Picture.h | 8 +++ - source/Lib/CommonLib/Rom.cpp | 5 ++ - source/Lib/CommonLib/Rom.h | 4 ++ - source/Lib/CommonLib/TypeDef.h | 3 + - source/Lib/DecoderLib/DecCu.cpp | 14 +++- - source/Lib/EncoderLib/EncLib.cpp | 4 ++ - source/Lib/Utilities/VideoIOYuv.cpp | 82 +++++++++++++++++++++++- - source/Lib/Utilities/VideoIOYuv.h | 6 +- - 14 files changed, 243 insertions(+), 5 deletions(-) - -diff --git a/source/App/DecoderApp/DecApp.cpp b/source/App/DecoderApp/DecApp.cpp -index 85f63bb0..1842d346 100644 ---- a/source/App/DecoderApp/DecApp.cpp -+++ b/source/App/DecoderApp/DecApp.cpp -@@ -88,6 +88,17 @@ uint32_t DecApp::decode() - EXIT( "Failed to open bitstream file " << m_bitstreamFileName.c_str() << " for reading" ) ; - } - -+#if DATA_GEN_DEC -+ strcpy(global_str_name, m_bitstreamFileName.c_str()); -+ { -+ size_t len = strlen(global_str_name); -+ for (size_t i = len - 1; i > len - 5; i--) -+ { -+ global_str_name[i] = 0; -+ } -+ } -+#endif -+ - InputByteStream bytestream(bitstreamFile); - - if (!m_outputDecodedSEIMessagesFilename.empty() && m_outputDecodedSEIMessagesFilename!="-") -@@ -678,7 +689,14 @@ void DecApp::xWriteOutput( PicList* pcListPic, uint32_t tId ) - ChromaFormat chromaFormatIDC = sps->getChromaFormatIdc(); - if( m_upscaledOutput ) - { -- m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ); -+ m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range -+#if DATA_GEN_DEC -+ , pcPic -+#endif -+ ); -+#if DATA_PREDICTION -+ pcPic->m_bufs[PIC_TRUE_PREDICTION].destroy(); -+#endif - } - else - { -@@ -825,7 +843,11 @@ void DecApp::xFlushOutput( PicList* pcListPic, const int layerId ) - ChromaFormat chromaFormatIDC = sps->getChromaFormatIdc(); - if( m_upscaledOutput ) - { -- m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ); -+ m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range -+#if DATA_GEN_DEC -+ , pcPic -+#endif -+ ); - } - else - { -diff --git a/source/App/DecoderApp/DecAppCfg.cpp b/source/App/DecoderApp/DecAppCfg.cpp -index d96c2049..ad912462 100644 ---- a/source/App/DecoderApp/DecAppCfg.cpp -+++ b/source/App/DecoderApp/DecAppCfg.cpp -@@ -124,7 +124,11 @@ bool DecAppCfg::parseCfg( int argc, char* argv[] ) - #endif - ("MCTSCheck", m_mctsCheck, false, "If enabled, the decoder checks for violations of mc_exact_sample_value_match_flag in Temporal MCTS ") - ("targetSubPicIdx", m_targetSubPicIdx, 0, "Specify which subpicture shall be written to output, using subpic index, 0: disabled, subpicIdx=m_targetSubPicIdx-1 \n" ) -+#if DATA_GEN_DEC -+ ( "UpscaledOutput", m_upscaledOutput, 2, "Upscaled output for RPR" ) -+#else - ( "UpscaledOutput", m_upscaledOutput, 0, "Upscaled output for RPR" ) -+#endif - ; - - po::setDefaults(opts); -diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp -index b38001eb..c43cd0e1 100644 ---- a/source/App/EncoderApp/EncAppCfg.cpp -+++ b/source/App/EncoderApp/EncAppCfg.cpp -@@ -1411,11 +1411,19 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) - ( "CCALF", m_ccalf, true, "Cross-component Adaptive Loop Filter" ) - ( "CCALFQpTh", m_ccalfQpThreshold, 37, "QP threshold above which encoder reduces CCALF usage") - ( "RPR", m_rprEnabledFlag, true, "Reference Sample Resolution" ) -+#if DATA_GEN_ENC -+ ( "ScalingRatioHor", m_scalingRatioHor, 2.0, "Scaling ratio in hor direction" ) -+ ( "ScalingRatioVer", m_scalingRatioVer, 2.0, "Scaling ratio in ver direction" ) -+ ( "FractionNumFrames", m_fractionOfFrames, 1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" ) -+ ( "SwitchPocPeriod", m_switchPocPeriod, 0, "Switch POC period for RPR" ) -+ ( "UpscaledOutput", m_upscaledOutput, 2, "Output upscaled (2), decoded but in full resolution buffer (1) or decoded cropped (0, default) picture for RPR" ) -+#else - ( "ScalingRatioHor", m_scalingRatioHor, 1.0, "Scaling ratio in hor direction" ) - ( "ScalingRatioVer", m_scalingRatioVer, 1.0, "Scaling ratio in ver direction" ) - ( "FractionNumFrames", m_fractionOfFrames, 1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" ) - ( "SwitchPocPeriod", m_switchPocPeriod, 0, "Switch POC period for RPR" ) - ( "UpscaledOutput", m_upscaledOutput, 0, "Output upscaled (2), decoded but in full resolution buffer (1) or decoded cropped (0, default) picture for RPR" ) -+#endif - ( "MaxLayers", m_maxLayers, 1, "Max number of layers" ) - #if JVET_S0163_ON_TARGETOLS_SUBLAYERS - ( "EnableOperatingPointInformation", m_OPIEnabled, false, "Enables writing of Operating Point Information (OPI)" ) -diff --git a/source/Lib/CommonLib/CodingStructure.cpp b/source/Lib/CommonLib/CodingStructure.cpp -index b655d445..15e542ba 100644 ---- a/source/Lib/CommonLib/CodingStructure.cpp -+++ b/source/Lib/CommonLib/CodingStructure.cpp -@@ -107,6 +107,9 @@ void CodingStructure::destroy() - parent = nullptr; - - m_pred.destroy(); -+#if DATA_PREDICTION -+ m_predTrue.destroy(); -+#endif - m_resi.destroy(); - m_reco.destroy(); - m_orgr.destroy(); -@@ -895,6 +898,9 @@ void CodingStructure::create(const ChromaFormat &_chromaFormat, const Area& _are - - m_reco.create( area ); - m_pred.create( area ); -+#if DATA_PREDICTION -+ m_predTrue.create( area ); -+#endif - m_resi.create( area ); - m_orgr.create( area ); - } -@@ -910,6 +916,9 @@ void CodingStructure::create(const UnitArea& _unit, const bool isTopLayer, const - - m_reco.create( area ); - m_pred.create( area ); -+#if DATA_PREDICTION -+ m_predTrue.create( area ); -+#endif - m_resi.create( area ); - m_orgr.create( area ); - } -@@ -1082,6 +1091,16 @@ void CodingStructure::rebindPicBufs() - { - m_pred.destroy(); - } -+#if DATA_PREDICTION -+ if (!picture->M_BUFS(0, PIC_TRUE_PREDICTION).bufs.empty()) -+ { -+ m_predTrue.createFromBuf(picture->M_BUFS(0, PIC_TRUE_PREDICTION)); -+ } -+ else -+ { -+ m_predTrue.destroy(); -+ } -+#endif - if (!picture->M_BUFS(0, PIC_RESIDUAL).bufs.empty()) - { - m_resi.createFromBuf(picture->M_BUFS(0, PIC_RESIDUAL)); -@@ -1240,12 +1259,20 @@ void CodingStructure::useSubStructure( const CodingStructure& subStruct, const C - if( parent ) - { - // copy data to picture -+#if DATA_PREDICTION -+ getTruePredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); -+ getPredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); -+#endif - if( cpyPred ) getPredBuf ( clippedArea ).copyFrom( subPredBuf ); - if( cpyResi ) getResiBuf ( clippedArea ).copyFrom( subResiBuf ); - if( cpyReco ) getRecoBuf ( clippedArea ).copyFrom( subRecoBuf ); - if( cpyOrgResi ) getOrgResiBuf( clippedArea ).copyFrom( subStruct.getOrgResiBuf( clippedArea ) ); - } - -+#if DATA_PREDICTION -+ picture->getTruePredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); -+ picture->getPredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); -+#endif - if( cpyPred ) picture->getPredBuf( clippedArea ).copyFrom( subPredBuf ); - if( cpyResi ) picture->getResiBuf( clippedArea ).copyFrom( subResiBuf ); - if( cpyReco ) picture->getRecoBuf( clippedArea ).copyFrom( subRecoBuf ); -@@ -1562,6 +1589,13 @@ const CPelBuf CodingStructure::getPredBuf(const CompArea &blk) const { r - PelUnitBuf CodingStructure::getPredBuf(const UnitArea &unit) { return getBuf(unit, PIC_PREDICTION); } - const CPelUnitBuf CodingStructure::getPredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_PREDICTION); } - -+#if DATA_PREDICTION -+ PelBuf CodingStructure::getTruePredBuf(const CompArea &blk) { return getBuf(blk, PIC_TRUE_PREDICTION); } -+const CPelBuf CodingStructure::getTruePredBuf(const CompArea &blk) const { return getBuf(blk, PIC_TRUE_PREDICTION); } -+ PelUnitBuf CodingStructure::getTruePredBuf(const UnitArea &unit) { return getBuf(unit, PIC_TRUE_PREDICTION); } -+const CPelUnitBuf CodingStructure::getTruePredBuf(const UnitArea &unit)const { return getBuf(unit, PIC_TRUE_PREDICTION); } -+#endif -+ - PelBuf CodingStructure::getResiBuf(const CompArea &blk) { return getBuf(blk, PIC_RESIDUAL); } - const CPelBuf CodingStructure::getResiBuf(const CompArea &blk) const { return getBuf(blk, PIC_RESIDUAL); } - PelUnitBuf CodingStructure::getResiBuf(const UnitArea &unit) { return getBuf(unit, PIC_RESIDUAL); } -@@ -1603,6 +1637,13 @@ PelBuf CodingStructure::getBuf( const CompArea &blk, const PictureType &type ) - - PelStorage* buf = type == PIC_PREDICTION ? &m_pred : ( type == PIC_RESIDUAL ? &m_resi : ( type == PIC_RECONSTRUCTION ? &m_reco : ( type == PIC_ORG_RESI ? &m_orgr : nullptr ) ) ); - -+#if DATA_PREDICTION -+ if (type == PIC_TRUE_PREDICTION) -+ { -+ buf = &m_predTrue; -+ } -+#endif -+ - CHECK( !buf, "Unknown buffer requested" ); - - CHECKD( !area.blocks[compID].contains( blk ), "Buffer not contained in self requested" ); -@@ -1637,6 +1678,13 @@ const CPelBuf CodingStructure::getBuf( const CompArea &blk, const PictureType &t - - const PelStorage* buf = type == PIC_PREDICTION ? &m_pred : ( type == PIC_RESIDUAL ? &m_resi : ( type == PIC_RECONSTRUCTION ? &m_reco : ( type == PIC_ORG_RESI ? &m_orgr : nullptr ) ) ); - -+#if DATA_PREDICTION -+ if (type == PIC_TRUE_PREDICTION) -+ { -+ buf = &m_predTrue; -+ } -+#endif -+ - CHECK( !buf, "Unknown buffer requested" ); - - CHECKD( !area.blocks[compID].contains( blk ), "Buffer not contained in self requested" ); -diff --git a/source/Lib/CommonLib/CodingStructure.h b/source/Lib/CommonLib/CodingStructure.h -index b5ae7ac6..cdd3fbf1 100644 ---- a/source/Lib/CommonLib/CodingStructure.h -+++ b/source/Lib/CommonLib/CodingStructure.h -@@ -62,6 +62,9 @@ enum PictureType - PIC_ORIGINAL_INPUT, - PIC_TRUE_ORIGINAL_INPUT, - PIC_FILTERED_ORIGINAL_INPUT, -+#if DATA_PREDICTION -+ PIC_TRUE_PREDICTION, -+#endif - NUM_PIC_TYPES - }; - extern XUCache g_globalUnitCache; -@@ -228,6 +231,9 @@ private: - std::vector<SAOBlkParam> m_sao; - - PelStorage m_pred; -+#if DATA_PREDICTION -+ PelStorage m_predTrue; -+#endif - PelStorage m_resi; - PelStorage m_reco; - PelStorage m_orgr; -@@ -268,6 +274,17 @@ public: - PelUnitBuf getPredBuf(const UnitArea &unit); - const CPelUnitBuf getPredBuf(const UnitArea &unit) const; - -+#if DATA_PREDICTION -+ PelBuf getTruePredBuf(const CompArea &blk); -+ const CPelBuf getTruePredBuf(const CompArea &blk) const; -+ PelUnitBuf getTruePredBuf(const UnitArea &unit); -+ const CPelUnitBuf getTruePredBuf(const UnitArea &unit) const; -+#endif -+ -+#if DATA_PREDICTION -+ PelUnitBuf getTruePredBuf() { return m_predTrue; } -+#endif -+ - PelBuf getResiBuf(const CompArea &blk); - const CPelBuf getResiBuf(const CompArea &blk) const; - PelUnitBuf getResiBuf(const UnitArea &unit); -diff --git a/source/Lib/CommonLib/Picture.cpp b/source/Lib/CommonLib/Picture.cpp -index a7205bad..3cdc698a 100644 ---- a/source/Lib/CommonLib/Picture.cpp -+++ b/source/Lib/CommonLib/Picture.cpp -@@ -277,6 +277,12 @@ void Picture::createTempBuffers( const unsigned _maxCUSize ) - { - M_BUFS( jId, PIC_PREDICTION ).create( chromaFormat, a, _maxCUSize ); - M_BUFS( jId, PIC_RESIDUAL ).create( chromaFormat, a, _maxCUSize ); -+ -+#if DATA_PREDICTION -+ const Area a_old(Position{ 0, 0 }, lumaSize()); -+ M_BUFS(jId, PIC_TRUE_PREDICTION).create(chromaFormat, a_old, _maxCUSize); -+#endif -+ - #if ENABLE_SPLIT_PARALLELISM - if (jId > 0) - { -@@ -305,6 +311,11 @@ void Picture::destroyTempBuffers() - { - M_BUFS(jId, t).destroy(); - } -+#if DATA_PREDICTION -+#if !DATA_GEN_DEC -+ if (t == PIC_TRUE_PREDICTION) M_BUFS(jId, t).destroy(); -+#endif -+#endif - #if ENABLE_SPLIT_PARALLELISM - if (t == PIC_RECONSTRUCTION && jId > 0) - { -@@ -344,6 +355,14 @@ const CPelBuf Picture::getPredBuf(const CompArea &blk) const { return getBu - PelUnitBuf Picture::getPredBuf(const UnitArea &unit) { return getBuf(unit, PIC_PREDICTION); } - const CPelUnitBuf Picture::getPredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_PREDICTION); } - -+#if DATA_PREDICTION -+ PelBuf Picture::getTruePredBuf(const ComponentID compID, bool wrap) { return getBuf(compID, PIC_TRUE_PREDICTION); } -+ PelBuf Picture::getTruePredBuf(const CompArea &blk) { return getBuf(blk, PIC_TRUE_PREDICTION); } -+const CPelBuf Picture::getTruePredBuf(const CompArea &blk) const { return getBuf(blk, PIC_TRUE_PREDICTION); } -+ PelUnitBuf Picture::getTruePredBuf(const UnitArea &unit) { return getBuf(unit, PIC_TRUE_PREDICTION); } -+const CPelUnitBuf Picture::getTruePredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_TRUE_PREDICTION); } -+#endif -+ - PelBuf Picture::getResiBuf(const CompArea &blk) { return getBuf(blk, PIC_RESIDUAL); } - const CPelBuf Picture::getResiBuf(const CompArea &blk) const { return getBuf(blk, PIC_RESIDUAL); } - PelUnitBuf Picture::getResiBuf(const UnitArea &unit) { return getBuf(unit, PIC_RESIDUAL); } -diff --git a/source/Lib/CommonLib/Picture.h b/source/Lib/CommonLib/Picture.h -index 66073bf6..b48a6099 100644 ---- a/source/Lib/CommonLib/Picture.h -+++ b/source/Lib/CommonLib/Picture.h -@@ -128,6 +128,14 @@ struct Picture : public UnitArea - PelUnitBuf getPredBuf(const UnitArea &unit); - const CPelUnitBuf getPredBuf(const UnitArea &unit) const; - -+#if DATA_PREDICTION -+ PelBuf getTruePredBuf(const ComponentID compID, bool wrap = false); -+ PelBuf getTruePredBuf(const CompArea &blk); -+ const CPelBuf getTruePredBuf(const CompArea &blk) const; -+ PelUnitBuf getTruePredBuf(const UnitArea &unit); -+ const CPelUnitBuf getTruePredBuf(const UnitArea &unit) const; -+#endif -+ - PelBuf getResiBuf(const CompArea &blk); - const CPelBuf getResiBuf(const CompArea &blk) const; - PelUnitBuf getResiBuf(const UnitArea &unit); -diff --git a/source/Lib/CommonLib/Rom.cpp b/source/Lib/CommonLib/Rom.cpp -index dc1c29ae..28ad2c4f 100644 ---- a/source/Lib/CommonLib/Rom.cpp -+++ b/source/Lib/CommonLib/Rom.cpp -@@ -53,6 +53,11 @@ CDTrace *g_trace_ctx = NULL; - #endif - bool g_mctsDecCheckEnabled = false; - -+#if DATA_GEN_DEC -+unsigned int global_cnt = 0; -+char global_str_name[200]; -+#endif -+ - //! \ingroup CommonLib - //! \{ - -diff --git a/source/Lib/CommonLib/Rom.h b/source/Lib/CommonLib/Rom.h -index e7352e3c..4d1b38a1 100644 ---- a/source/Lib/CommonLib/Rom.h -+++ b/source/Lib/CommonLib/Rom.h -@@ -44,6 +44,10 @@ - #include <stdio.h> - #include <iostream> - -+#if DATA_GEN_DEC -+extern unsigned int global_cnt; -+extern char global_str_name[200]; -+#endif - - //! \ingroup CommonLib - //! \{ -diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h -index 8af59c7f..2874459a 100644 ---- a/source/Lib/CommonLib/TypeDef.h -+++ b/source/Lib/CommonLib/TypeDef.h -@@ -50,6 +50,9 @@ - #include <assert.h> - #include <cassert> - -+#define DATA_GEN_ENC 1 // Encode frame by RPR downsampling -+#define DATA_GEN_DEC 1 // Decode bin files to generate dataset, which should be turned off when running the encoder -+#define DATA_PREDICTION 1 // Prediction data - // clang-format off - - //########### place macros to be removed in next cycle below this line ############### -diff --git a/source/Lib/DecoderLib/DecCu.cpp b/source/Lib/DecoderLib/DecCu.cpp -index eeec3474..844c7aac 100644 ---- a/source/Lib/DecoderLib/DecCu.cpp -+++ b/source/Lib/DecoderLib/DecCu.cpp -@@ -182,6 +182,9 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID ) - const ChannelType chType = toChannelType( compID ); - - PelBuf piPred = cs.getPredBuf( area ); -+#if DATA_PREDICTION -+ PelBuf piPredTrue = cs.getTruePredBuf(area); -+#endif - - const PredictionUnit &pu = *tu.cs->getPU( area.pos(), chType ); - const uint32_t uiChFinalMode = PU::getFinalIntraMode( pu, chType ); -@@ -311,10 +314,15 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID ) - } - #if KEEP_PRED_AND_RESI_SIGNALS - pReco.reconstruct( piPred, piResi, tu.cu->cs->slice->clpRng( compID ) ); -+#else -+#if DATA_PREDICTION -+ piPredTrue.copyFrom(piPred); -+ pReco.reconstruct(piPred, piResi, tu.cu->cs->slice->clpRng(compID)); - #else - piPred.reconstruct( piPred, piResi, tu.cu->cs->slice->clpRng( compID ) ); - #endif --#if !KEEP_PRED_AND_RESI_SIGNALS -+#endif -+#if !KEEP_PRED_AND_RESI_SIGNALS && !DATA_PREDICTION - pReco.copyFrom( piPred ); - #endif - if (slice.getLmcsEnabledFlag() && (m_pcReshape->getCTUFlag() || slice.isIntra()) && compID == COMPONENT_Y) -@@ -684,6 +692,10 @@ void DecCu::xReconInter(CodingUnit &cu) - DTRACE ( g_trace_ctx, D_TMP, "pred " ); - DTRACE_CRC( g_trace_ctx, D_TMP, *cu.cs, cu.cs->getPredBuf( cu ), &cu.Y() ); - -+#if DATA_PREDICTION -+ cu.cs->getTruePredBuf(cu).copyFrom(cu.cs->getPredBuf(cu)); -+#endif -+ - // inter recon - xDecodeInterTexture(cu); - -diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp -index bb5e51f6..f3287686 100644 ---- a/source/Lib/EncoderLib/EncLib.cpp -+++ b/source/Lib/EncoderLib/EncLib.cpp -@@ -657,6 +657,9 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu - } - #endif - -+#if DATA_GEN_ENC -+ ppsID = ENC_PPS_ID_RPR; -+#else - if( m_resChangeInClvsEnabled && m_intraPeriod == -1 ) - { - const int poc = m_iPOCLast + ( m_compositeRefEnabled ? 2 : 1 ); -@@ -675,6 +678,7 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu - { - ppsID = m_vps->getGeneralLayerIdx( m_layerId ); - } -+#endif - - xGetNewPicBuffer( rcListPicYuvRecOut, pcPicCurr, ppsID ); - -diff --git a/source/Lib/Utilities/VideoIOYuv.cpp b/source/Lib/Utilities/VideoIOYuv.cpp -index 8a30ccc5..7a271982 100644 ---- a/source/Lib/Utilities/VideoIOYuv.cpp -+++ b/source/Lib/Utilities/VideoIOYuv.cpp -@@ -1252,7 +1252,11 @@ void VideoIOYuv::ColourSpaceConvert(const CPelUnitBuf &src, PelUnitBuf &dest, co - } - } - --bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice, ChromaFormat format, const bool bClipToRec709 ) -+bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice, ChromaFormat format, const bool bClipToRec709 -+#if DATA_GEN_DEC -+ , Picture* pcPic -+#endif -+) - { - ChromaFormat chromaFormatIDC = sps.getChromaFormatIdc(); - bool ret = false; -@@ -1284,6 +1288,82 @@ bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPe - int xScale = ( ( refPicWidth << SCALE_RATIO_BITS ) + ( curPicWidth >> 1 ) ) / curPicWidth; - int yScale = ( ( refPicHeight << SCALE_RATIO_BITS ) + ( curPicHeight >> 1 ) ) / curPicHeight; - -+#if DATA_GEN_DEC -+ if (pcPic->cs->slice->getSliceType() == B_SLICE) -+ { -+ PelStorage upscaledRPR; -+ upscaledRPR.create( chromaFormatIDC, Area( Position(), Size( sps.getMaxPicWidthInLumaSamples(), sps.getMaxPicHeightInLumaSamples() ) ) ); -+ Picture::rescalePicture( std::pair<int, int>( xScale, yScale ), pic, pps.getScalingWindow(), upscaledRPR, afterScaleWindowFullResolution, chromaFormatIDC, sps.getBitDepths(), false, false, sps.getHorCollocatedChromaFlag(), sps.getVerCollocatedChromaFlag() ); -+ -+ char rec_out_name[200]; -+ strcpy(rec_out_name, global_str_name); -+ sprintf(rec_out_name + strlen(rec_out_name), "_poc%03d_qp%d.yuv", pcPic->cs->slice->getPOC(), pcPic->cs->slice->getSliceQp()); -+ FILE* fp_rec = fopen(rec_out_name, "wb"); -+ -+ char pre_out_name[200]; -+ strcpy(pre_out_name, global_str_name); -+ sprintf(pre_out_name + strlen(pre_out_name), "_poc%03d_qp%d_prediction.yuv", pcPic->cs->slice->getPOC(), pcPic->cs->slice->getSliceQp()); -+ FILE* fp_pre = fopen(pre_out_name, "wb"); -+ -+ char rpr_out_name[200]; -+ strcpy(rpr_out_name, global_str_name); -+ sprintf(rpr_out_name + strlen(rpr_out_name), "_poc%03d_qp%d_rpr.yuv", pcPic->cs->slice->getPOC(), pcPic->cs->slice->getSliceQp()); -+ FILE* fp_rpr = fopen(rpr_out_name, "wb"); -+ -+ int8_t temp[2]; -+ -+ uint32_t curLumaH = pps.getPicHeightInLumaSamples(); -+ uint32_t curLumaW = pps.getPicWidthInLumaSamples(); -+ -+ uint32_t oriLumaH = sps.getMaxPicHeightInLumaSamples(); -+ uint32_t oriLumaW = sps.getMaxPicWidthInLumaSamples(); -+ -+ for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) -+ { -+ ComponentID compID = ComponentID(compIdx); -+ const int chromascaleY = getComponentScaleY(compID, pic.chromaFormat); -+ const int chromascaleX = getComponentScaleX(compID, pic.chromaFormat); -+ -+ uint32_t curPicH = curLumaH >> chromascaleY; -+ uint32_t curPicW = curLumaW >> chromascaleX; -+ -+ uint32_t oriPicH = oriLumaH >> chromascaleY; -+ uint32_t oriPicW = oriLumaW >> chromascaleX; -+ -+ for (uint32_t j = 0; j < curPicH; j++) -+ { -+ for (uint32_t i = 0; i < curPicW; i++) -+ { -+ temp[0] = (pic.get(compID).at(i, j) >> 0) & 0xff; -+ temp[1] = (pic.get(compID).at(i, j) >> 8) & 0xff; -+ ::fwrite(temp, sizeof(temp[0]), 2, fp_rec); -+ -+ CHECK(pic.get(compID).at(i, j) < 0 || pic.get(compID).at(i, j) > 1023, ""); -+ -+ temp[0] = (pcPic->getTruePredBuf(compID).at(i, j) >> 0) & 0xff; -+ temp[1] = (pcPic->getTruePredBuf(compID).at(i, j) >> 8) & 0xff; -+ ::fwrite(temp, sizeof(temp[0]), 2, fp_pre); -+ -+ CHECK(pcPic->getTruePredBuf(compID).at(i, j) < 0 || pcPic->getTruePredBuf(compID).at(i, j) > 1023, ""); -+ } -+ } -+ for (uint32_t j = 0; j < oriPicH; j++) -+ { -+ for (uint32_t i = 0; i < oriPicW; i++) -+ { -+ temp[0] = (upscaledRPR.get(compID).at(i, j) >> 0) & 0xff; -+ temp[1] = (upscaledRPR.get(compID).at(i, j) >> 8) & 0xff; -+ ::fwrite(temp, sizeof(temp[0]), 2, fp_rpr); -+ -+ CHECK(upscaledRPR.get(compID).at(i, j) < 0 || upscaledRPR.get(compID).at(i, j) > 1023, ""); -+ } -+ } -+ } -+ ::fclose(fp_rec); -+ ::fclose(fp_pre); -+ ::fclose(fp_rpr); -+ } -+#endif - Picture::rescalePicture( std::pair<int, int>( xScale, yScale ), pic, pps.getScalingWindow(), upscaledPic, afterScaleWindowFullResolution, chromaFormatIDC, sps.getBitDepths(), false, false, sps.getHorCollocatedChromaFlag(), sps.getVerCollocatedChromaFlag() ); - - ret = write( sps.getMaxPicWidthInLumaSamples(), sps.getMaxPicHeightInLumaSamples(), upscaledPic, -diff --git a/source/Lib/Utilities/VideoIOYuv.h b/source/Lib/Utilities/VideoIOYuv.h -index bf2c4705..e4baec31 100644 ---- a/source/Lib/Utilities/VideoIOYuv.h -+++ b/source/Lib/Utilities/VideoIOYuv.h -@@ -101,7 +101,11 @@ public: - int getFileBitdepth( int ch ) { return m_fileBitdepth[ch]; } - - bool writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, -- const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice = 0, ChromaFormat format = NUM_CHROMA_FORMAT, const bool bClipToRec709 = false ); ///< write one upsaled YUV frame -+ const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice = 0, ChromaFormat format = NUM_CHROMA_FORMAT, const bool bClipToRec709 = false -+#if DATA_GEN_DEC -+ , Picture* pcPic = nullptr -+#endif -+ ); ///< write one upsaled YUV frame - - }; - --- -2.34.0.windows.1 - diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_I_slices.patch b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_I_slices.patch deleted file mode 100644 index df34598010965b2a58d63aa5e773d39b8a4429bd..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_I_slices.patch +++ /dev/null @@ -1,569 +0,0 @@ -From 10295eea930cf3be502f5c279618313443b2a5b6 Mon Sep 17 00:00:00 2001 -From: renjiechang <renjiechang@tencent.com> -Date: Tue, 14 Feb 2023 14:59:10 +0800 -Subject: [PATCH] Generate SR datasets for I slices - ---- - source/App/DecoderApp/DecApp.cpp | 26 ++++++- - source/App/DecoderApp/DecAppCfg.cpp | 4 ++ - source/App/EncoderApp/EncAppCfg.cpp | 8 +++ - source/Lib/CommonLib/CodingStructure.cpp | 48 +++++++++++++ - source/Lib/CommonLib/CodingStructure.h | 17 +++++ - source/Lib/CommonLib/Picture.cpp | 19 +++++ - source/Lib/CommonLib/Picture.h | 8 +++ - source/Lib/CommonLib/Rom.cpp | 5 ++ - source/Lib/CommonLib/Rom.h | 4 ++ - source/Lib/CommonLib/TypeDef.h | 3 + - source/Lib/DecoderLib/DecCu.cpp | 14 +++- - source/Lib/EncoderLib/EncLib.cpp | 4 ++ - source/Lib/Utilities/VideoIOYuv.cpp | 90 +++++++++++++++++++++++- - source/Lib/Utilities/VideoIOYuv.h | 6 +- - 14 files changed, 251 insertions(+), 5 deletions(-) - -diff --git a/source/App/DecoderApp/DecApp.cpp b/source/App/DecoderApp/DecApp.cpp -index 85f63bb0..1842d346 100644 ---- a/source/App/DecoderApp/DecApp.cpp -+++ b/source/App/DecoderApp/DecApp.cpp -@@ -88,6 +88,17 @@ uint32_t DecApp::decode() - EXIT( "Failed to open bitstream file " << m_bitstreamFileName.c_str() << " for reading" ) ; - } - -+#if DATA_GEN_DEC -+ strcpy(global_str_name, m_bitstreamFileName.c_str()); -+ { -+ size_t len = strlen(global_str_name); -+ for (size_t i = len - 1; i > len - 5; i--) -+ { -+ global_str_name[i] = 0; -+ } -+ } -+#endif -+ - InputByteStream bytestream(bitstreamFile); - - if (!m_outputDecodedSEIMessagesFilename.empty() && m_outputDecodedSEIMessagesFilename!="-") -@@ -678,7 +689,14 @@ void DecApp::xWriteOutput( PicList* pcListPic, uint32_t tId ) - ChromaFormat chromaFormatIDC = sps->getChromaFormatIdc(); - if( m_upscaledOutput ) - { -- m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ); -+ m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range -+#if DATA_GEN_DEC -+ , pcPic -+#endif -+ ); -+#if DATA_PREDICTION -+ pcPic->m_bufs[PIC_TRUE_PREDICTION].destroy(); -+#endif - } - else - { -@@ -825,7 +843,11 @@ void DecApp::xFlushOutput( PicList* pcListPic, const int layerId ) - ChromaFormat chromaFormatIDC = sps->getChromaFormatIdc(); - if( m_upscaledOutput ) - { -- m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ); -+ m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range -+#if DATA_GEN_DEC -+ , pcPic -+#endif -+ ); - } - else - { -diff --git a/source/App/DecoderApp/DecAppCfg.cpp b/source/App/DecoderApp/DecAppCfg.cpp -index d96c2049..ad912462 100644 ---- a/source/App/DecoderApp/DecAppCfg.cpp -+++ b/source/App/DecoderApp/DecAppCfg.cpp -@@ -124,7 +124,11 @@ bool DecAppCfg::parseCfg( int argc, char* argv[] ) - #endif - ("MCTSCheck", m_mctsCheck, false, "If enabled, the decoder checks for violations of mc_exact_sample_value_match_flag in Temporal MCTS ") - ("targetSubPicIdx", m_targetSubPicIdx, 0, "Specify which subpicture shall be written to output, using subpic index, 0: disabled, subpicIdx=m_targetSubPicIdx-1 \n" ) -+#if DATA_GEN_DEC -+ ( "UpscaledOutput", m_upscaledOutput, 2, "Upscaled output for RPR" ) -+#else - ( "UpscaledOutput", m_upscaledOutput, 0, "Upscaled output for RPR" ) -+#endif - ; - - po::setDefaults(opts); -diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp -index b38001eb..c43cd0e1 100644 ---- a/source/App/EncoderApp/EncAppCfg.cpp -+++ b/source/App/EncoderApp/EncAppCfg.cpp -@@ -1411,11 +1411,19 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) - ( "CCALF", m_ccalf, true, "Cross-component Adaptive Loop Filter" ) - ( "CCALFQpTh", m_ccalfQpThreshold, 37, "QP threshold above which encoder reduces CCALF usage") - ( "RPR", m_rprEnabledFlag, true, "Reference Sample Resolution" ) -+#if DATA_GEN_ENC -+ ( "ScalingRatioHor", m_scalingRatioHor, 2.0, "Scaling ratio in hor direction" ) -+ ( "ScalingRatioVer", m_scalingRatioVer, 2.0, "Scaling ratio in ver direction" ) -+ ( "FractionNumFrames", m_fractionOfFrames, 1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" ) -+ ( "SwitchPocPeriod", m_switchPocPeriod, 0, "Switch POC period for RPR" ) -+ ( "UpscaledOutput", m_upscaledOutput, 2, "Output upscaled (2), decoded but in full resolution buffer (1) or decoded cropped (0, default) picture for RPR" ) -+#else - ( "ScalingRatioHor", m_scalingRatioHor, 1.0, "Scaling ratio in hor direction" ) - ( "ScalingRatioVer", m_scalingRatioVer, 1.0, "Scaling ratio in ver direction" ) - ( "FractionNumFrames", m_fractionOfFrames, 1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" ) - ( "SwitchPocPeriod", m_switchPocPeriod, 0, "Switch POC period for RPR" ) - ( "UpscaledOutput", m_upscaledOutput, 0, "Output upscaled (2), decoded but in full resolution buffer (1) or decoded cropped (0, default) picture for RPR" ) -+#endif - ( "MaxLayers", m_maxLayers, 1, "Max number of layers" ) - #if JVET_S0163_ON_TARGETOLS_SUBLAYERS - ( "EnableOperatingPointInformation", m_OPIEnabled, false, "Enables writing of Operating Point Information (OPI)" ) -diff --git a/source/Lib/CommonLib/CodingStructure.cpp b/source/Lib/CommonLib/CodingStructure.cpp -index b655d445..15e542ba 100644 ---- a/source/Lib/CommonLib/CodingStructure.cpp -+++ b/source/Lib/CommonLib/CodingStructure.cpp -@@ -107,6 +107,9 @@ void CodingStructure::destroy() - parent = nullptr; - - m_pred.destroy(); -+#if DATA_PREDICTION -+ m_predTrue.destroy(); -+#endif - m_resi.destroy(); - m_reco.destroy(); - m_orgr.destroy(); -@@ -895,6 +898,9 @@ void CodingStructure::create(const ChromaFormat &_chromaFormat, const Area& _are - - m_reco.create( area ); - m_pred.create( area ); -+#if DATA_PREDICTION -+ m_predTrue.create( area ); -+#endif - m_resi.create( area ); - m_orgr.create( area ); - } -@@ -910,6 +916,9 @@ void CodingStructure::create(const UnitArea& _unit, const bool isTopLayer, const - - m_reco.create( area ); - m_pred.create( area ); -+#if DATA_PREDICTION -+ m_predTrue.create( area ); -+#endif - m_resi.create( area ); - m_orgr.create( area ); - } -@@ -1082,6 +1091,16 @@ void CodingStructure::rebindPicBufs() - { - m_pred.destroy(); - } -+#if DATA_PREDICTION -+ if (!picture->M_BUFS(0, PIC_TRUE_PREDICTION).bufs.empty()) -+ { -+ m_predTrue.createFromBuf(picture->M_BUFS(0, PIC_TRUE_PREDICTION)); -+ } -+ else -+ { -+ m_predTrue.destroy(); -+ } -+#endif - if (!picture->M_BUFS(0, PIC_RESIDUAL).bufs.empty()) - { - m_resi.createFromBuf(picture->M_BUFS(0, PIC_RESIDUAL)); -@@ -1240,12 +1259,20 @@ void CodingStructure::useSubStructure( const CodingStructure& subStruct, const C - if( parent ) - { - // copy data to picture -+#if DATA_PREDICTION -+ getTruePredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); -+ getPredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); -+#endif - if( cpyPred ) getPredBuf ( clippedArea ).copyFrom( subPredBuf ); - if( cpyResi ) getResiBuf ( clippedArea ).copyFrom( subResiBuf ); - if( cpyReco ) getRecoBuf ( clippedArea ).copyFrom( subRecoBuf ); - if( cpyOrgResi ) getOrgResiBuf( clippedArea ).copyFrom( subStruct.getOrgResiBuf( clippedArea ) ); - } - -+#if DATA_PREDICTION -+ picture->getTruePredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); -+ picture->getPredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); -+#endif - if( cpyPred ) picture->getPredBuf( clippedArea ).copyFrom( subPredBuf ); - if( cpyResi ) picture->getResiBuf( clippedArea ).copyFrom( subResiBuf ); - if( cpyReco ) picture->getRecoBuf( clippedArea ).copyFrom( subRecoBuf ); -@@ -1562,6 +1589,13 @@ const CPelBuf CodingStructure::getPredBuf(const CompArea &blk) const { r - PelUnitBuf CodingStructure::getPredBuf(const UnitArea &unit) { return getBuf(unit, PIC_PREDICTION); } - const CPelUnitBuf CodingStructure::getPredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_PREDICTION); } - -+#if DATA_PREDICTION -+ PelBuf CodingStructure::getTruePredBuf(const CompArea &blk) { return getBuf(blk, PIC_TRUE_PREDICTION); } -+const CPelBuf CodingStructure::getTruePredBuf(const CompArea &blk) const { return getBuf(blk, PIC_TRUE_PREDICTION); } -+ PelUnitBuf CodingStructure::getTruePredBuf(const UnitArea &unit) { return getBuf(unit, PIC_TRUE_PREDICTION); } -+const CPelUnitBuf CodingStructure::getTruePredBuf(const UnitArea &unit)const { return getBuf(unit, PIC_TRUE_PREDICTION); } -+#endif -+ - PelBuf CodingStructure::getResiBuf(const CompArea &blk) { return getBuf(blk, PIC_RESIDUAL); } - const CPelBuf CodingStructure::getResiBuf(const CompArea &blk) const { return getBuf(blk, PIC_RESIDUAL); } - PelUnitBuf CodingStructure::getResiBuf(const UnitArea &unit) { return getBuf(unit, PIC_RESIDUAL); } -@@ -1603,6 +1637,13 @@ PelBuf CodingStructure::getBuf( const CompArea &blk, const PictureType &type ) - - PelStorage* buf = type == PIC_PREDICTION ? &m_pred : ( type == PIC_RESIDUAL ? &m_resi : ( type == PIC_RECONSTRUCTION ? &m_reco : ( type == PIC_ORG_RESI ? &m_orgr : nullptr ) ) ); - -+#if DATA_PREDICTION -+ if (type == PIC_TRUE_PREDICTION) -+ { -+ buf = &m_predTrue; -+ } -+#endif -+ - CHECK( !buf, "Unknown buffer requested" ); - - CHECKD( !area.blocks[compID].contains( blk ), "Buffer not contained in self requested" ); -@@ -1637,6 +1678,13 @@ const CPelBuf CodingStructure::getBuf( const CompArea &blk, const PictureType &t - - const PelStorage* buf = type == PIC_PREDICTION ? &m_pred : ( type == PIC_RESIDUAL ? &m_resi : ( type == PIC_RECONSTRUCTION ? &m_reco : ( type == PIC_ORG_RESI ? &m_orgr : nullptr ) ) ); - -+#if DATA_PREDICTION -+ if (type == PIC_TRUE_PREDICTION) -+ { -+ buf = &m_predTrue; -+ } -+#endif -+ - CHECK( !buf, "Unknown buffer requested" ); - - CHECKD( !area.blocks[compID].contains( blk ), "Buffer not contained in self requested" ); -diff --git a/source/Lib/CommonLib/CodingStructure.h b/source/Lib/CommonLib/CodingStructure.h -index b5ae7ac6..cdd3fbf1 100644 ---- a/source/Lib/CommonLib/CodingStructure.h -+++ b/source/Lib/CommonLib/CodingStructure.h -@@ -62,6 +62,9 @@ enum PictureType - PIC_ORIGINAL_INPUT, - PIC_TRUE_ORIGINAL_INPUT, - PIC_FILTERED_ORIGINAL_INPUT, -+#if DATA_PREDICTION -+ PIC_TRUE_PREDICTION, -+#endif - NUM_PIC_TYPES - }; - extern XUCache g_globalUnitCache; -@@ -228,6 +231,9 @@ private: - std::vector<SAOBlkParam> m_sao; - - PelStorage m_pred; -+#if DATA_PREDICTION -+ PelStorage m_predTrue; -+#endif - PelStorage m_resi; - PelStorage m_reco; - PelStorage m_orgr; -@@ -268,6 +274,17 @@ public: - PelUnitBuf getPredBuf(const UnitArea &unit); - const CPelUnitBuf getPredBuf(const UnitArea &unit) const; - -+#if DATA_PREDICTION -+ PelBuf getTruePredBuf(const CompArea &blk); -+ const CPelBuf getTruePredBuf(const CompArea &blk) const; -+ PelUnitBuf getTruePredBuf(const UnitArea &unit); -+ const CPelUnitBuf getTruePredBuf(const UnitArea &unit) const; -+#endif -+ -+#if DATA_PREDICTION -+ PelUnitBuf getTruePredBuf() { return m_predTrue; } -+#endif -+ - PelBuf getResiBuf(const CompArea &blk); - const CPelBuf getResiBuf(const CompArea &blk) const; - PelUnitBuf getResiBuf(const UnitArea &unit); -diff --git a/source/Lib/CommonLib/Picture.cpp b/source/Lib/CommonLib/Picture.cpp -index a7205bad..d5d1400a 100644 ---- a/source/Lib/CommonLib/Picture.cpp -+++ b/source/Lib/CommonLib/Picture.cpp -@@ -277,6 +277,12 @@ void Picture::createTempBuffers( const unsigned _maxCUSize ) - { - M_BUFS( jId, PIC_PREDICTION ).create( chromaFormat, a, _maxCUSize ); - M_BUFS( jId, PIC_RESIDUAL ).create( chromaFormat, a, _maxCUSize ); -+ -+#if DATA_PREDICTION -+ const Area a_old(Position{ 0, 0 }, lumaSize()); -+ M_BUFS(jId, PIC_TRUE_PREDICTION).create(chromaFormat, a_old, _maxCUSize); -+#endif -+ - #if ENABLE_SPLIT_PARALLELISM - if (jId > 0) - { -@@ -305,6 +311,11 @@ void Picture::destroyTempBuffers() - { - M_BUFS(jId, t).destroy(); - } -+#if DATA_PREDICTION -+#if !DATA_GEN_DEC -+ if (t == PIC_TRUE_PREDICTION) M_BUFS(jId, t).destroy(); -+#endif -+#endif - #if ENABLE_SPLIT_PARALLELISM - if (t == PIC_RECONSTRUCTION && jId > 0) - { -@@ -344,6 +355,14 @@ const CPelBuf Picture::getPredBuf(const CompArea &blk) const { return getBu - PelUnitBuf Picture::getPredBuf(const UnitArea &unit) { return getBuf(unit, PIC_PREDICTION); } - const CPelUnitBuf Picture::getPredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_PREDICTION); } - -+#if DATA_PREDICTION -+ PelBuf Picture::getTruePredBuf(const ComponentID compID, bool wrap) { return getBuf(compID, PIC_TRUE_PREDICTION); } -+ PelBuf Picture::getTruePredBuf(const CompArea &blk) { return getBuf(blk, PIC_TRUE_PREDICTION); } -+const CPelBuf Picture::getTruePredBuf(const CompArea &blk) const { return getBuf(blk, PIC_TRUE_PREDICTION); } -+ PelUnitBuf Picture::getTruePredBuf(const UnitArea &unit) { return getBuf(unit, PIC_TRUE_PREDICTION); } -+const CPelUnitBuf Picture::getTruePredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_TRUE_PREDICTION); } -+#endif -+ - PelBuf Picture::getResiBuf(const CompArea &blk) { return getBuf(blk, PIC_RESIDUAL); } - const CPelBuf Picture::getResiBuf(const CompArea &blk) const { return getBuf(blk, PIC_RESIDUAL); } - PelUnitBuf Picture::getResiBuf(const UnitArea &unit) { return getBuf(unit, PIC_RESIDUAL); } -diff --git a/source/Lib/CommonLib/Picture.h b/source/Lib/CommonLib/Picture.h -index 66073bf6..b48a6099 100644 ---- a/source/Lib/CommonLib/Picture.h -+++ b/source/Lib/CommonLib/Picture.h -@@ -128,6 +128,14 @@ struct Picture : public UnitArea - PelUnitBuf getPredBuf(const UnitArea &unit); - const CPelUnitBuf getPredBuf(const UnitArea &unit) const; - -+#if DATA_PREDICTION -+ PelBuf getTruePredBuf(const ComponentID compID, bool wrap = false); -+ PelBuf getTruePredBuf(const CompArea &blk); -+ const CPelBuf getTruePredBuf(const CompArea &blk) const; -+ PelUnitBuf getTruePredBuf(const UnitArea &unit); -+ const CPelUnitBuf getTruePredBuf(const UnitArea &unit) const; -+#endif -+ - PelBuf getResiBuf(const CompArea &blk); - const CPelBuf getResiBuf(const CompArea &blk) const; - PelUnitBuf getResiBuf(const UnitArea &unit); -diff --git a/source/Lib/CommonLib/Rom.cpp b/source/Lib/CommonLib/Rom.cpp -index dc1c29ae..28ad2c4f 100644 ---- a/source/Lib/CommonLib/Rom.cpp -+++ b/source/Lib/CommonLib/Rom.cpp -@@ -53,6 +53,11 @@ CDTrace *g_trace_ctx = NULL; - #endif - bool g_mctsDecCheckEnabled = false; - -+#if DATA_GEN_DEC -+unsigned int global_cnt = 0; -+char global_str_name[200]; -+#endif -+ - //! \ingroup CommonLib - //! \{ - -diff --git a/source/Lib/CommonLib/Rom.h b/source/Lib/CommonLib/Rom.h -index e7352e3c..4d1b38a1 100644 ---- a/source/Lib/CommonLib/Rom.h -+++ b/source/Lib/CommonLib/Rom.h -@@ -44,6 +44,10 @@ - #include <stdio.h> - #include <iostream> - -+#if DATA_GEN_DEC -+extern unsigned int global_cnt; -+extern char global_str_name[200]; -+#endif - - //! \ingroup CommonLib - //! \{ -diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h -index 8af59c7f..2874459a 100644 ---- a/source/Lib/CommonLib/TypeDef.h -+++ b/source/Lib/CommonLib/TypeDef.h -@@ -50,6 +50,9 @@ - #include <assert.h> - #include <cassert> - -+#define DATA_GEN_ENC 1 // Encode frame by RPR downsampling -+#define DATA_GEN_DEC 1 // Decode bin files to generate dataset, which should be turned off when running the encoder -+#define DATA_PREDICTION 1 // Prediction data - // clang-format off - - //########### place macros to be removed in next cycle below this line ############### -diff --git a/source/Lib/DecoderLib/DecCu.cpp b/source/Lib/DecoderLib/DecCu.cpp -index eeec3474..844c7aac 100644 ---- a/source/Lib/DecoderLib/DecCu.cpp -+++ b/source/Lib/DecoderLib/DecCu.cpp -@@ -182,6 +182,9 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID ) - const ChannelType chType = toChannelType( compID ); - - PelBuf piPred = cs.getPredBuf( area ); -+#if DATA_PREDICTION -+ PelBuf piPredTrue = cs.getTruePredBuf(area); -+#endif - - const PredictionUnit &pu = *tu.cs->getPU( area.pos(), chType ); - const uint32_t uiChFinalMode = PU::getFinalIntraMode( pu, chType ); -@@ -311,10 +314,15 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID ) - } - #if KEEP_PRED_AND_RESI_SIGNALS - pReco.reconstruct( piPred, piResi, tu.cu->cs->slice->clpRng( compID ) ); -+#else -+#if DATA_PREDICTION -+ piPredTrue.copyFrom(piPred); -+ pReco.reconstruct(piPred, piResi, tu.cu->cs->slice->clpRng(compID)); - #else - piPred.reconstruct( piPred, piResi, tu.cu->cs->slice->clpRng( compID ) ); - #endif --#if !KEEP_PRED_AND_RESI_SIGNALS -+#endif -+#if !KEEP_PRED_AND_RESI_SIGNALS && !DATA_PREDICTION - pReco.copyFrom( piPred ); - #endif - if (slice.getLmcsEnabledFlag() && (m_pcReshape->getCTUFlag() || slice.isIntra()) && compID == COMPONENT_Y) -@@ -684,6 +692,10 @@ void DecCu::xReconInter(CodingUnit &cu) - DTRACE ( g_trace_ctx, D_TMP, "pred " ); - DTRACE_CRC( g_trace_ctx, D_TMP, *cu.cs, cu.cs->getPredBuf( cu ), &cu.Y() ); - -+#if DATA_PREDICTION -+ cu.cs->getTruePredBuf(cu).copyFrom(cu.cs->getPredBuf(cu)); -+#endif -+ - // inter recon - xDecodeInterTexture(cu); - -diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp -index bb5e51f6..f3287686 100644 ---- a/source/Lib/EncoderLib/EncLib.cpp -+++ b/source/Lib/EncoderLib/EncLib.cpp -@@ -657,6 +657,9 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu - } - #endif - -+#if DATA_GEN_ENC -+ ppsID = ENC_PPS_ID_RPR; -+#else - if( m_resChangeInClvsEnabled && m_intraPeriod == -1 ) - { - const int poc = m_iPOCLast + ( m_compositeRefEnabled ? 2 : 1 ); -@@ -675,6 +678,7 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu - { - ppsID = m_vps->getGeneralLayerIdx( m_layerId ); - } -+#endif - - xGetNewPicBuffer( rcListPicYuvRecOut, pcPicCurr, ppsID ); - -diff --git a/source/Lib/Utilities/VideoIOYuv.cpp b/source/Lib/Utilities/VideoIOYuv.cpp -index 8a30ccc5..3ea4d985 100644 ---- a/source/Lib/Utilities/VideoIOYuv.cpp -+++ b/source/Lib/Utilities/VideoIOYuv.cpp -@@ -1252,7 +1252,11 @@ void VideoIOYuv::ColourSpaceConvert(const CPelUnitBuf &src, PelUnitBuf &dest, co - } - } - --bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice, ChromaFormat format, const bool bClipToRec709 ) -+bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice, ChromaFormat format, const bool bClipToRec709 -+#if DATA_GEN_DEC -+ , Picture* pcPic -+#endif -+) - { - ChromaFormat chromaFormatIDC = sps.getChromaFormatIdc(); - bool ret = false; -@@ -1284,6 +1288,90 @@ bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPe - int xScale = ( ( refPicWidth << SCALE_RATIO_BITS ) + ( curPicWidth >> 1 ) ) / curPicWidth; - int yScale = ( ( refPicHeight << SCALE_RATIO_BITS ) + ( curPicHeight >> 1 ) ) / curPicHeight; - -+#if DATA_GEN_DEC -+ if (pcPic->cs->slice->getSliceType() == I_SLICE) -+ { -+ PelStorage upscaledRPR; -+ upscaledRPR.create( chromaFormatIDC, Area( Position(), Size( sps.getMaxPicWidthInLumaSamples(), sps.getMaxPicHeightInLumaSamples() ) ) ); -+ Picture::rescalePicture( std::pair<int, int>( xScale, yScale ), pic, pps.getScalingWindow(), upscaledRPR, afterScaleWindowFullResolution, chromaFormatIDC, sps.getBitDepths(), false, false, sps.getHorCollocatedChromaFlag(), sps.getVerCollocatedChromaFlag() ); -+ -+ char rec_out_name[200]; -+ strcpy(rec_out_name, global_str_name); -+ sprintf(rec_out_name + strlen(rec_out_name), "_poc%03d.yuv", pcPic->cs->slice->getPOC()); -+ FILE* fp_rec = fopen(rec_out_name, "wb"); -+ -+#if DATA_PREDICTION -+ char pre_out_name[200]; -+ strcpy(pre_out_name, global_str_name); -+ sprintf(pre_out_name + strlen(pre_out_name), "_poc%03d_prediction.yuv", pcPic->cs->slice->getPOC()); -+ FILE* fp_pre = fopen(pre_out_name, "wb"); -+#endif -+ -+ char rpr_out_name[200]; -+ strcpy(rpr_out_name, global_str_name); -+ sprintf(rpr_out_name + strlen(rpr_out_name), "_poc%03d_rpr.yuv", pcPic->cs->slice->getPOC()); -+ FILE* fp_rpr = fopen(rpr_out_name, "wb"); -+ -+ int8_t temp[2]; -+ -+ uint32_t curLumaH = pps.getPicHeightInLumaSamples(); -+ uint32_t curLumaW = pps.getPicWidthInLumaSamples(); -+ -+ uint32_t oriLumaH = sps.getMaxPicHeightInLumaSamples(); -+ uint32_t oriLumaW = sps.getMaxPicWidthInLumaSamples(); -+ -+ for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) -+ { -+ ComponentID compID = ComponentID(compIdx); -+ const int chromascaleY = getComponentScaleY(compID, pic.chromaFormat); -+ const int chromascaleX = getComponentScaleX(compID, pic.chromaFormat); -+ -+ uint32_t curPicH = curLumaH >> chromascaleY; -+ uint32_t curPicW = curLumaW >> chromascaleX; -+ -+ uint32_t oriPicH = oriLumaH >> chromascaleY; -+ uint32_t oriPicW = oriLumaW >> chromascaleX; -+ -+ for (uint32_t j = 0; j < curPicH; j++) -+ { -+ for (uint32_t i = 0; i < curPicW; i++) -+ { -+ temp[0] = (pic.get(compID).at(i, j) >> 0) & 0xff; -+ temp[1] = (pic.get(compID).at(i, j) >> 8) & 0xff; -+ ::fwrite(temp, sizeof(temp[0]), 2, fp_rec); -+ -+ CHECK(pic.get(compID).at(i, j) < 0 || pic.get(compID).at(i, j) > 1023, ""); -+ -+#if DATA_PREDICTION -+ temp[0] = (pcPic->getTruePredBuf(compID).at(i, j) >> 0) & 0xff; -+ temp[1] = (pcPic->getTruePredBuf(compID).at(i, j) >> 8) & 0xff; -+ ::fwrite(temp, sizeof(temp[0]), 2, fp_pre); -+ -+ CHECK(pcPic->getTruePredBuf(compID).at(i, j) < 0 || pcPic->getTruePredBuf(compID).at(i, j) > 1023, ""); -+#endif -+ } -+ } -+ for (uint32_t j = 0; j < oriPicH; j++) -+ { -+ for (uint32_t i = 0; i < oriPicW; i++) -+ { -+ temp[0] = (upscaledRPR.get(compID).at(i, j) >> 0) & 0xff; -+ temp[1] = (upscaledRPR.get(compID).at(i, j) >> 8) & 0xff; -+ ::fwrite(temp, sizeof(temp[0]), 2, fp_rpr); -+ -+ CHECK(upscaledRPR.get(compID).at(i, j) < 0 || upscaledRPR.get(compID).at(i, j) > 1023, ""); -+ } -+ } -+ } -+ ::fclose(fp_rec); -+#if DATA_PREDICTION -+ ::fclose(fp_pre); -+#endif -+ ::fclose(fp_rpr); -+ -+ global_cnt++; -+ } -+#endif - Picture::rescalePicture( std::pair<int, int>( xScale, yScale ), pic, pps.getScalingWindow(), upscaledPic, afterScaleWindowFullResolution, chromaFormatIDC, sps.getBitDepths(), false, false, sps.getHorCollocatedChromaFlag(), sps.getVerCollocatedChromaFlag() ); - - ret = write( sps.getMaxPicWidthInLumaSamples(), sps.getMaxPicHeightInLumaSamples(), upscaledPic, -diff --git a/source/Lib/Utilities/VideoIOYuv.h b/source/Lib/Utilities/VideoIOYuv.h -index bf2c4705..e4baec31 100644 ---- a/source/Lib/Utilities/VideoIOYuv.h -+++ b/source/Lib/Utilities/VideoIOYuv.h -@@ -101,7 +101,11 @@ public: - int getFileBitdepth( int ch ) { return m_fileBitdepth[ch]; } - - bool writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, -- const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice = 0, ChromaFormat format = NUM_CHROMA_FORMAT, const bool bClipToRec709 = false ); ///< write one upsaled YUV frame -+ const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice = 0, ChromaFormat format = NUM_CHROMA_FORMAT, const bool bClipToRec709 = false -+#if DATA_GEN_DEC -+ , Picture* pcPic = nullptr -+#endif -+ ); ///< write one upsaled YUV frame - - }; - --- -2.34.0.windows.1 - diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/bvi_dvc_codec_info.py b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/bvi_dvc_codec_info.py deleted file mode 100644 index 4e359b04ad9be3ef0ef646d26d2da799d5a7c6d7..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/bvi_dvc_codec_info.py +++ /dev/null @@ -1,2641 +0,0 @@ -SequenceTable = [ - [ - "B_S001", - "BAdvertisingMassagesBangkokVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S002", - "BAmericanFootballS2Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S003", - "BAmericanFootballS3Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S004", - "BAmericanFootballS4Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S005", - "BAnimalsS11Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S006", - "BAnimalsS1Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S007", - "BBangkokMarketVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S008", - "BBasketballGoalScoredS1Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S009", - "BBasketballGoalScoredS2Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S010", - "BBasketballS1YonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S011", - "BBasketballS2YonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S012", - "BBasketballS3YonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S013", - "BBoatsChaoPhrayaRiverVidevo_1920x1088_23fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 23, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S014", - "BBobbleheadBVIHFR_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S015", - "BBookcaseBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S016", - "BBoxingPracticeHarmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S017", - "BBricksBushesStaticBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S018", - "BBricksLeavesBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S019", - "BBricksTiltingBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S020", - "BBubblesPitcherS1BVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S021", - "BBuildingRoofS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S022", - "BBuildingRoofS2IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S023", - "BBuildingRoofS3IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S024", - "BBuildingRoofS4IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S025", - "BBuntingHangingAcrossHongKongVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S026", - "BBusyHongKongStreetVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S027", - "BCalmingWaterBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S028", - "BCarpetPanAverageBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S029", - "BCatchBVIHFR_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S030", - "BCeramicsandSpicesMoroccoVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S031", - "BCharactersYonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S032", - "BChristmasPresentsIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S033", - "BChristmasRoomDareful_1920x1088_29fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 29, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S034", - "BChurchInsideMCLJCV_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S035", - "BCityScapesS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S036", - "BCityScapesS2IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S037", - "BCityScapesS3IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S038", - "BCityStreetS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S039", - "BCityStreetS3IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S040", - "BCityStreetS4IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S041", - "BCityStreetS5IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S042", - "BCityStreetS6IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S043", - "BCityStreetS7IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S044", - "BCloseUpBasketballSceneVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S045", - "BCloudsStaticBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S046", - "BColourfulDecorationWatPhoVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S047", - "BColourfulKoreanLanternsVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S048", - "BColourfulPaperLanternsVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S049", - "BColourfulRugsMoroccoVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S050", - "BConstructionS2YonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S051", - "BCostaRicaS3Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S052", - "BCrosswalkHarmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S053", - "BCrosswalkHongKong2S1Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S054", - "BCrosswalkHongKong2S2Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S055", - "BCrosswalkHongKongVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S056", - "BCrowdRunMCLV_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S057", - "BCyclistS1BVIHFR_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S058", - "BCyclistVeniceBeachBoardwalkVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S059", - "BDollsScene1YonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S060", - "BDollsScene2YonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S061", - "BDowntownHongKongVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S062", - "BDrivingPOVHarmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S063", - "BDropsOnWaterBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S064", - "BElFuenteMaskLIVENetFlix_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S065", - "BEnteringHongKongStallS1Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S066", - "BEnteringHongKongStallS2Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S067", - "BFerrisWheelTurningVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S068", - "BFireS18Mitch_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S069", - "BFireS21Mitch_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S070", - "BFireS71Mitch_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S071", - "BFirewoodS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S072", - "BFirewoodS2IRIS_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S073", - "BFitnessIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S074", - "BFjordsS1Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S075", - "BFlagShootTUMSVT_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S076", - "BFlowerChapelS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S077", - "BFlowerChapelS2IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S078", - "BFlyingCountrysideDareful_1920x1088_29fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 29, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S079", - "BFlyingMountainsDareful_1920x1088_29fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 29, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S080", - "BFlyingThroughLAStreetVidevo_1920x1088_23fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 23, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S081", - "BFungusZoomBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S082", - "BGrassBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S083", - "BGrazTowerIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S084", - "BHamsterBVIHFR_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S085", - "BHarleyDavidsonIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S086", - "BHongKongIslandVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S087", - "BHongKongMarket1Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S088", - "BHongKongMarket2Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S089", - "BHongKongMarket3S1Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S090", - "BHongKongMarket3S2Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S091", - "BHongKongMarket4S1Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S092", - "BHongKongMarket4S2Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S093", - "BHongKongS1Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S094", - "BHongKongS2Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S095", - "BHongKongS3Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S096", - "BHorseDrawnCarriagesVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S097", - "BHorseStaringS1Videvo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S098", - "BHorseStaringS2Videvo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S099", - "BJockeyHarmonics_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S100", - "BJoggersS1BVIHFR_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S101", - "BJoggersS2BVIHFR_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S102", - "BKartingIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S103", - "BKoraDrumsVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S104", - "BLakeYonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S105", - "BLampLeavesBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S106", - "BLaundryHangingOverHongKongVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S107", - "BLeaves1BVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S108", - "BLeaves3BVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S109", - "BLowLevelShotAlongHongKongVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S110", - "BLungshanTempleS1Videvo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S111", - "BLungshanTempleS2Videvo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S112", - "BManMoTempleVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S113", - "BManStandinginProduceTruckVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S114", - "BManWalkingThroughBangkokVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S115", - "BMaplesS1YonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S116", - "BMaplesS2YonseiUniversity_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S117", - "BMirabellParkS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S118", - "BMirabellParkS2IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S119", - "BMobileHarmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S120", - "BMoroccanCeramicsShopVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S121", - "BMoroccanSlippersVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S122", - "BMuralPaintingVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S123", - "BMyanmarS4Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S124", - "BMyanmarS6Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S125", - "BMyeongDongVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S126", - "BNewYorkStreetDareful_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S127", - "BOrangeBuntingoverHongKongVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S128", - "BPaintingTiltingBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S129", - "BParkViolinMCLJCV_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S130", - "BPedestriansSeoulatDawnVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S131", - "BPeopleWalkingS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S132", - "BPersonRunningOutsideVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S133", - "BPillowsTransBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S134", - "BPlasmaFreeBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S135", - "BPresentsChristmasTreeDareful_1920x1088_29fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 29, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S136", - "BReadySetGoS2TampereUniversity_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S137", - "BResidentialBuildingSJTU_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S138", - "BRollerCoaster2Netflix_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S139", - "BRunnersSJTU_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S140", - "BRuralSetupIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S141", - "BRuralSetupS2IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S142", - "BScarfSJTU_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S143", - "BSeasideWalkIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S144", - "BSeekingMCLV_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S145", - "BSeoulCanalatDawnVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S146", - "BShoppingCentreVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S147", - "BSignboardBoatLIVENetFlix_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S148", - "BSkyscraperBangkokVidevo_1920x1088_23fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 23, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S149", - "BSmokeClearBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S150", - "BSmokeS45Mitch_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S151", - "BSparklerBVIHFR_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S152", - "BSquareAndTimelapseHarmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S153", - "BSquareS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S154", - "BSquareS2IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S155", - "BStreetArtVidevo_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S156", - "BStreetDancerS1IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S157", - "BStreetDancerS2IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S158", - "BStreetDancerS3IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S159", - "BStreetDancerS4IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S160", - "BStreetDancerS5IRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S161", - "BStreetsOfIndiaS1Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S162", - "BStreetsOfIndiaS2Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S163", - "BStreetsOfIndiaS3Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S164", - "BTaiChiHongKongS1Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S165", - "BTaiChiHongKongS2Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S166", - "BTaipeiCityRooftops8Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S167", - "BTaipeiCityRooftopsS1Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S168", - "BTaipeiCityRooftopsS2Videvo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S169", - "BTaksinBridgeVidevo_1920x1088_23fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 23, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S170", - "BTallBuildingsSJTU_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S171", - "BTennisMCLV_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S172", - "BToddlerFountain2Netflix_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S173", - "BTouristsSatOutsideVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S174", - "BToyCalendarHarmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S175", - "BTrackingDownHongKongSideVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S176", - "BTrackingPastRestaurantVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S177", - "BTrackingPastStallHongKongVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S178", - "BTraditionalIndonesianKecakVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S179", - "BTrafficandBuildingSJTU_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S180", - "BTrafficFlowSJTU_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S181", - "BTrafficonTasksinBridgeVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S182", - "BTreeWillsBVITexture_1920x1088_120fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 120, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S183", - "BTruckIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S184", - "BTunnelFlagS1Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S185", - "BUnloadingVegetablesVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S186", - "BVegetableMarketS1LIVENetFlix_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S187", - "BVegetableMarketS2LIVENetFlix_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S188", - "BVegetableMarketS3LIVENetFlix_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S189", - "BVegetableMarketS4LIVENetFlix_1920x1088_30fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 30, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S190", - "BVeniceS1Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S191", - "BVeniceS2Harmonics_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S192", - "BVeniceSceneIRIS_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S193", - "BWalkingDownKhaoStreetVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S194", - "BWalkingDownNorthRodeoVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S195", - "BWalkingThroughFootbridgeVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S196", - "BWaterS65Mitch_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S197", - "BWaterS81Mitch_1920x1088_24fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S198", - "BWatPhoTempleVidevo_1920x1088_50fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S199", - "BWoodSJTU_1920x1088_60fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 60, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "B_S200", - "BWovenVidevo_1920x1088_25fps_10bit_420.yuv", - 1920, - 1088, - 0, - 64, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], -] - - -TemporalSubsampleRatio = 1 - -for seq_data in SequenceTable: - ( - seq_key, - seq_file_name, - width, - height, - StartFrame, - FramesToBeEncoded, - FrameRate, - InputBitDepth, - QPs, - RateNames, - level, - ) = seq_data - - QPs = QPs.split(" ") - RateNames = RateNames.split(" ") - QPs_RateNames = [] - for QP, RateName in zip(QPs, RateNames): - commonFileName = ( - "T2RA_" - + seq_key - + "_" - + RateName - + "_qp" - + QP - + "_s" - + str(StartFrame) - + "_f" - + str(FramesToBeEncoded) - + "_t" - + str(TemporalSubsampleRatio) - ) - binFile = "Bin_" + commonFileName + ".bin" - print(binFile) diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/tvd_codec_info.py b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/tvd_codec_info.py deleted file mode 100644 index ebdf574090ee4186041fde8c3c0012c98a108c4d..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/tvd_codec_info.py +++ /dev/null @@ -1,1003 +0,0 @@ -SequenceTable = [ - [ - "A_S01", - "Bamboo_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S02", - "BlackBird_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S03", - "BoyDressing1_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S04", - "BoyDressing2_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S05", - "BoyMakingUp1_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S06", - "BoyMakingUp2_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S07", - "BoyWithCostume_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S08", - "BuildingTouristAttraction1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S09", - "BuildingTouristAttraction2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S10", - "BuildingTouristAttraction3_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S11", - "CableCar_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S12", - "ChefCooking1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S13", - "ChefCooking2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S14", - "ChefCooking3_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S15", - "ChefCooking4_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S16", - "ChefCooking5_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S17", - "ChefCuttingUp1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S18", - "ChefCuttingUp2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S19", - "DryRedPepper_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S20", - "FilmMachine_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S21", - "FlowingWater_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S22", - "Fountain_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S23", - "GirlRunningOnGrass_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S24", - "GirlWithTeaSet1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S25", - "GirlWithTeaSet2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S26", - "GirlWithTeaSet3_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S27", - "GirlsOnGrass1_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S28", - "GirlsOnGrass2_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S29", - "HotPot_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S30", - "HotelClerks_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S31", - "LyingDog_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S32", - "ManWithFilmMachine_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S33", - "MountainsAndStairs1_3840x2160_24fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 24, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S34", - "MountainsAndStairs2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S35", - "MountainsAndStairs3_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S36", - "MountainsAndStairs4_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S37", - "MountainsView1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S38", - "MountainsView2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S39", - "MountainsView3_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S40", - "MountainsView4_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S41", - "MovingBikesAndPedestrian4_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S42", - "OilPainting1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S43", - "OilPainting2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S44", - "PeopleNearDesk_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S45", - "PeopleOnGrass_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S46", - "Plaque_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S47", - "PressureCooker_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S48", - "RawDucks_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S49", - "RedBush_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S50", - "RedRibbonsWithLocks_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S51", - "RestaurantWaitress1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S52", - "RestaurantWaitress2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S53", - "RiverAndTrees_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S54", - "RoastedDuck_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S55", - "RoomTouristAttraction1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S56", - "RoomTouristAttraction2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S57", - "RoomTouristAttraction3_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S58", - "RoomTouristAttraction4_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S59", - "RoomTouristAttraction5_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S60", - "RoomTouristAttraction6_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S61", - "RoomTouristAttraction7_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S62", - "StampCarving1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S63", - "StampCarving2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S64", - "StaticRocks_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S65", - "StaticWaterAndBikes2_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S66", - "SunAndTrees_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S67", - "SunriseMountainHuang_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S68", - "SunsetMountainHuang1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S69", - "SunsetMountainHuang2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S70", - "TreesAndLeaves_3840x2160_50fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 50, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S71", - "TreesOnMountains1_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S72", - "TreesOnMountains2_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S73", - "TreesOnMountains3_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], - [ - "A_S74", - "Weave_3840x2160_25fps_10bit_420.yuv", - 3840, - 2160, - 0, - 65, - 25, - 10, - "22 27 32 37 42", - "R22 R27 R32 R37 R42", - 4.1, - ], -] - - -TemporalSubsampleRatio = 1 - -for seq_data in SequenceTable: - ( - seq_key, - seq_file_name, - width, - height, - StartFrame, - FramesToBeEncoded, - FrameRate, - InputBitDepth, - QPs, - RateNames, - level, - ) = seq_data - - QPs = QPs.split(" ") - RateNames = RateNames.split(" ") - QPs_RateNames = [] - for QP, RateName in zip(QPs, RateNames): - commonFileName = ( - "T2RA_" - + seq_key - + "_" - + RateName - + "_qp" - + QP - + "_s" - + str(StartFrame) - + "_f" - + str(FramesToBeEncoded) - + "_t" - + str(TemporalSubsampleRatio) - ) - binFile = "Bin_" + commonFileName + ".bin" - print(binFile) diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/3_ReadMe.md b/training/training_scripts/NN_Super_Resolution/3_train_tasks/3_ReadMe.md deleted file mode 100644 index 927be4d440e8aac76987e33eaef441d369960b9f..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/3_ReadMe.md +++ /dev/null @@ -1,72 +0,0 @@ -## training process -For NNSR, a total of three networks including two luma networks and one chroma network need to be trained. -### How to monitor the training process - -Tensorboard is used here to monitor the training process. -1. Run the following command in the path 'Experiments' generated by the training script -``` -tensorboard --logdir=Tensorboard --port=6001 -``` -2. Access http://localhost:6001/ to view the result - - -### The training stage - -Launch the training stage by the following command: -``` -sh train.sh -``` -The following lines to set the dataset paths may need to be revised according to your local side. - -| path | row | dataset | format | -| :-------------------------- | :----- | :----------------------------- | :-------------- | - -| ./training_scripts/Luma-I | line 29 | The compression data of BVI-DVC | frame level YUV | -| ./training_scripts/Luma-I | line 30 | The raw data of BVI-DVC | frame level YUV | -| ./training_scripts/Luma-I | line 33 | The compression data of TVD | frame level YUV | -| ./training_scripts/Luma-I | line 34 | The raw data of TVD | frame level YUV | - -| ./training_scripts/Luma-B | line 29 | The compression data of BVI-DVC | frame level YUV | -| ./training_scripts/Luma-B | line 30 | The raw data of BVI-DVC | frame level YUV | -| ./training_scripts/Luma-B | line 33 | The compression data of TVD | frame level YUV | -| ./training_scripts/Luma-B | line 34 | The raw data of TVD | frame level YUV | - -| ./training_scripts/Chroma-IB | line 29 | The compression data of BVI-DVC | frame level YUV | -| ./training_scripts/Chroma-IB | line 30 | The raw data of BVI-DVC | frame level YUV | -| ./training_scripts/Chroma-IB | line 33 | The compression data of TVD | frame level YUV | -| ./training_scripts/Chroma-IB | line 34 | The raw data of TVD | frame level YUV | -| ./training_scripts/Chroma-IB | line 37 | The compression data of BVI-DVC | frame level YUV | -| ./training_scripts/Chroma-IB | line 38 | The raw data of BVI-DVC | frame level YUV | -| ./training_scripts/Chroma-IB | line 41 | The compression data of TVD | frame level YUV | -| ./training_scripts/Chroma-IB | line 42 | The raw data of TVD | frame level YUV | - -The convergence curve for different models on the validation set is shown below. - - - - - -The convergence curve above is enlarged to select the optimal model as follows. The model selection is decided based on the PSNR improvement of Y, Cb and Cr. -Finally, the selected optimal model and its training epoch are shown as follows in the training stage. - -| network | epoch | -| :----- | :------------ | -| Luma-I | model_0820.pth | -| Luma-B | model_0925.pth | -| Chroma-IB | model_0315.pth | - -### Convert to Libtorch model -1. Select the optimal model of the training stage as the final model. -2. Use the following command to get the converted models (Luma-I.pt Luma-B.pt Chroma-IB.pt). -``` -sh conversion.sh -``` -Please note that to successfully generate different models, the final model path (in model_conversion.py line 9), coversion code (in model_conversion.py line 21/24/27) and generated model name (in model_conversion.py line 29) should be consistent. - -### Other -Noted that the models in this training scripts are generated by **pytorch-1.9.0**. -The corresponding Libtorch or Pytorch version should be used when loading the model. Otherwise, it is likely to get an error. - -Empirically, it is best to perform training until the PSNR results on the validation set begin to decline. So for the provided training scripts, a fixed and relative large 2000 training epochs is set. -Because of the randomness in the training process, the number of training epochs provided in the training stage is only for the reference. -It is suggested to run the task until the number of training epochs is greater than the provided one (Luma-I: 0820, Luma-B: 0925, Chroma-IB:0315) directly, and then find the optimal result around. diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/conversion.sh b/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/conversion.sh deleted file mode 100644 index 36bb9a8fae84578d8d5af9dd61cd3cec6ef71dea..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/conversion.sh +++ /dev/null @@ -1,32 +0,0 @@ -# The copyright in this software is being made available under the BSD -# License, included below. This software may be subject to other third party -# and contributor rights, including patent rights, and no such rights are -# granted under this license. -# -# Copyright (c) 2010-2022, ITU/ISO/IEC -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -# be used to endorse or promote products derived from this software without -# specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -# THE POSSIBILITY OF SUCH DAMAGE. -python model_conversion.py \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/model_conversion.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/model_conversion.py deleted file mode 100644 index 8202b81a56f58b9ab4a83e8745622a083fc37b03..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/model_conversion.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -from nn_model import Net -import torch - -if __name__ == "__main__": - model = Net() - model_name = "model_0315.pth" - model.load_state_dict( - torch.load(model_name, map_location=lambda storage, loc: storage)["network"] - ) - model.eval() - - example1 = torch.ones(1, 3, 144, 144) - example2 = torch.ones(1, 1, 144, 144) - example3 = torch.ones(1, 1, 72, 72) - example4 = torch.ones(1, 2, 72, 72) - example5 = torch.ones(1, 2, 144, 144) - example6 = torch.ones(1, 3, 72, 72) - - # Luma-I - # traced_script_module = torch.jit.trace(model, [example3, example3, example2, example3]) - - # Luma-B - # traced_script_module = torch.jit.trace(model, [example3, example3, example2, example3, example3]) - - # Chroma-IB - traced_script_module = torch.jit.trace( - model, [example2, example4, example5, example3, example3, example3] - ) - - traced_script_module.save("Chroma-IB.pt") diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/Utils.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/Utils.py deleted file mode 100644 index c13615186e3f42e15ade25e44a6191b3ef840103..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/Utils.py +++ /dev/null @@ -1,345 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import argparse -import logging -import random -from pathlib import Path -import numpy as np -import PIL.Image as Image -import os - - -import torch -from torch.utils.tensorboard import SummaryWriter - - -def parse_args(): - parser = argparse.ArgumentParser() - - path_cur = Path(os.path.split(os.path.realpath(__file__))[0]) - path_save = path_cur.joinpath("Experiments") - - # for loading data - parser.add_argument("--ext", type=str, default="yuv", help="data file extension") - - parser.add_argument( - "--data_range_bvi_AI", - type=str, - default="1-180/181-200", - help="train/test data range", - ) - parser.add_argument( - "--dir_data_bvi_AI", - type=str, - default="/path/EE1_2_2_train/AI_BVI_DVC", - help="distorted dataset directory", - ) - parser.add_argument( - "--dir_data_ori_bvi_AI", - type=str, - default="/path/EE1_2_2_train_ori/BVI_DVC", - help="raw dataset directory", - ) - - parser.add_argument( - "--data_range_tvd_AI", - type=str, - default="1-66/67-74", - help="train/test data range", - ) - parser.add_argument( - "--dir_data_tvd_AI", - type=str, - default="/path/EE1_2_2_train/AI_TVD", - help="distorted dataset directory", - ) - parser.add_argument( - "--dir_data_ori_tvd_AI", - type=str, - default="/path/EE1_2_2_train_ori/TVD", - help="raw dataset directory", - ) - - parser.add_argument( - "--data_range_bvi_RA", - type=str, - default="1-180/181-200", - help="train/test data range", - ) - parser.add_argument( - "--dir_data_bvi_RA", - type=str, - default="/path/EE1_2_2_train/RA_BVI_DVC", - help="distorted dataset directory", - ) - parser.add_argument( - "--dir_data_ori_bvi_RA", - type=str, - default="/path/EE1_2_2_train_ori/BVI_DVC", - help="raw dataset directory", - ) - - parser.add_argument( - "--data_range_tvd_RA", - type=str, - default="1-66/67-74", - help="train/test data range", - ) - parser.add_argument( - "--dir_data_tvd_RA", - type=str, - default="/path/EE1_2_2_train/RA_TVD", - help="distorted dataset directory", - ) - parser.add_argument( - "--dir_data_ori_tvd_RA", - type=str, - default="/path/EE1_2_2_train_ori/TVD", - help="raw dataset directory", - ) - - # for loading model - parser.add_argument("--checkpoints", type=str, help="checkpoints file path") - parser.add_argument("--pretrained", type=str, help="pretrained model path") - - # batch size - parser.add_argument( - "--batch_size", type=int, default=64, help="batch size for Fusion stage" - ) - # do validation - parser.add_argument( - "--test_every", type=int, default=1200, help="do test per every N batches" - ) - - # learning rate - parser.add_argument( - "--lr", type=float, default=1e-4, help="learning rate for Fusion stage" - ) - - parser.add_argument( - "--gpu", action="store_true", default=True, help="use gpu or cpu" - ) - - # epoch - parser.add_argument( - "--max_epoch", type=int, default=2000, help="max training epochs" - ) - - # patch_size - parser.add_argument( - "--patch_size", type=int, default=256, help="train/val patch size" - ) - parser.add_argument("--shave", type=int, default=8, help="train/shave") - - # for recording - parser.add_argument( - "--verbose", - action="store_true", - default=True, - help="use tensorboard and logger", - ) - parser.add_argument( - "--save_dir", type=str, default=path_save, help="directory for recording" - ) - parser.add_argument( - "--eval_epochs", type=int, default=5, help="save model after epochs" - ) - - args = parser.parse_args() - return args - - -def init(): - # parse arguments - args = parse_args() - - # create directory for recording - experiment_dir = Path(args.save_dir) - experiment_dir.mkdir(exist_ok=True) - - ckpt_dir = experiment_dir.joinpath("Checkpoints/") - ckpt_dir.mkdir(exist_ok=True) - print(r"===========Save checkpoints to {0}===========".format(str(ckpt_dir))) - - if args.verbose: - # initialize logger - log_dir = experiment_dir.joinpath("Log/") - log_dir.mkdir(exist_ok=True) - logger = logging.getLogger() - logger.setLevel(logging.INFO) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - file_handler = logging.FileHandler(str(log_dir) + "/Log.txt") - file_handler.setLevel(logging.INFO) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - logger.info("PARAMETER ...") - logger.info(args) - # initialize tensorboard - tb_dir_all = experiment_dir.joinpath("Tensorboard_all/") - tb_dir_all.mkdir(exist_ok=True) - tensorboard_all = SummaryWriter(log_dir=str(tb_dir_all), flush_secs=30) - - tb_dir = experiment_dir.joinpath("Tensorboard/") - tb_dir.mkdir(exist_ok=True) - tensorboard = SummaryWriter(log_dir=str(tb_dir), flush_secs=30) - print( - r"===========Save tensorboard and logger to {0}===========".format( - str(tb_dir_all) - ) - ) - else: - print( - r"===========Disable tensorboard and logger to accelerate training===========" - ) - logger = None - tensorboard_all = None - tensorboard = None - - return args, logger, ckpt_dir, tensorboard_all, tensorboard - - -def yuv_read(yuv_path, h, w, iy, ix, ip): - h_c = h // 2 - w_c = w // 2 - - ip_c = ip // 2 - iy_c = iy // 2 - ix_c = ix // 2 - - fp = open(yuv_path, "rb") - - # y - fp.seek(iy * w * 2, 0) - patch_y = np.fromfile(fp, np.uint16, ip * w).reshape(ip, w, 1) - patch_y = patch_y[:, ix : ix + ip, :] - - # u - fp.seek((w * h + iy_c * w_c) * 2, 0) - patch_u = np.fromfile(fp, np.uint16, ip_c * w_c).reshape(ip_c, w_c, 1) - patch_u = patch_u[:, ix_c : ix_c + ip_c, :] - - # v - fp.seek((w * h + w_c * h_c + iy_c * w_c) * 2, 0) - patch_v = np.fromfile(fp, np.uint16, ip_c * w_c).reshape(ip_c, w_c, 1) - patch_v = patch_v[:, ix_c : ix_c + ip_c, :] - - fp.close() - - return patch_y, patch_u, patch_v - - -def upsample(img, height, width): - img = np.squeeze(img, axis=2) - img = np.array( - Image.fromarray(img.astype(np.float)).resize((width, height), Image.NEAREST) - ) - img = np.expand_dims(img, axis=2) - return img - - -def patch_process(yuv_path, h, w, iy, ix, ip): - y, u, v = yuv_read(yuv_path, h, w, iy, ix, ip) - # u_up = upsample(u, ip, ip) - # v_up = upsample(v, ip, ip) - # yuv = np.concatenate((y, u_up, v_up), axis=2) - return y, u, v - - -def get_patch( - image_yuv_path, image_yuv_rpr_path, image_yuv_org_path, w, h, patch_size, shave -): - ih = h - iw = w - - ip = patch_size - ih -= ih % ip - iw -= iw % ip - iy = random.randrange(ip, ih - ip, ip) - shave - ix = random.randrange(ip, iw - ip, ip) - shave - - # - patch_rec_y, patch_rec_u, patch_rec_v = patch_process( - image_yuv_path, h // 2, w // 2, iy // 2, ix // 2, (ip + 2 * shave) // 2 - ) - _, patch_rpr_u, patch_rpr_v = patch_process( - image_yuv_rpr_path, h, w, iy, ix, ip + 2 * shave - ) - _, patch_org_u, patch_org_v = patch_process( - image_yuv_org_path, h, w, iy, ix, ip + 2 * shave - ) - - patch_in = np.concatenate((patch_rec_u, patch_rec_v), axis=2) - patch_rpr = np.concatenate((patch_rpr_u, patch_rpr_v), axis=2) - patch_org = np.concatenate((patch_org_u, patch_org_v), axis=2) - - ret = [patch_rec_y, patch_in, patch_rpr, patch_org] - - return ret - - -def augment(*args): - x = random.random() - hflip = x < 0.2 - vflip = x >= 0.2 and x < 0.4 - rot90 = x >= 0.4 and x < 0.6 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - - return img - - return [_augment(a) for a in args] - - -def np2Tensor(*args): - def _np2Tensor(img): - np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) - tensor = torch.from_numpy(np_transpose.astype(np.int32)).float() / 1023.0 - - return tensor - - return [_np2Tensor(a) for a in args] - - -def cal_psnr(distortion: torch.Tensor): - psnr = -10 * torch.log10(distortion) - return psnr diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/nn_model.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/nn_model.py deleted file mode 100644 index a921005ad3edc1a00eefa5e7d4f2fc824a129acc..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/nn_model.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import torch -import torch.nn as nn - - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - # hyper-params - n_resblocks = 24 - n_feats_k = 64 - n_feats_m = 192 - - # define head module - self.head_rec_y = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=n_feats_k, - kernel_size=3, - stride=2, - padding=1, - ), - nn.PReLU(), - ) - self.head_rec = nn.Sequential( - nn.Conv2d( - in_channels=2, - out_channels=n_feats_k, - kernel_size=3, - stride=1, - padding=1, - ), # downsmaple by stride = 2 - nn.PReLU(), - ) - - # define fuse module - self.fuse = nn.Sequential( - nn.Conv2d( - in_channels=n_feats_k * 2 + 3, - out_channels=n_feats_k, - kernel_size=1, - stride=1, - padding=0, - ), - nn.PReLU(), - ) - - # define body module - body = [] - for _ in range(n_resblocks): - body.append(DscBlock(n_feats_k, n_feats_m)) - - self.body = nn.Sequential(*body) - - # define tail module - self.tail = nn.Sequential( - nn.Conv2d( - in_channels=n_feats_k, out_channels=4 * 2, kernel_size=3, padding=1 - ), - nn.PixelShuffle(2), # feature_map:(B, 2x2x2, N, N) -> (B, 2, 2N, 2N) - ) - - def forward(self, y_rec, uv_rec, uv_rpr, slice_qp, base_qp, slice_type): - in_0 = self.head_rec_y(y_rec) - in_1 = self.head_rec(uv_rec) - - x = self.fuse(torch.cat((in_0, in_1, slice_qp, base_qp, slice_type), 1)) - x = self.body(x) - x = self.tail(x) - x[:, 0:1, :, :] += uv_rpr[:, 0:1, :, :] - x[:, 1:2, :, :] += uv_rpr[:, 1:2, :, :] - - return x - - -class DscBlock(nn.Module): - def __init__(self, n_feats_k, n_feats_m, expansion=1): - super(DscBlock, self).__init__() - self.expansion = expansion - self.c1 = nn.Conv2d( - in_channels=n_feats_k, out_channels=n_feats_m, kernel_size=1, padding=0 - ) - self.prelu = nn.PReLU() - self.c2 = nn.Conv2d( - in_channels=n_feats_m, out_channels=n_feats_k, kernel_size=1, padding=0 - ) - self.c3 = nn.Conv2d( - in_channels=n_feats_k, out_channels=n_feats_k, kernel_size=3, padding=1 - ) - - def forward(self, x): - i = x - x = self.c2(self.prelu(self.c1(x))) - x = self.c3(x) - x += i - - return x diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train.sh b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train.sh deleted file mode 100644 index 7b026d766aefdfa8a5c92a46d1750661bc344f5a..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train.sh +++ /dev/null @@ -1,32 +0,0 @@ -# The copyright in this software is being made available under the BSD -# License, included below. This software may be subject to other third party -# and contributor rights, including patent rights, and no such rights are -# granted under this license. -# -# Copyright (c) 2010-2022, ITU/ISO/IEC -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -# be used to endorse or promote products derived from this software without -# specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -# THE POSSIBILITY OF SUCH DAMAGE. -python train_YUV.py \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train_YUV.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train_YUV.py deleted file mode 100644 index 5340aec0c2823bff5561a03248d5985bc880e65c..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train_YUV.py +++ /dev/null @@ -1,452 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import torch -import torch.nn as nn -from torch.optim.adam import Adam -from torch.optim.lr_scheduler import MultiStepLR -from torch.utils.data.dataloader import DataLoader -import datetime -import os -import glob - -from yuv10bdata import YUV10bData -from Utils import init, cal_psnr -from nn_model import Net - -torch.backends.cudnn.enabled = True -torch.backends.cudnn.benchmark = True - - -class Trainer: - def __init__(self): - ( - self.args, - self.logger, - self.checkpoints_dir, - self.tensorboard_all, - self.tensorboard, - ) = init() - - self.net = Net().to("cuda" if self.args.gpu else "cpu") - - self.L1loss = nn.L1Loss().to("cuda" if self.args.gpu else "cpu") - self.L2loss = nn.MSELoss().to("cuda" if self.args.gpu else "cpu") - - self.optimizer = Adam(self.net.parameters(), lr=self.args.lr) - self.scheduler = MultiStepLR( - optimizer=self.optimizer, milestones=[4001, 4002], gamma=0.5 - ) - - print("============>loading data") - self.train_dataset = YUV10bData(self.args, train=True) - self.eval_dataset = YUV10bData(self.args, train=False) - - self.train_dataloader = DataLoader( - dataset=self.train_dataset, - batch_size=self.args.batch_size, - shuffle=True, - num_workers=12, - pin_memory=False, - ) - self.eval_dataloader = DataLoader( - dataset=self.eval_dataset, - batch_size=self.args.batch_size, - shuffle=True, - num_workers=12, - pin_memory=False, - ) - - self.train_steps = self.eval_steps = 0 - - def train(self): - start_epoch = self.load_checkpoints() - print("============>start training") - for epoch in range(start_epoch, self.args.max_epoch): - print("Epoch {}/{}".format(epoch, self.args.max_epoch)) - self.logger.info("Epoch {}/{}".format(epoch, self.args.max_epoch)) - self.train_one_epoch() - self.scheduler.step() - if (epoch + 1) % self.args.eval_epochs == 0: - self.eval(epoch=epoch) - self.save_ckpt(epoch=epoch) - - def train_one_epoch(self): - self.net.train() - for _, tensor in enumerate(self.train_dataloader): - img_lr, img_hr, filename = tensor - - img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") - img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") - - uv_rec = img_lr[:, 0:2, :, :] - slice_qp = img_lr[:, 2:3, :, :] - base_qp = img_lr[:, 3:4, :, :] - slice_type = img_lr[:, 4:5, :, :] - y_rec = img_hr[:, 0:1, :, :] - img_rpr = img_hr[:, 1:3, :, :] - img_ori = img_hr[:, 3:5, :, :] - - img_out = self.net(y_rec, uv_rec, img_rpr, slice_qp, base_qp, slice_type) - - # calculate distortion - shave = self.args.shave // 2 - - # L1_loss_pred_Y = self.L1loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - L1_loss_pred_Cb = self.L1loss( - img_out[:, 0:1, shave:-shave, shave:-shave], - img_ori[:, 0:1, shave:-shave, shave:-shave], - ) - L1_loss_pred_Cr = self.L1loss( - img_out[:, 1:2, shave:-shave, shave:-shave], - img_ori[:, 1:2, shave:-shave, shave:-shave], - ) - - # loss_pred_Y = self.L2loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - loss_pred_Cb = self.L2loss( - img_out[:, 0:1, shave:-shave, shave:-shave], - img_ori[:, 0:1, shave:-shave, shave:-shave], - ) - loss_pred_Cr = self.L2loss( - img_out[:, 1:2, shave:-shave, shave:-shave], - img_ori[:, 1:2, shave:-shave, shave:-shave], - ) - - # loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr - loss_pred = L1_loss_pred_Cb + L1_loss_pred_Cr - - # loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - # loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - # visualization - self.train_steps += 1 - if self.train_steps % 20 == 0: - # psnr_pred_Y = cal_psnr(loss_pred_Y) - psnr_pred_Cb = cal_psnr(loss_pred_Cb) - psnr_pred_Cr = cal_psnr(loss_pred_Cr) - - # psnr_input_Y = cal_psnr(loss_rec_Y) - # psnr_input_Cb = cal_psnr(loss_rec_Cb) - # psnr_input_Cr = cal_psnr(loss_rec_Cr) - - time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") - - print( - "[{}/{}]\tCb:{:.8f}\tCr:{:.8f}\tPSNR_Cb:{:.8f}\tPSNR_Cr:{:.8f}------{}".format( - (self.train_steps % len(self.train_dataloader)), - len(self.train_dataloader), - loss_pred_Cb, - loss_pred_Cr, - psnr_pred_Cb, - psnr_pred_Cr, - time, - ) - ) - self.logger.info( - "[{}/{}]\tCb:{:.8f}\tCr:{:.8f}\tPSNR_Cb:{:.8f}\tPSNR_Cr:{:.8f}".format( - (self.train_steps % len(self.train_dataloader)), - len(self.train_dataloader), - loss_pred_Cb, - loss_pred_Cr, - psnr_pred_Cb, - psnr_pred_Cr, - ) - ) - - # print("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), - # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y, time)) - # self.logger.info("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), - # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y)) - - self.tensorboard_all.add_scalars( - main_tag="Train/PSNR", - tag_scalar_dict={"pred_Cb": psnr_pred_Cb.data}, - global_step=self.train_steps, - ) - self.tensorboard_all.add_scalars( - main_tag="Train/PSNR", - tag_scalar_dict={"pred_Cr": psnr_pred_Cr.data}, - global_step=self.train_steps, - ) - self.tensorboard_all.add_image( - "rec", - uv_rec[0:1, 0:1, :, :].squeeze(dim=0), - global_step=self.train_steps, - ) - # self.tensorboard_all.add_image("pre", pre[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - self.tensorboard_all.add_image( - "rpr", - img_rpr[0:1, 0:1, :, :].squeeze(dim=0), - global_step=self.train_steps, - ) - self.tensorboard_all.add_image( - "out", - img_out[0:1, 0:1, :, :].squeeze(dim=0), - global_step=self.train_steps, - ) - self.tensorboard_all.add_image( - "ori", - img_ori[0:1, 0:1, :, :].squeeze(dim=0), - global_step=self.train_steps, - ) - - # self.tensorboard_all.add_scalars(main_tag="Train/PSNR", - # tag_scalar_dict={"input_Cb": psnr_input_Cb.data, - # "pred_Cb": psnr_pred_Cb.data}, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalars(main_tag="Train/PSNR", - # tag_scalar_dict={"input_Cr": psnr_input_Cr.data, - # "pred_Cr": psnr_pred_Cr.data}, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Y", - # scalar_value = psnr_pred_Y - psnr_input_Y, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cb", - # scalar_value = psnr_pred_Cb - psnr_input_Cb, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cr", - # scalar_value = psnr_pred_Cr - psnr_input_Cr, - # global_step=self.train_steps) - - self.tensorboard_all.add_scalar( - tag="Train/train_loss_pred", - scalar_value=loss_pred, - global_step=self.train_steps, - ) - - # backward - self.optimizer.zero_grad() - loss_pred.backward() - self.optimizer.step() - - @torch.no_grad() - def eval(self, epoch: int): - print("============>start evaluating") - eval_cnt = 0 - # ave_psnr_Y = 0.000 - ave_psnr_Cb = 0.000 - ave_psnr_Cr = 0.000 - self.net.eval() - for _, tensor in enumerate(self.eval_dataloader): - img_lr, img_hr, filename = tensor - - img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") - img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") - - uv_rec = img_lr[:, 0:2, :, :] - slice_qp = img_lr[:, 2:3, :, :] - base_qp = img_lr[:, 3:4, :, :] - slice_type = img_lr[:, 4:5, :, :] - y_rec = img_hr[:, 0:1, :, :] - img_rpr = img_hr[:, 1:3, :, :] - img_ori = img_hr[:, 3:5, :, :] - - img_out = self.net(y_rec, uv_rec, img_rpr, slice_qp, base_qp, slice_type) - - # calculate distortion - shave = self.args.shave // 2 - - # L1_loss_pred_Y = self.L1loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - L1_loss_pred_Cb = self.L1loss( - img_out[:, 0:1, shave:-shave, shave:-shave], - img_ori[:, 0:1, shave:-shave, shave:-shave], - ) - L1_loss_pred_Cr = self.L1loss( - img_out[:, 1:2, shave:-shave, shave:-shave], - img_ori[:, 1:2, shave:-shave, shave:-shave], - ) - - # loss_pred_Y = self.L2loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - loss_pred_Cb = self.L2loss( - img_out[:, 0:1, shave:-shave, shave:-shave], - img_ori[:, 0:1, shave:-shave, shave:-shave], - ) - loss_pred_Cr = self.L2loss( - img_out[:, 1:2, shave:-shave, shave:-shave], - img_ori[:, 1:2, shave:-shave, shave:-shave], - ) - - # loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr - loss_pred = L1_loss_pred_Cb + L1_loss_pred_Cr - - # loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - # loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - # psnr_pred_Y = cal_psnr(loss_pred_Y) - psnr_pred_Cb = cal_psnr(loss_pred_Cb) - psnr_pred_Cr = cal_psnr(loss_pred_Cr) - - # psnr_input_Y = cal_psnr(loss_rec_Y) - # psnr_input_Cb = cal_psnr(loss_rec_Cb) - # psnr_input_Cr = cal_psnr(loss_rec_Cr) - - # ave_psnr_Y += psnr_pred_Y - ave_psnr_Cb += psnr_pred_Cb - ave_psnr_Cr += psnr_pred_Cr - - eval_cnt += 1 - # visualization - self.eval_steps += 1 - if self.eval_steps % 2 == 0: - # self.tensorboard_all.add_scalar(tag="Eval/PSNR_Cb", - # scalar_value = psnr_pred_Cb, - # global_step=self.eval_steps) - - self.tensorboard_all.add_scalar( - tag="Eval/PSNR_Cb", - scalar_value=psnr_pred_Cb, - global_step=self.eval_steps, - ) - - self.tensorboard_all.add_scalar( - tag="Eval/PSNR_Cr", - scalar_value=psnr_pred_Cr, - global_step=self.eval_steps, - ) - - self.tensorboard_all.add_scalar( - tag="Eval/eval_loss_pred", - scalar_value=loss_pred, - global_step=self.eval_steps, - ) - - time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") - # print("PSNR_Y:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, time)) - # self.logger.info("PSNR_Y:{:.3f}".format(ave_psnr_Y / eval_cnt)) - - print( - "delta_Cb:{:.3f}\tdelta_Cr:{:.3f}------{}".format( - ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt, time - ) - ) - self.logger.info( - "delta_Cb:{:.3f}\tdelta_Cr:{:.3f}".format( - ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt - ) - ) - - # self.tensorboard.add_scalar(tag = "Eval/PSNR_Y_ave", - # scalar_value = ave_psnr_Y / eval_cnt, - # global_step = epoch + 1) - self.tensorboard.add_scalar( - tag="Eval/PSNR_Cb_ave", - scalar_value=ave_psnr_Cb / eval_cnt, - global_step=epoch + 1, - ) - self.tensorboard.add_scalar( - tag="Eval/PSNR_Cr_ave", - scalar_value=ave_psnr_Cr / eval_cnt, - global_step=epoch + 1, - ) - - def load_checkpoints(self): - if not self.args.checkpoints: - ckpt_list = sorted(glob.glob(os.path.join(self.checkpoints_dir, "*.pth"))) - num = len(ckpt_list) - if num > 1: - if os.path.getsize(ckpt_list[-1]) == os.path.getsize(ckpt_list[-2]): - self.args.checkpoints = ckpt_list[-1] - else: - self.args.checkpoints = ckpt_list[-2] - - if self.args.checkpoints: - print( - "===========Load checkpoints {0}===========".format( - self.args.checkpoints - ) - ) - self.logger.info("Load checkpoints {0}".format(self.args.checkpoints)) - ckpt = torch.load(self.args.checkpoints) - # load network weights - try: - self.net.load_state_dict(ckpt["network"]) - except Exception: - print("Can not find network weights") - # load optimizer params - try: - self.optimizer.load_state_dict(ckpt["optimizer"]) - self.scheduler.load_state_dict(ckpt["scheduler"]) - except Exception: - print("Can not find some optimizers params, just ignore") - start_epoch = ckpt["epoch"] + 1 - self.train_steps = ckpt["train_step"] + 1 - self.eval_steps = ckpt["eval_step"] + 1 - elif self.args.pretrained: - ckpt = torch.load(self.args.pretrained) - print( - "===========Load network weights {0}===========".format( - self.args.checkpoints - ) - ) - self.logger.info("Load network weights {0}".format(self.args.checkpoints)) - # load codec weights - try: - self.net.load_state_dict(ckpt["network"]) - except Exception: - print("Can not find network weights") - start_epoch = 0 - else: - print("===========Training from scratch===========") - self.logger.info("Training from scratch") - start_epoch = 0 - return start_epoch - - def save_ckpt(self, epoch: int): - checkpoint = { - "network": self.net.state_dict(), - "epoch": epoch, - "train_step": self.train_steps, - "eval_step": self.eval_steps, - "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.state_dict(), - } - - torch.save(checkpoint, "%s/model_%.4d.pth" % (self.checkpoints_dir, epoch + 1)) - self.logger.info("Save model..") - print( - "======================Saving model {0}======================".format( - str(epoch) - ) - ) - - -if __name__ == "__main__": - trainer = Trainer() - trainer.train() diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/yuv10bdata.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/yuv10bdata.py deleted file mode 100644 index 38a9f39f7d4e5ceb605261c949a58583c0ed0e2c..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/yuv10bdata.py +++ /dev/null @@ -1,339 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import os -import glob -from torch.utils.data import Dataset -import numpy as np -import Utils -import math - - -class YUV10bData(Dataset): - def __init__(self, args, name="YuvData", train=True): - super(YUV10bData, self).__init__() - self.args = args - self.split = "train" if train else "valid" - self.image_ext = args.ext - self.name = name - self.train = train - - data_range_bvi_AI = [r.split("-") for r in args.data_range_bvi_AI.split("/")] - data_range_tvd_AI = [r.split("-") for r in args.data_range_tvd_AI.split("/")] - data_range_bvi_RA = [r.split("-") for r in args.data_range_bvi_RA.split("/")] - data_range_tvd_RA = [r.split("-") for r in args.data_range_tvd_RA.split("/")] - - if train: - data_range_bvi_AI = data_range_bvi_AI[0] - data_range_tvd_AI = data_range_tvd_AI[0] - data_range_bvi_RA = data_range_bvi_RA[0] - data_range_tvd_RA = data_range_tvd_RA[0] - else: - data_range_bvi_AI = data_range_bvi_AI[1] - data_range_tvd_AI = data_range_tvd_AI[1] - data_range_bvi_RA = data_range_bvi_RA[1] - data_range_tvd_RA = data_range_tvd_RA[1] - - self.begin_bvi_AI, self.end_bvi_AI = list( # noqa: C417 - map(lambda x: int(x), data_range_bvi_AI) - ) - self.begin_tvd_AI, self.end_tvd_AI = list( # noqa: C417 - map(lambda x: int(x), data_range_tvd_AI) - ) - self.begin_bvi_RA, self.end_bvi_RA = list( # noqa: C417 - map(lambda x: int(x), data_range_bvi_RA) - ) - self.begin_tvd_RA, self.end_tvd_RA = list( # noqa: C417 - map(lambda x: int(x), data_range_tvd_RA) - ) - - self._set_data() - self._get_image_list() - - if train: - # n_patches = args.batch_size * args.test_every - n_images = len(self.images_yuv) - if n_images == 0: - self.repeat = 0 - else: - self.repeat = 1 - # self.repeat = max(n_patches // n_images, 1) - print(f"repeating dataset {self.repeat} for one epoch") - else: - # n_patches = args.batch_size * args.test_every // 25 - n_images = len(self.images_yuv) - if n_images == 0: - self.repeat = 0 - else: - self.repeat = 5 - print(f"repeating dataset {self.repeat} for one epoch") - - def _set_data(self): - self.dir_in_bvi_AI = os.path.join(self.args.dir_data_bvi_AI, "yuv") - self.dir_org_bvi_AI = os.path.join(self.args.dir_data_ori_bvi_AI) - - self.dir_in_tvd_AI = os.path.join(self.args.dir_data_tvd_AI, "yuv") - self.dir_org_tvd_AI = os.path.join(self.args.dir_data_ori_tvd_AI) - - self.dir_in_bvi_RA = os.path.join(self.args.dir_data_bvi_RA, "yuv") - self.dir_org_bvi_RA = os.path.join(self.args.dir_data_ori_bvi_RA) - - self.dir_in_tvd_RA = os.path.join(self.args.dir_data_tvd_RA, "yuv") - self.dir_org_tvd_RA = os.path.join(self.args.dir_data_ori_tvd_RA) - - def _scan_class(self, is_tvd, class_name, mode): - QPs = ["22", "27", "32", "37", "42"] - - if is_tvd: - if mode == "AI": - dir_in = self.dir_in_tvd_AI - dir_org = self.dir_org_tvd_AI - else: - dir_in = self.dir_in_tvd_RA - dir_org = self.dir_org_tvd_RA - else: - if mode == "AI": - dir_in = self.dir_in_bvi_AI - dir_org = self.dir_org_bvi_AI - else: - dir_in = self.dir_in_bvi_RA - dir_org = self.dir_org_bvi_RA - - list_temp = glob.glob( - os.path.join(dir_in, "*_" + class_name + "_*." + self.image_ext) - ) - file_rec_list = [] - for i in list_temp: - index = i.find("poc") - poc = int(i[index + 3 : index + 6]) - if poc % 3 == 0 and poc != 0: - file_rec_list.append(i) - - if is_tvd: - file_rec_list = sorted(file_rec_list) - else: - file_rec_list = sorted(file_rec_list, key=str.lower) - - list_temp = glob.glob( - os.path.join(dir_org, class_name + "*/*." + self.image_ext) - ) - # print(list_temp) - file_org_list = [] - for i in list_temp: - index = i.find("frame_") - poc = int(i[index + 6 : index + 9]) - if poc % 3 == 0 and poc != 0: - file_org_list.append(i) - - if is_tvd: - file_org_list = sorted(file_org_list) - else: - file_org_list = sorted(file_org_list, key=str.lower) - - frame_num = 62 - frame_num_sampled = math.ceil(frame_num / 3) - - if is_tvd: - if mode == "AI": - begin = self.begin_tvd_AI - end = self.end_tvd_AI - else: - begin = self.begin_tvd_RA - end = self.end_tvd_RA - else: - if mode == "AI": - begin = self.begin_bvi_AI - end = self.end_bvi_AI - else: - begin = self.begin_bvi_RA - end = self.end_bvi_RA - - class_names_yuv = [] - class_names_yuv_org = [] - - for qp in QPs: - file_list = file_rec_list[ - (begin - 1) * frame_num_sampled * 5 : end * frame_num_sampled * 5 - ] - for filename in file_list: - idx = filename.find("qp") - if int(filename[idx + 2 : idx + 4]) == int(qp): - class_names_yuv.append(filename) - - file_list = file_org_list[ - (begin - 1) * frame_num_sampled : end * frame_num_sampled - ] - for filename in file_list: - class_names_yuv_org.append(filename) - - return class_names_yuv, class_names_yuv_org - - def _scan(self): - bvi_class_set = ["A"] - - names_yuv = [] - names_yuv_org = [] - for class_name in bvi_class_set: - class_names_yuv, class_names_yuv_org = self._scan_class( - False, class_name, "AI" - ) - names_yuv = names_yuv + class_names_yuv - names_yuv_org = names_yuv_org + class_names_yuv_org - - class_names_yuv, class_names_yuv_org = self._scan_class( - False, class_name, "RA" - ) - names_yuv = names_yuv + class_names_yuv - names_yuv_org = names_yuv_org + class_names_yuv_org - - class_names_yuv, class_names_yuv_org = self._scan_class(True, "A", "AI") - names_yuv = names_yuv + class_names_yuv - names_yuv_org = names_yuv_org + class_names_yuv_org - - class_names_yuv, class_names_yuv_org = self._scan_class(True, "A", "RA") - names_yuv = names_yuv + class_names_yuv - names_yuv_org = names_yuv_org + class_names_yuv_org - - print(len(names_yuv)) - print(len(names_yuv_org)) - - return names_yuv, names_yuv_org - - def _get_image_list(self): - self.images_yuv, self.images_yuv_org = self._scan() - - def __getitem__(self, idx): - patch_in, patch_org, filename = self._load_file_get_patch(idx) - pair_t = Utils.np2Tensor(patch_in, patch_org) - - return pair_t[0], pair_t[1], filename - - def __len__(self): - if self.train: - return len(self.images_yuv) * self.repeat - else: - return len(self.images_yuv) * self.repeat - - def _get_index(self, idx): - if self.train: - return idx % len(self.images_yuv) - else: - return idx % len(self.images_yuv) - - def _load_file_get_patch(self, idx): - idx = self._get_index(idx) - - # reconstruction - image_yuv_path = self.images_yuv[idx] - - slice_qp_idx = int(image_yuv_path.rfind("qp")) - slice_qp = int(image_yuv_path[slice_qp_idx + 2 : slice_qp_idx + 4]) - slice_qp_map = np.uint16( - np.ones( - ( - (self.args.patch_size + 2 * self.args.shave) // 4, - (self.args.patch_size + 2 * self.args.shave) // 4, - 1, - ) - ) - * slice_qp - ) - - base_qp_idx = int(image_yuv_path.find("qp")) - base_qp = int(image_yuv_path[base_qp_idx + 2 : base_qp_idx + 4]) - base_qp_map = np.uint16( - np.ones( - ( - (self.args.patch_size + 2 * self.args.shave) // 4, - (self.args.patch_size + 2 * self.args.shave) // 4, - 1, - ) - ) - * base_qp - ) - - if ( - self.args.dir_data_bvi_AI in image_yuv_path - or self.args.dir_data_tvd_AI in image_yuv_path - ): - is_AI = 1 - else: - is_AI = 0 - - if is_AI: - slice_type = 0 - else: - slice_type = 1023 - slice_type_map = np.uint16( - np.ones( - ( - (self.args.patch_size + 2 * self.args.shave) // 4, - (self.args.patch_size + 2 * self.args.shave) // 4, - 1, - ) - ) - * slice_type - ) - - # RPR - rpr_str = "_rpr" - pos = image_yuv_path.find(".yuv") - image_yuv_rpr_path = image_yuv_path[:pos] + rpr_str + image_yuv_path[pos:] - image_yuv_rpr_path = image_yuv_rpr_path.replace("/yuv/", "/rpr_image/") - - # original - image_yuv_org_path = self.images_yuv_org[idx] - org_splits = os.path.basename(os.path.dirname(image_yuv_org_path)).split("_") - wh_org = org_splits[1].split("x") - w, h = list(map(lambda x: int(x), wh_org)) # noqa: C417 - - patch_rec_y, patch_in, patch_rpr, patch_org = Utils.get_patch( - image_yuv_path, - image_yuv_rpr_path, - image_yuv_org_path, - w, - h, - self.args.patch_size, - self.args.shave, - ) - - patch_lr = np.concatenate( - (patch_in, slice_qp_map, base_qp_map, slice_type_map), axis=2 - ) - patch_hr = np.concatenate((patch_rec_y, patch_rpr, patch_org), axis=2) - - if self.train: - patch_lr, patch_hr = Utils.augment(patch_lr, patch_hr) - - return patch_lr, patch_hr, image_yuv_path diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/Utils.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/Utils.py deleted file mode 100644 index 4db1c06bdf686208813e2411c682a1ce2fd20d8b..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/Utils.py +++ /dev/null @@ -1,305 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import argparse -import logging -import random -from pathlib import Path -import numpy as np -import PIL.Image as Image -import os - - -import torch -from torch.utils.tensorboard import SummaryWriter - - -def parse_args(): - parser = argparse.ArgumentParser() - - path_cur = Path(os.path.split(os.path.realpath(__file__))[0]) - path_save = path_cur.joinpath("Experiments") - - # for loading data - parser.add_argument("--ext", type=str, default="yuv", help="data file extension") - - parser.add_argument( - "--data_range", type=str, default="1-180/181-200", help="train/test data range" - ) - parser.add_argument( - "--dir_data", - type=str, - default="/path/EE1_2_2_train/RA_BVI_DVC", - help="distorted dataset directory", - ) - parser.add_argument( - "--dir_data_ori", - type=str, - default="/path/EE1_2_2_train_ori/BVI_DVC", - help="raw dataset directory", - ) - - parser.add_argument( - "--data_range_tvd", type=str, default="1-66/67-74", help="train/test data range" - ) - parser.add_argument( - "--dir_data_tvd", - type=str, - default="/path/EE1_2_2_train/RA_TVD", - help="distorted dataset directory", - ) - parser.add_argument( - "--dir_data_ori_tvd", - type=str, - default="/path/EE1_2_2_train_ori/TVD", - help="raw dataset directory", - ) - - # for loading model - parser.add_argument("--checkpoints", type=str, help="checkpoints file path") - parser.add_argument("--pretrained", type=str, help="pretrained model path") - - # batch size - parser.add_argument( - "--batch_size", type=int, default=64, help="batch size for Fusion stage" - ) - # do validation - parser.add_argument( - "--test_every", type=int, default=1200, help="do test per every N batches" - ) - - # learning rate - parser.add_argument( - "--lr", type=float, default=1e-4, help="learning rate for Fusion stage" - ) - - parser.add_argument( - "--gpu", action="store_true", default=True, help="use gpu or cpu" - ) - - # epoch - parser.add_argument( - "--max_epoch", type=int, default=2000, help="max training epochs" - ) - - # patch_size - parser.add_argument( - "--patch_size", type=int, default=128, help="train/val patch size" - ) - parser.add_argument("--shave", type=int, default=8, help="train/shave") - - # for recording - parser.add_argument( - "--verbose", - action="store_true", - default=True, - help="use tensorboard and logger", - ) - parser.add_argument( - "--save_dir", type=str, default=path_save, help="directory for recording" - ) - parser.add_argument( - "--eval_epochs", type=int, default=5, help="save model after epochs" - ) - - args = parser.parse_args() - return args - - -def init(): - # parse arguments - args = parse_args() - - # create directory for recording - experiment_dir = Path(args.save_dir) - experiment_dir.mkdir(exist_ok=True) - - ckpt_dir = experiment_dir.joinpath("Checkpoints/") - ckpt_dir.mkdir(exist_ok=True) - print(r"===========Save checkpoints to {0}===========".format(str(ckpt_dir))) - - if args.verbose: - # initialize logger - log_dir = experiment_dir.joinpath("Log/") - log_dir.mkdir(exist_ok=True) - logger = logging.getLogger() - logger.setLevel(logging.INFO) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - file_handler = logging.FileHandler(str(log_dir) + "/Log.txt") - file_handler.setLevel(logging.INFO) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - logger.info("PARAMETER ...") - logger.info(args) - # initialize tensorboard - tb_dir_all = experiment_dir.joinpath("Tensorboard_all/") - tb_dir_all.mkdir(exist_ok=True) - tensorboard_all = SummaryWriter(log_dir=str(tb_dir_all), flush_secs=30) - - tb_dir = experiment_dir.joinpath("Tensorboard/") - tb_dir.mkdir(exist_ok=True) - tensorboard = SummaryWriter(log_dir=str(tb_dir), flush_secs=30) - print( - r"===========Save tensorboard and logger to {0}===========".format( - str(tb_dir_all) - ) - ) - else: - print( - r"===========Disable tensorboard and logger to accelerate training===========" - ) - logger = None - tensorboard_all = None - tensorboard = None - - return args, logger, ckpt_dir, tensorboard_all, tensorboard - - -def yuv_read(yuv_path, h, w, iy, ix, ip): - h_c = h // 2 - w_c = w // 2 - - ip_c = ip // 2 - iy_c = iy // 2 - ix_c = ix // 2 - - fp = open(yuv_path, "rb") - - # y - fp.seek(iy * w * 2, 0) - patch_y = np.fromfile(fp, np.uint16, ip * w).reshape(ip, w, 1) - patch_y = patch_y[:, ix : ix + ip, :] - - # u - fp.seek((w * h + iy_c * w_c) * 2, 0) - patch_u = np.fromfile(fp, np.uint16, ip_c * w_c).reshape(ip_c, w_c, 1) - patch_u = patch_u[:, ix_c : ix_c + ip_c, :] - - # v - fp.seek((w * h + w_c * h_c + iy_c * w_c) * 2, 0) - patch_v = np.fromfile(fp, np.uint16, ip_c * w_c).reshape(ip_c, w_c, 1) - patch_v = patch_v[:, ix_c : ix_c + ip_c, :] - - fp.close() - - return patch_y, patch_u, patch_v - - -def upsample(img, height, width): - img = np.squeeze(img, axis=2) - img = np.array( - Image.fromarray(img.astype(np.float)).resize((width, height), Image.NEAREST) - ) - img = np.expand_dims(img, axis=2) - return img - - -def patch_process(yuv_path, h, w, iy, ix, ip): - y, u, v = yuv_read(yuv_path, h, w, iy, ix, ip) - # u_up = upsample(u, ip, ip) - # v_up = upsample(v, ip, ip) - # yuv = np.concatenate((y, u_up, v_up), axis=2) - return y - - -def get_patch( - image_yuv_path, - image_yuv_pred_path, - image_yuv_rpr_path, - image_yuv_org_path, - w, - h, - patch_size, - shave, -): - ih = h - iw = w - - ip = patch_size - ih -= ih % ip - iw -= iw % ip - iy = random.randrange(ip, ih - ip, ip) - shave - ix = random.randrange(ip, iw - ip, ip) - shave - - # - patch_rec = patch_process( - image_yuv_path, h // 2, w // 2, iy // 2, ix // 2, (ip + 2 * shave) // 2 - ) - patch_pre = patch_process( - image_yuv_pred_path, h // 2, w // 2, iy // 2, ix // 2, (ip + 2 * shave) // 2 - ) - patch_rpr = patch_process(image_yuv_rpr_path, h, w, iy, ix, ip + 2 * shave) - patch_org = patch_process(image_yuv_org_path, h, w, iy, ix, ip + 2 * shave) - - patch_in = np.concatenate((patch_rec, patch_pre), axis=2) - - ret = [patch_in, patch_rpr, patch_org] - - return ret - - -def augment(*args): - x = random.random() - hflip = x < 0.2 - vflip = x >= 0.2 and x < 0.4 - rot90 = x >= 0.4 and x < 0.6 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - - return img - - return [_augment(a) for a in args] - - -def np2Tensor(*args): - def _np2Tensor(img): - np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) - tensor = torch.from_numpy(np_transpose.astype(np.int32)).float() / 1023.0 - - return tensor - - return [_np2Tensor(a) for a in args] - - -def cal_psnr(distortion: torch.Tensor): - psnr = -10 * torch.log10(distortion) - return psnr diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/nn_model.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/nn_model.py deleted file mode 100644 index ec9023cce5f0b776c00625032fed8db930d2ef43..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/nn_model.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import torch -import torch.nn as nn - - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - # hyper-params - n_resblocks = 24 - n_feats_k = 64 - n_feats_m = 192 - - # define head module - self.head_rec = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=n_feats_k, - kernel_size=3, - stride=1, - padding=1, - ), # downsmaple by stride = 2 - nn.PReLU(), - ) - self.head_pre = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=n_feats_k, - kernel_size=3, - stride=1, - padding=1, - ), # downsmaple by stride = 2 - nn.PReLU(), - ) - - # define fuse module - self.fuse = nn.Sequential( - nn.Conv2d( - in_channels=n_feats_k * 2 + 2, - out_channels=n_feats_k, - kernel_size=1, - stride=1, - padding=0, - ), - nn.PReLU(), - ) - - # define body module - body = [] - for _ in range(n_resblocks): - body.append(DscBlock(n_feats_k, n_feats_m)) - - self.body = nn.Sequential(*body) - - # define tail module - self.tail = nn.Sequential( - nn.Conv2d( - in_channels=n_feats_k, out_channels=4 * 1, kernel_size=3, padding=1 - ), - nn.PixelShuffle(2), # feature_map:(B, 2x2x1, N, N) -> (B, 1, 2N, 2N) - ) - - def forward(self, rec, pre, rpr, slice_qp, base_qp): - in_0 = self.head_rec(rec) - in_1 = self.head_pre(pre) - - x = self.fuse(torch.cat((in_0, in_1, slice_qp, base_qp), 1)) - x = self.body(x) - x = self.tail(x) - x += rpr - - return x - - -class DscBlock(nn.Module): - def __init__(self, n_feats_k, n_feats_m, expansion=1): - super(DscBlock, self).__init__() - self.expansion = expansion - self.c1 = nn.Conv2d( - in_channels=n_feats_k, out_channels=n_feats_m, kernel_size=1, padding=0 - ) - self.prelu = nn.PReLU() - self.c2 = nn.Conv2d( - in_channels=n_feats_m, out_channels=n_feats_k, kernel_size=1, padding=0 - ) - self.c3 = nn.Conv2d( - in_channels=n_feats_k, out_channels=n_feats_k, kernel_size=3, padding=1 - ) - - def forward(self, x): - i = x - x = self.c2(self.prelu(self.c1(x))) - x = self.c3(x) - x += i - - return x diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train.sh b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train.sh deleted file mode 100644 index 7b026d766aefdfa8a5c92a46d1750661bc344f5a..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train.sh +++ /dev/null @@ -1,32 +0,0 @@ -# The copyright in this software is being made available under the BSD -# License, included below. This software may be subject to other third party -# and contributor rights, including patent rights, and no such rights are -# granted under this license. -# -# Copyright (c) 2010-2022, ITU/ISO/IEC -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -# be used to endorse or promote products derived from this software without -# specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -# THE POSSIBILITY OF SUCH DAMAGE. -python train_YUV.py \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train_YUV.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train_YUV.py deleted file mode 100644 index 2b81e93bd7a85fa79c40106d1aa1f87be7a29dd5..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train_YUV.py +++ /dev/null @@ -1,400 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import torch -import torch.nn as nn -from torch.optim.adam import Adam -from torch.optim.lr_scheduler import MultiStepLR -from torch.utils.data.dataloader import DataLoader -import datetime -import os -import glob - -from yuv10bdata import YUV10bData -from Utils import init, cal_psnr -from nn_model import Net - -torch.backends.cudnn.enabled = True -torch.backends.cudnn.benchmark = True - - -class Trainer: - def __init__(self): - ( - self.args, - self.logger, - self.checkpoints_dir, - self.tensorboard_all, - self.tensorboard, - ) = init() - - self.net = Net().to("cuda" if self.args.gpu else "cpu") - - self.L1loss = nn.L1Loss().to("cuda" if self.args.gpu else "cpu") - self.L2loss = nn.MSELoss().to("cuda" if self.args.gpu else "cpu") - - self.optimizer = Adam(self.net.parameters(), lr=self.args.lr) - self.scheduler = MultiStepLR( - optimizer=self.optimizer, milestones=[4001, 4002], gamma=0.5 - ) - - print("============>loading data") - self.train_dataset = YUV10bData(self.args, train=True) - self.eval_dataset = YUV10bData(self.args, train=False) - - self.train_dataloader = DataLoader( - dataset=self.train_dataset, - batch_size=self.args.batch_size, - shuffle=True, - num_workers=12, - pin_memory=False, - ) - self.eval_dataloader = DataLoader( - dataset=self.eval_dataset, - batch_size=self.args.batch_size, - shuffle=True, - num_workers=12, - pin_memory=False, - ) - - self.train_steps = self.eval_steps = 0 - - def train(self): - start_epoch = self.load_checkpoints() - print("============>start training") - for epoch in range(start_epoch, self.args.max_epoch): - print("Epoch {}/{}".format(epoch, self.args.max_epoch)) - self.logger.info("Epoch {}/{}".format(epoch, self.args.max_epoch)) - self.train_one_epoch() - self.scheduler.step() - if (epoch + 1) % self.args.eval_epochs == 0: - self.eval(epoch=epoch) - self.save_ckpt(epoch=epoch) - - def train_one_epoch(self): - self.net.train() - for _, tensor in enumerate(self.train_dataloader): - img_lr, img_hr, filename = tensor - - img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") - img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") - - rec = img_lr[:, 0:1, :, :] - pre = img_lr[:, 1:2, :, :] - slice_qp = img_lr[:, 2:3, :, :] - base_qp = img_lr[:, 3:4, :, :] - img_rpr = img_hr[:, 0:1, :, :] - img_ori = img_hr[:, 1:2, :, :] - - img_out = self.net(rec, pre, img_rpr, slice_qp, base_qp) - - # calculate distortion - shave = self.args.shave - - L1_loss_pred_Y = self.L1loss( - img_out[:, 0, shave:-shave, shave:-shave], - img_ori[:, 0, shave:-shave, shave:-shave], - ) - # L1_loss_pred_Cb = self.L1loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # L1_loss_pred_Cr = self.L1loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - loss_pred_Y = self.L2loss( - img_out[:, 0, shave:-shave, shave:-shave], - img_ori[:, 0, shave:-shave, shave:-shave], - ) - # loss_pred_Cb = self.L2loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_pred_Cr = self.L2loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - # loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr - loss_pred = L1_loss_pred_Y - - # loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - # loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - # visualization - self.train_steps += 1 - if self.train_steps % 20 == 0: - psnr_pred_Y = cal_psnr(loss_pred_Y) - # psnr_pred_Cb = cal_psnr(loss_pred_Cb) - # psnr_pred_Cr = cal_psnr(loss_pred_Cr) - - # psnr_input_Y = cal_psnr(loss_rec_Y) - # psnr_input_Cb = cal_psnr(loss_rec_Cb) - # psnr_input_Cr = cal_psnr(loss_rec_Cr) - - time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") - - print( - "[{}/{}]\tY:{:.8f}\tPSNR_Y: {:.8f}------{}".format( - (self.train_steps % len(self.train_dataloader)), - len(self.train_dataloader), - loss_pred_Y, - psnr_pred_Y, - time, - ) - ) - self.logger.info( - "[{}/{}]\tY:{:.8f}\tPSNR_Y: {:.8f}".format( - (self.train_steps % len(self.train_dataloader)), - len(self.train_dataloader), - loss_pred_Y, - psnr_pred_Y, - ) - ) - - # print("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), - # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y, time)) - # self.logger.info("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), - # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y)) - - self.tensorboard_all.add_scalars( - main_tag="Train/PSNR", - tag_scalar_dict={"pred_Y": psnr_pred_Y.data}, - global_step=self.train_steps, - ) - # self.tensorboard_all.add_image("rec", rec[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - # self.tensorboard_all.add_image("pre", pre[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - # self.tensorboard_all.add_image("rpr", img_rpr[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - # self.tensorboard_all.add_image("out", img_out[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - # self.tensorboard_all.add_image("ori", img_ori[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - - # self.tensorboard_all.add_scalars(main_tag="Train/PSNR", - # tag_scalar_dict={"input_Cb": psnr_input_Cb.data, - # "pred_Cb": psnr_pred_Cb.data}, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalars(main_tag="Train/PSNR", - # tag_scalar_dict={"input_Cr": psnr_input_Cr.data, - # "pred_Cr": psnr_pred_Cr.data}, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Y", - # scalar_value = psnr_pred_Y - psnr_input_Y, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cb", - # scalar_value = psnr_pred_Cb - psnr_input_Cb, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cr", - # scalar_value = psnr_pred_Cr - psnr_input_Cr, - # global_step=self.train_steps) - - self.tensorboard_all.add_scalar( - tag="Train/train_loss_pred", - scalar_value=loss_pred, - global_step=self.train_steps, - ) - - # backward - self.optimizer.zero_grad() - loss_pred.backward() - self.optimizer.step() - - @torch.no_grad() - def eval(self, epoch: int): - print("============>start evaluating") - eval_cnt = 0 - ave_psnr_Y = 0.000 - # ave_psnr_Cb = 0.000 - # ave_psnr_Cr = 0.000 - self.net.eval() - for _, tensor in enumerate(self.eval_dataloader): - img_lr, img_hr, filename = tensor - - img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") - img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") - - rec = img_lr[:, 0:1, :, :] - pre = img_lr[:, 1:2, :, :] - slice_qp = img_lr[:, 2:3, :, :] - base_qp = img_lr[:, 3:4, :, :] - img_rpr = img_hr[:, 0:1, :, :] - img_ori = img_hr[:, 1:2, :, :] - img_out = self.net(rec, pre, img_rpr, slice_qp, base_qp) - - # calculate distortion and psnr - shave = self.args.shave - - L1_loss_pred_Y = self.L1loss( - img_out[:, 0, shave:-shave, shave:-shave], - img_ori[:, 0, shave:-shave, shave:-shave], - ) - # L1_loss_pred_Cb = self.L1loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # L1_loss_pred_Cr = self.L1loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - loss_pred_Y = self.L2loss( - img_out[:, 0, shave:-shave, shave:-shave], - img_ori[:, 0, shave:-shave, shave:-shave], - ) - # loss_pred_Cb = self.L2loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_pred_Cr = self.L2loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - # loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr - loss_pred = L1_loss_pred_Y - - # loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - # loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - psnr_pred_Y = cal_psnr(loss_pred_Y) - # psnr_pred_Cb = cal_psnr(loss_pred_Cb) - # psnr_pred_Cr = cal_psnr(loss_pred_Cr) - - # psnr_input_Y = cal_psnr(loss_rec_Y) - # psnr_input_Cb = cal_psnr(loss_rec_Cb) - # psnr_input_Cr = cal_psnr(loss_rec_Cr) - - ave_psnr_Y += psnr_pred_Y - # ave_psnr_Cb += psnr_pred_Cb - psnr_input_Cb - # ave_psnr_Cr += psnr_pred_Cr - psnr_input_Cr - - eval_cnt += 1 - # visualization - self.eval_steps += 1 - if self.eval_steps % 2 == 0: - self.tensorboard_all.add_scalar( - tag="Eval/PSNR_Y", - scalar_value=psnr_pred_Y, - global_step=self.eval_steps, - ) - - # self.tensorboard_all.add_scalar(tag="Eval/delta_PSNR_Cb", - # scalar_value = psnr_pred_Cb - psnr_input_Cb, - # global_step=self.eval_steps) - - # self.tensorboard_all.add_scalar(tag="Eval/delta_PSNR_Cr", - # scalar_value = psnr_pred_Cr - psnr_input_Cr, - # global_step=self.eval_steps) - - self.tensorboard_all.add_scalar( - tag="Eval/eval_loss_pred", - scalar_value=loss_pred, - global_step=self.eval_steps, - ) - - time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") - print("PSNR_Y:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, time)) - self.logger.info("PSNR_Y:{:.3f}".format(ave_psnr_Y / eval_cnt)) - - # print("delta_Y:{:.3f}\tdelta_Cb:{:.3f}\tdelta_Cr:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt, time)) - # self.logger.info("delta_Y:{:.3f}\tdelta_Cb:{:.3f}\tdelta_Cr:{:.3f}".format(ave_psnr_Y / eval_cnt, ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt)) - - self.tensorboard.add_scalar( - tag="Eval/PSNR_Y_ave", - scalar_value=ave_psnr_Y / eval_cnt, - global_step=epoch + 1, - ) - # self.tensorboard.add_scalar(tag = "Eval/delta_PSNR_Cb_ave", - # scalar_value = ave_psnr_Cb / eval_cnt, - # global_step = epoch + 1) - # self.tensorboard.add_scalar(tag = "Eval/delta_PSNR_Cr_ave", - # scalar_value = ave_psnr_Cr / eval_cnt, - # global_step = epoch + 1) - - def load_checkpoints(self): - if not self.args.checkpoints: - ckpt_list = sorted(glob.glob(os.path.join(self.checkpoints_dir, "*.pth"))) - num = len(ckpt_list) - if num > 1: - if os.path.getsize(ckpt_list[-1]) == os.path.getsize(ckpt_list[-2]): - self.args.checkpoints = ckpt_list[-1] - else: - self.args.checkpoints = ckpt_list[-2] - - if self.args.checkpoints: - print( - "===========Load checkpoints {0}===========".format( - self.args.checkpoints - ) - ) - self.logger.info("Load checkpoints {0}".format(self.args.checkpoints)) - ckpt = torch.load(self.args.checkpoints) - # load network weights - try: - self.net.load_state_dict(ckpt["network"]) - except Exception: - print("Can not find network weights") - # load optimizer params - try: - self.optimizer.load_state_dict(ckpt["optimizer"]) - self.scheduler.load_state_dict(ckpt["scheduler"]) - except Exception: - print("Can not find some optimizers params, just ignore") - start_epoch = ckpt["epoch"] + 1 - self.train_steps = ckpt["train_step"] + 1 - self.eval_steps = ckpt["eval_step"] + 1 - elif self.args.pretrained: - ckpt = torch.load(self.args.pretrained) - print( - "===========Load network weights {0}===========".format( - self.args.checkpoints - ) - ) - self.logger.info("Load network weights {0}".format(self.args.checkpoints)) - # load codec weights - try: - self.net.load_state_dict(ckpt["network"]) - except Exception: - print("Can not find network weights") - start_epoch = 0 - else: - print("===========Training from scratch===========") - self.logger.info("Training from scratch") - start_epoch = 0 - return start_epoch - - def save_ckpt(self, epoch: int): - checkpoint = { - "network": self.net.state_dict(), - "epoch": epoch, - "train_step": self.train_steps, - "eval_step": self.eval_steps, - "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.state_dict(), - } - - torch.save(checkpoint, "%s/model_%.4d.pth" % (self.checkpoints_dir, epoch + 1)) - self.logger.info("Save model..") - print( - "======================Saving model {0}======================".format( - str(epoch) - ) - ) - - -if __name__ == "__main__": - trainer = Trainer() - trainer.train() diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/yuv10bdata.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/yuv10bdata.py deleted file mode 100644 index c9232cf11e6c1c1818056ce6d605b5d0edeaf68c..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/yuv10bdata.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import os -import glob -from torch.utils.data import Dataset -import numpy as np -import Utils -import math - - -class YUV10bData(Dataset): - def __init__(self, args, name="YuvData", train=True): - super(YUV10bData, self).__init__() - self.args = args - self.split = "train" if train else "valid" - self.image_ext = args.ext - self.name = name - self.train = train - - data_range = [r.split("-") for r in args.data_range.split("/")] - if train: - data_range = data_range[0] - else: - data_range = data_range[1] - self.begin, self.end = list(map(lambda x: int(x), data_range)) # noqa: C417 - - data_range_tvd = [r.split("-") for r in args.data_range_tvd.split("/")] - if train: - data_range_tvd = data_range_tvd[0] - else: - data_range_tvd = data_range_tvd[1] - self.begin_tvd, self.end_tvd = list( # noqa: C417 - map(lambda x: int(x), data_range_tvd) - ) - - self._set_data() - self._get_image_list() - - if train: - n_patches = args.batch_size * args.test_every - n_images = len(self.images_yuv) - if n_images == 0: - self.repeat = 0 - else: - self.repeat = max(n_patches // n_images, 1) - print(f"repeating dataset {self.repeat} for one epoch") - else: - n_patches = args.batch_size * args.test_every // 25 - n_images = len(self.images_yuv) - if n_images == 0: - self.repeat = 0 - else: - self.repeat = 5 - print(f"repeating dataset {self.repeat} for one epoch") - - def _set_data(self): - self.dir_in = os.path.join(self.args.dir_data, "yuv") - self.dir_org = os.path.join(self.args.dir_data_ori) - - self.dir_in_tvd = os.path.join(self.args.dir_data_tvd, "yuv") - self.dir_org_tvd = os.path.join(self.args.dir_data_ori_tvd) - - def _scan_class(self, is_tvd, class_name): - QPs = ["22", "27", "32", "37", "42"] - - dir_in = self.dir_in_tvd if is_tvd else self.dir_in - list_temp = glob.glob( - os.path.join(dir_in, "*_" + class_name + "_*." + self.image_ext) - ) - file_rec_list = [] - for i in list_temp: - print(i) - index = i.find("poc") - poc = int(i[index + 3 : index + 6]) - if poc % 3 == 0 and poc != 0: - file_rec_list.append(i) - - if is_tvd: - file_rec_list = sorted(file_rec_list) - else: - file_rec_list = sorted(file_rec_list, key=str.lower) - - dir_org = self.dir_org_tvd if is_tvd else self.dir_org - list_temp = glob.glob( - os.path.join(dir_org, class_name + "*/*." + self.image_ext) - ) - # print(list_temp) - file_org_list = [] - for i in list_temp: - index = i.find("frame_") - poc = int(i[index + 6 : index + 9]) - if poc % 3 == 0 and poc != 0: - file_org_list.append(i) - - if is_tvd: - file_org_list = sorted(file_org_list) - else: - file_org_list = sorted(file_org_list, key=str.lower) - - frame_num = 62 - frame_num_sampled = math.ceil(frame_num / 3) - begin = self.begin_tvd if is_tvd else self.begin - end = self.end_tvd if is_tvd else self.end - - class_names_yuv = [] - class_names_yuv_org = [] - - for qp in QPs: - file_list = file_rec_list[ - (begin - 1) * frame_num_sampled * 5 : end * frame_num_sampled * 5 - ] - for filename in file_list: - idx = filename.find("qp") - if int(filename[idx + 2 : idx + 4]) == int(qp): - class_names_yuv.append(filename) - - file_list = file_org_list[ - (begin - 1) * frame_num_sampled : end * frame_num_sampled - ] - for filename in file_list: - class_names_yuv_org.append(filename) - - return class_names_yuv, class_names_yuv_org - - def _scan(self): - bvi_class_set = ["A"] - - names_yuv = [] - names_yuv_org = [] - for class_name in bvi_class_set: - class_names_yuv, class_names_yuv_org = self._scan_class(False, class_name) - names_yuv = names_yuv + class_names_yuv - names_yuv_org = names_yuv_org + class_names_yuv_org - - class_names_yuv, class_names_yuv_org = self._scan_class(True, "A") - names_yuv = names_yuv + class_names_yuv - names_yuv_org = names_yuv_org + class_names_yuv_org - - print(len(names_yuv)) - print(len(names_yuv_org)) - - return names_yuv, names_yuv_org - - def _get_image_list(self): - self.images_yuv, self.images_yuv_org = self._scan() - - def __getitem__(self, idx): - patch_in, patch_org, filename = self._load_file_get_patch(idx) - pair_t = Utils.np2Tensor(patch_in, patch_org) - - return pair_t[0], pair_t[1], filename - - def __len__(self): - if self.train: - return len(self.images_yuv) * self.repeat - else: - return len(self.images_yuv) * self.repeat - - def _get_index(self, idx): - if self.train: - return idx % len(self.images_yuv) - else: - return idx % len(self.images_yuv) - - def _load_file_get_patch(self, idx): - idx = self._get_index(idx) - - # reconstruction - image_yuv_path = self.images_yuv[idx] - - slice_qp_idx = int(image_yuv_path.rfind("qp")) - slice_qp = int(image_yuv_path[slice_qp_idx + 2 : slice_qp_idx + 4]) - slice_qp_map = np.uint16( - np.ones( - ( - (self.args.patch_size + 2 * self.args.shave) // 2, - (self.args.patch_size + 2 * self.args.shave) // 2, - 1, - ) - ) - * slice_qp - ) - - base_qp_idx = int(image_yuv_path.find("qp")) - base_qp = int(image_yuv_path[base_qp_idx + 2 : base_qp_idx + 4]) - base_qp_map = np.uint16( - np.ones( - ( - (self.args.patch_size + 2 * self.args.shave) // 2, - (self.args.patch_size + 2 * self.args.shave) // 2, - 1, - ) - ) - * base_qp - ) - - # prediction - pred_str = "_prediction" - pos = image_yuv_path.find(".yuv") - image_yuv_pred_path = image_yuv_path[:pos] + pred_str + image_yuv_path[pos:] - image_yuv_pred_path = image_yuv_pred_path.replace("/yuv/", "/prediction_image/") - - # RPR - rpr_str = "_rpr" - pos = image_yuv_path.find(".yuv") - image_yuv_rpr_path = image_yuv_path[:pos] + rpr_str + image_yuv_path[pos:] - image_yuv_rpr_path = image_yuv_rpr_path.replace("/yuv/", "/rpr_image/") - - # original - image_yuv_org_path = self.images_yuv_org[idx] - org_splits = os.path.basename(os.path.dirname(image_yuv_org_path)).split("_") - wh_org = org_splits[1].split("x") - w, h = list(map(lambda x: int(x), wh_org)) # noqa: C417 - - patch_in, patch_rpr, patch_org = Utils.get_patch( - image_yuv_path, - image_yuv_pred_path, - image_yuv_rpr_path, - image_yuv_org_path, - w, - h, - self.args.patch_size, - self.args.shave, - ) - - patch_in = np.concatenate((patch_in, slice_qp_map, base_qp_map), axis=2) - - if self.train: - patch_in, patch_rpr, patch_org = Utils.augment( - patch_in, patch_rpr, patch_org - ) - - patch_lr = patch_in - patch_hr = np.concatenate((patch_rpr, patch_org), axis=2) - - return patch_lr, patch_hr, image_yuv_path diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/Utils.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/Utils.py deleted file mode 100644 index e03cb6fe870ef048065b7e9b5e6d07e9dfb82c34..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/Utils.py +++ /dev/null @@ -1,305 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import argparse -import logging -import random -from pathlib import Path -import numpy as np -import PIL.Image as Image -import os - - -import torch -from torch.utils.tensorboard import SummaryWriter - - -def parse_args(): - parser = argparse.ArgumentParser() - - path_cur = Path(os.path.split(os.path.realpath(__file__))[0]) - path_save = path_cur.joinpath("Experiments") - - # for loading data - parser.add_argument("--ext", type=str, default="yuv", help="data file extension") - - parser.add_argument( - "--data_range", type=str, default="1-180/181-200", help="train/test data range" - ) - parser.add_argument( - "--dir_data", - type=str, - default="/path/EE1_2_2_train/AI_BVI_DVC", - help="distorted dataset directory", - ) - parser.add_argument( - "--dir_data_ori", - type=str, - default="/path/EE1_2_2_train_ori/BVI_DVC", - help="raw dataset directory", - ) - - parser.add_argument( - "--data_range_tvd", type=str, default="1-66/67-74", help="train/test data range" - ) - parser.add_argument( - "--dir_data_tvd", - type=str, - default="/path/EE1_2_2_train/AI_TVD", - help="distorted dataset directory", - ) - parser.add_argument( - "--dir_data_ori_tvd", - type=str, - default="/path/EE1_2_2_train_ori/TVD", - help="raw dataset directory", - ) - - # for loading model - parser.add_argument("--checkpoints", type=str, help="checkpoints file path") - parser.add_argument("--pretrained", type=str, help="pretrained model path") - - # batch size - parser.add_argument( - "--batch_size", type=int, default=64, help="batch size for Fusion stage" - ) - # do validation - parser.add_argument( - "--test_every", type=int, default=1200, help="do test per every N batches" - ) - - # learning rate - parser.add_argument( - "--lr", type=float, default=1e-4, help="learning rate for Fusion stage" - ) - - parser.add_argument( - "--gpu", action="store_true", default=True, help="use gpu or cpu" - ) - - # epoch - parser.add_argument( - "--max_epoch", type=int, default=2000, help="max training epochs" - ) - - # patch_size - parser.add_argument( - "--patch_size", type=int, default=128, help="train/val patch size" - ) - parser.add_argument("--shave", type=int, default=8, help="train/shave") - - # for recording - parser.add_argument( - "--verbose", - action="store_true", - default=True, - help="use tensorboard and logger", - ) - parser.add_argument( - "--save_dir", type=str, default=path_save, help="directory for recording" - ) - parser.add_argument( - "--eval_epochs", type=int, default=5, help="save model after epochs" - ) - - args = parser.parse_args() - return args - - -def init(): - # parse arguments - args = parse_args() - - # create directory for recording - experiment_dir = Path(args.save_dir) - experiment_dir.mkdir(exist_ok=True) - - ckpt_dir = experiment_dir.joinpath("Checkpoints/") - ckpt_dir.mkdir(exist_ok=True) - print(r"===========Save checkpoints to {0}===========".format(str(ckpt_dir))) - - if args.verbose: - # initialize logger - log_dir = experiment_dir.joinpath("Log/") - log_dir.mkdir(exist_ok=True) - logger = logging.getLogger() - logger.setLevel(logging.INFO) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - file_handler = logging.FileHandler(str(log_dir) + "/Log.txt") - file_handler.setLevel(logging.INFO) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - logger.info("PARAMETER ...") - logger.info(args) - # initialize tensorboard - tb_dir_all = experiment_dir.joinpath("Tensorboard_all/") - tb_dir_all.mkdir(exist_ok=True) - tensorboard_all = SummaryWriter(log_dir=str(tb_dir_all), flush_secs=30) - - tb_dir = experiment_dir.joinpath("Tensorboard/") - tb_dir.mkdir(exist_ok=True) - tensorboard = SummaryWriter(log_dir=str(tb_dir), flush_secs=30) - print( - r"===========Save tensorboard and logger to {0}===========".format( - str(tb_dir_all) - ) - ) - else: - print( - r"===========Disable tensorboard and logger to accelerate training===========" - ) - logger = None - tensorboard_all = None - tensorboard = None - - return args, logger, ckpt_dir, tensorboard_all, tensorboard - - -def yuv_read(yuv_path, h, w, iy, ix, ip): - h_c = h // 2 - w_c = w // 2 - - ip_c = ip // 2 - iy_c = iy // 2 - ix_c = ix // 2 - - fp = open(yuv_path, "rb") - - # y - fp.seek(iy * w * 2, 0) - patch_y = np.fromfile(fp, np.uint16, ip * w).reshape(ip, w, 1) - patch_y = patch_y[:, ix : ix + ip, :] - - # u - fp.seek((w * h + iy_c * w_c) * 2, 0) - patch_u = np.fromfile(fp, np.uint16, ip_c * w_c).reshape(ip_c, w_c, 1) - patch_u = patch_u[:, ix_c : ix_c + ip_c, :] - - # v - fp.seek((w * h + w_c * h_c + iy_c * w_c) * 2, 0) - patch_v = np.fromfile(fp, np.uint16, ip_c * w_c).reshape(ip_c, w_c, 1) - patch_v = patch_v[:, ix_c : ix_c + ip_c, :] - - fp.close() - - return patch_y, patch_u, patch_v - - -def upsample(img, height, width): - img = np.squeeze(img, axis=2) - img = np.array( - Image.fromarray(img.astype(np.float)).resize((width, height), Image.NEAREST) - ) - img = np.expand_dims(img, axis=2) - return img - - -def patch_process(yuv_path, h, w, iy, ix, ip): - y, u, v = yuv_read(yuv_path, h, w, iy, ix, ip) - # u_up = upsample(u, ip, ip) - # v_up = upsample(v, ip, ip) - # yuv = np.concatenate((y, u_up, v_up), axis=2) - return y - - -def get_patch( - image_yuv_path, - image_yuv_pred_path, - image_yuv_rpr_path, - image_yuv_org_path, - w, - h, - patch_size, - shave, -): - ih = h - iw = w - - ip = patch_size - ih -= ih % ip - iw -= iw % ip - iy = random.randrange(ip, ih - ip, ip) - shave - ix = random.randrange(ip, iw - ip, ip) - shave - - # - patch_rec = patch_process( - image_yuv_path, h // 2, w // 2, iy // 2, ix // 2, (ip + 2 * shave) // 2 - ) - patch_pre = patch_process( - image_yuv_pred_path, h // 2, w // 2, iy // 2, ix // 2, (ip + 2 * shave) // 2 - ) - patch_rpr = patch_process(image_yuv_rpr_path, h, w, iy, ix, ip + 2 * shave) - patch_org = patch_process(image_yuv_org_path, h, w, iy, ix, ip + 2 * shave) - - patch_in = np.concatenate((patch_rec, patch_pre), axis=2) - - ret = [patch_in, patch_rpr, patch_org] - - return ret - - -def augment(*args): - x = random.random() - hflip = x < 0.2 - vflip = x >= 0.2 and x < 0.4 - rot90 = x >= 0.4 and x < 0.6 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - - return img - - return [_augment(a) for a in args] - - -def np2Tensor(*args): - def _np2Tensor(img): - np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) - tensor = torch.from_numpy(np_transpose.astype(np.int32)).float() / 1023.0 - - return tensor - - return [_np2Tensor(a) for a in args] - - -def cal_psnr(distortion: torch.Tensor): - psnr = -10 * torch.log10(distortion) - return psnr diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/nn_model.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/nn_model.py deleted file mode 100644 index 74e5f973df3ab3019b4bc168f5976afd04250d75..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/nn_model.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import torch -import torch.nn as nn - - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - # hyper-params - n_resblocks = 24 - n_feats_k = 64 - n_feats_m = 192 - - # define head module - self.head_rec = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=n_feats_k, - kernel_size=3, - stride=1, - padding=1, - ), # downsmaple by stride = 2 - nn.PReLU(), - ) - self.head_pre = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=n_feats_k, - kernel_size=3, - stride=1, - padding=1, - ), # downsmaple by stride = 2 - nn.PReLU(), - ) - - # define fuse module - self.fuse = nn.Sequential( - nn.Conv2d( - in_channels=n_feats_k * 2 + 1, - out_channels=n_feats_k, - kernel_size=1, - stride=1, - padding=0, - ), - nn.PReLU(), - ) - - # define body module - body = [] - for _ in range(n_resblocks): - body.append(DscBlock(n_feats_k, n_feats_m)) - - self.body = nn.Sequential(*body) - - # define tail module - self.tail = nn.Sequential( - nn.Conv2d( - in_channels=n_feats_k, out_channels=4 * 1, kernel_size=3, padding=1 - ), - nn.PixelShuffle(2), # feature_map:(B, 2x2x1, N, N) -> (B, 1, 2N, 2N) - ) - - def forward(self, rec, pre, rpr, slice_qp): - in_0 = self.head_rec(rec) - in_1 = self.head_pre(pre) - - x = self.fuse(torch.cat((in_0, in_1, slice_qp), 1)) - x = self.body(x) - x = self.tail(x) - x += rpr - - return x - - -class DscBlock(nn.Module): - def __init__(self, n_feats_k, n_feats_m, expansion=1): - super(DscBlock, self).__init__() - self.expansion = expansion - self.c1 = nn.Conv2d( - in_channels=n_feats_k, out_channels=n_feats_m, kernel_size=1, padding=0 - ) - self.prelu = nn.PReLU() - self.c2 = nn.Conv2d( - in_channels=n_feats_m, out_channels=n_feats_k, kernel_size=1, padding=0 - ) - self.c3 = nn.Conv2d( - in_channels=n_feats_k, out_channels=n_feats_k, kernel_size=3, padding=1 - ) - - def forward(self, x): - i = x - x = self.c2(self.prelu(self.c1(x))) - x = self.c3(x) - x += i - - return x diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train.sh b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train.sh deleted file mode 100644 index 7b026d766aefdfa8a5c92a46d1750661bc344f5a..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train.sh +++ /dev/null @@ -1,32 +0,0 @@ -# The copyright in this software is being made available under the BSD -# License, included below. This software may be subject to other third party -# and contributor rights, including patent rights, and no such rights are -# granted under this license. -# -# Copyright (c) 2010-2022, ITU/ISO/IEC -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -# be used to endorse or promote products derived from this software without -# specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -# THE POSSIBILITY OF SUCH DAMAGE. -python train_YUV.py \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train_YUV.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train_YUV.py deleted file mode 100644 index dc5791e5794072c43074350947f936792e9fe43e..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train_YUV.py +++ /dev/null @@ -1,398 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import torch -import torch.nn as nn -from torch.optim.adam import Adam -from torch.optim.lr_scheduler import MultiStepLR -from torch.utils.data.dataloader import DataLoader -import datetime -import os -import glob - -from yuv10bdata import YUV10bData -from Utils import init, cal_psnr -from nn_model import Net - -torch.backends.cudnn.enabled = True -torch.backends.cudnn.benchmark = True - - -class Trainer: - def __init__(self): - ( - self.args, - self.logger, - self.checkpoints_dir, - self.tensorboard_all, - self.tensorboard, - ) = init() - - self.net = Net().to("cuda" if self.args.gpu else "cpu") - - self.L1loss = nn.L1Loss().to("cuda" if self.args.gpu else "cpu") - self.L2loss = nn.MSELoss().to("cuda" if self.args.gpu else "cpu") - - self.optimizer = Adam(self.net.parameters(), lr=self.args.lr) - self.scheduler = MultiStepLR( - optimizer=self.optimizer, milestones=[4001, 4002], gamma=0.5 - ) - - print("============>loading data") - self.train_dataset = YUV10bData(self.args, train=True) - self.eval_dataset = YUV10bData(self.args, train=False) - - self.train_dataloader = DataLoader( - dataset=self.train_dataset, - batch_size=self.args.batch_size, - shuffle=True, - num_workers=12, - pin_memory=False, - ) - self.eval_dataloader = DataLoader( - dataset=self.eval_dataset, - batch_size=self.args.batch_size, - shuffle=True, - num_workers=12, - pin_memory=False, - ) - - self.train_steps = self.eval_steps = 0 - - def train(self): - start_epoch = self.load_checkpoints() - print("============>start training") - for epoch in range(start_epoch, self.args.max_epoch): - print("Epoch {}/{}".format(epoch, self.args.max_epoch)) - self.logger.info("Epoch {}/{}".format(epoch, self.args.max_epoch)) - self.train_one_epoch() - self.scheduler.step() - if (epoch + 1) % self.args.eval_epochs == 0: - self.eval(epoch=epoch) - self.save_ckpt(epoch=epoch) - - def train_one_epoch(self): - self.net.train() - for _, tensor in enumerate(self.train_dataloader): - img_lr, img_hr, filename = tensor - - img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") - img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") - - rec = img_lr[:, 0:1, :, :] - pre = img_lr[:, 1:2, :, :] - slice_qp = img_lr[:, 2:3, :, :] - img_rpr = img_hr[:, 0:1, :, :] - img_ori = img_hr[:, 1:2, :, :] - - img_out = self.net(rec, pre, img_rpr, slice_qp) - - # calculate distortion - shave = self.args.shave - - L1_loss_pred_Y = self.L1loss( - img_out[:, 0, shave:-shave, shave:-shave], - img_ori[:, 0, shave:-shave, shave:-shave], - ) - # L1_loss_pred_Cb = self.L1loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # L1_loss_pred_Cr = self.L1loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - loss_pred_Y = self.L2loss( - img_out[:, 0, shave:-shave, shave:-shave], - img_ori[:, 0, shave:-shave, shave:-shave], - ) - # loss_pred_Cb = self.L2loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_pred_Cr = self.L2loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - # loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr - loss_pred = L1_loss_pred_Y - - # loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - # loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - # visualization - self.train_steps += 1 - if self.train_steps % 20 == 0: - psnr_pred_Y = cal_psnr(loss_pred_Y) - # psnr_pred_Cb = cal_psnr(loss_pred_Cb) - # psnr_pred_Cr = cal_psnr(loss_pred_Cr) - - # psnr_input_Y = cal_psnr(loss_rec_Y) - # psnr_input_Cb = cal_psnr(loss_rec_Cb) - # psnr_input_Cr = cal_psnr(loss_rec_Cr) - - time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") - - print( - "[{}/{}]\tY:{:.8f}\tPSNR_Y: {:.8f}------{}".format( - (self.train_steps % len(self.train_dataloader)), - len(self.train_dataloader), - loss_pred_Y, - psnr_pred_Y, - time, - ) - ) - self.logger.info( - "[{}/{}]\tY:{:.8f}\tPSNR_Y: {:.8f}".format( - (self.train_steps % len(self.train_dataloader)), - len(self.train_dataloader), - loss_pred_Y, - psnr_pred_Y, - ) - ) - - # print("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), - # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y, time)) - # self.logger.info("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), - # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y)) - - self.tensorboard_all.add_scalars( - main_tag="Train/PSNR", - tag_scalar_dict={"pred_Y": psnr_pred_Y.data}, - global_step=self.train_steps, - ) - # self.tensorboard_all.add_image("rec", rec[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - # self.tensorboard_all.add_image("pre", pre[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - # self.tensorboard_all.add_image("rpr", img_rpr[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - # self.tensorboard_all.add_image("out", img_out[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - # self.tensorboard_all.add_image("ori", img_ori[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) - - # self.tensorboard_all.add_scalars(main_tag="Train/PSNR", - # tag_scalar_dict={"input_Cb": psnr_input_Cb.data, - # "pred_Cb": psnr_pred_Cb.data}, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalars(main_tag="Train/PSNR", - # tag_scalar_dict={"input_Cr": psnr_input_Cr.data, - # "pred_Cr": psnr_pred_Cr.data}, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Y", - # scalar_value = psnr_pred_Y - psnr_input_Y, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cb", - # scalar_value = psnr_pred_Cb - psnr_input_Cb, - # global_step=self.train_steps) - - # self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cr", - # scalar_value = psnr_pred_Cr - psnr_input_Cr, - # global_step=self.train_steps) - - self.tensorboard_all.add_scalar( - tag="Train/train_loss_pred", - scalar_value=loss_pred, - global_step=self.train_steps, - ) - - # backward - self.optimizer.zero_grad() - loss_pred.backward() - self.optimizer.step() - - @torch.no_grad() - def eval(self, epoch: int): - print("============>start evaluating") - eval_cnt = 0 - ave_psnr_Y = 0.000 - # ave_psnr_Cb = 0.000 - # ave_psnr_Cr = 0.000 - self.net.eval() - for _, tensor in enumerate(self.eval_dataloader): - img_lr, img_hr, filename = tensor - - img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") - img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") - - rec = img_lr[:, 0:1, :, :] - pre = img_lr[:, 1:2, :, :] - slice_qp = img_lr[:, 2:3, :, :] - img_rpr = img_hr[:, 0:1, :, :] - img_ori = img_hr[:, 1:2, :, :] - img_out = self.net(rec, pre, img_rpr, slice_qp) - - # calculate distortion and psnr - shave = self.args.shave - - L1_loss_pred_Y = self.L1loss( - img_out[:, 0, shave:-shave, shave:-shave], - img_ori[:, 0, shave:-shave, shave:-shave], - ) - # L1_loss_pred_Cb = self.L1loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # L1_loss_pred_Cr = self.L1loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - loss_pred_Y = self.L2loss( - img_out[:, 0, shave:-shave, shave:-shave], - img_ori[:, 0, shave:-shave, shave:-shave], - ) - # loss_pred_Cb = self.L2loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_pred_Cr = self.L2loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - # loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr - loss_pred = L1_loss_pred_Y - - # loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) - # loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) - # loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) - - psnr_pred_Y = cal_psnr(loss_pred_Y) - # psnr_pred_Cb = cal_psnr(loss_pred_Cb) - # psnr_pred_Cr = cal_psnr(loss_pred_Cr) - - # psnr_input_Y = cal_psnr(loss_rec_Y) - # psnr_input_Cb = cal_psnr(loss_rec_Cb) - # psnr_input_Cr = cal_psnr(loss_rec_Cr) - - ave_psnr_Y += psnr_pred_Y - # ave_psnr_Cb += psnr_pred_Cb - psnr_input_Cb - # ave_psnr_Cr += psnr_pred_Cr - psnr_input_Cr - - eval_cnt += 1 - # visualization - self.eval_steps += 1 - if self.eval_steps % 2 == 0: - self.tensorboard_all.add_scalar( - tag="Eval/PSNR_Y", - scalar_value=psnr_pred_Y, - global_step=self.eval_steps, - ) - - # self.tensorboard_all.add_scalar(tag="Eval/delta_PSNR_Cb", - # scalar_value = psnr_pred_Cb - psnr_input_Cb, - # global_step=self.eval_steps) - - # self.tensorboard_all.add_scalar(tag="Eval/delta_PSNR_Cr", - # scalar_value = psnr_pred_Cr - psnr_input_Cr, - # global_step=self.eval_steps) - - self.tensorboard_all.add_scalar( - tag="Eval/eval_loss_pred", - scalar_value=loss_pred, - global_step=self.eval_steps, - ) - - time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") - print("PSNR_Y:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, time)) - self.logger.info("PSNR_Y:{:.3f}".format(ave_psnr_Y / eval_cnt)) - - # print("delta_Y:{:.3f}\tdelta_Cb:{:.3f}\tdelta_Cr:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt, time)) - # self.logger.info("delta_Y:{:.3f}\tdelta_Cb:{:.3f}\tdelta_Cr:{:.3f}".format(ave_psnr_Y / eval_cnt, ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt)) - - self.tensorboard.add_scalar( - tag="Eval/PSNR_Y_ave", - scalar_value=ave_psnr_Y / eval_cnt, - global_step=epoch + 1, - ) - # self.tensorboard.add_scalar(tag = "Eval/delta_PSNR_Cb_ave", - # scalar_value = ave_psnr_Cb / eval_cnt, - # global_step = epoch + 1) - # self.tensorboard.add_scalar(tag = "Eval/delta_PSNR_Cr_ave", - # scalar_value = ave_psnr_Cr / eval_cnt, - # global_step = epoch + 1) - - def load_checkpoints(self): - if not self.args.checkpoints: - ckpt_list = sorted(glob.glob(os.path.join(self.checkpoints_dir, "*.pth"))) - num = len(ckpt_list) - if num > 1: - if os.path.getsize(ckpt_list[-1]) == os.path.getsize(ckpt_list[-2]): - self.args.checkpoints = ckpt_list[-1] - else: - self.args.checkpoints = ckpt_list[-2] - - if self.args.checkpoints: - print( - "===========Load checkpoints {0}===========".format( - self.args.checkpoints - ) - ) - self.logger.info("Load checkpoints {0}".format(self.args.checkpoints)) - ckpt = torch.load(self.args.checkpoints) - # load network weights - try: - self.net.load_state_dict(ckpt["network"]) - except Exception: - print("Can not find network weights") - # load optimizer params - try: - self.optimizer.load_state_dict(ckpt["optimizer"]) - self.scheduler.load_state_dict(ckpt["scheduler"]) - except Exception: - print("Can not find some optimizers params, just ignore") - start_epoch = ckpt["epoch"] + 1 - self.train_steps = ckpt["train_step"] + 1 - self.eval_steps = ckpt["eval_step"] + 1 - elif self.args.pretrained: - ckpt = torch.load(self.args.pretrained) - print( - "===========Load network weights {0}===========".format( - self.args.checkpoints - ) - ) - self.logger.info("Load network weights {0}".format(self.args.checkpoints)) - # load codec weights - try: - self.net.load_state_dict(ckpt["network"]) - except Exception: - print("Can not find network weights") - start_epoch = 0 - else: - print("===========Training from scratch===========") - self.logger.info("Training from scratch") - start_epoch = 0 - return start_epoch - - def save_ckpt(self, epoch: int): - checkpoint = { - "network": self.net.state_dict(), - "epoch": epoch, - "train_step": self.train_steps, - "eval_step": self.eval_steps, - "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.state_dict(), - } - - torch.save(checkpoint, "%s/model_%.4d.pth" % (self.checkpoints_dir, epoch + 1)) - self.logger.info("Save model..") - print( - "======================Saving model {0}======================".format( - str(epoch) - ) - ) - - -if __name__ == "__main__": - trainer = Trainer() - trainer.train() diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/yuv10bdata.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/yuv10bdata.py deleted file mode 100644 index 8e2bd317e66f2717bfef2e7955f4c79aad1b42e1..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/yuv10bdata.py +++ /dev/null @@ -1,267 +0,0 @@ -""" -/* The copyright in this software is being made available under the BSD -* License, included below. This software may be subject to other third party -* and contributor rights, including patent rights, and no such rights are -* granted under this license. -* -* Copyright (c) 2010-2022, ITU/ISO/IEC -* All rights reserved. -* -* Redistribution and use in source and binary forms, with or without -* modification, are permitted provided that the following conditions are met: -* -* * Redistributions of source code must retain the above copyright notice, -* this list of conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, -* this list of conditions and the following disclaimer in the documentation -* and/or other materials provided with the distribution. -* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may -* be used to endorse or promote products derived from this software without -* specific prior written permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS -* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -* THE POSSIBILITY OF SUCH DAMAGE. -""" - -import os -import glob -from torch.utils.data import Dataset -import numpy as np -import Utils -import math - - -class YUV10bData(Dataset): - def __init__(self, args, name="YuvData", train=True): - super(YUV10bData, self).__init__() - self.args = args - self.split = "train" if train else "valid" - self.image_ext = args.ext - self.name = name - self.train = train - - data_range = [r.split("-") for r in args.data_range.split("/")] - if train: - data_range = data_range[0] - else: - data_range = data_range[1] - self.begin, self.end = list(map(lambda x: int(x), data_range)) # noqa: C417 - - data_range_tvd = [r.split("-") for r in args.data_range_tvd.split("/")] - if train: - data_range_tvd = data_range_tvd[0] - else: - data_range_tvd = data_range_tvd[1] - self.begin_tvd, self.end_tvd = list( # noqa: C417 - map(lambda x: int(x), data_range_tvd) - ) - - self._set_data() - self._get_image_list() - - if train: - n_patches = args.batch_size * args.test_every - n_images = len(self.images_yuv) - if n_images == 0: - self.repeat = 0 - else: - self.repeat = max(n_patches // n_images, 1) - print(f"repeating dataset {self.repeat} for one epoch") - else: - n_patches = args.batch_size * args.test_every // 25 - n_images = len(self.images_yuv) - if n_images == 0: - self.repeat = 0 - else: - self.repeat = 5 - print(f"repeating dataset {self.repeat} for one epoch") - - def _set_data(self): - self.dir_in = os.path.join(self.args.dir_data, "yuv") - self.dir_org = os.path.join(self.args.dir_data_ori) - - self.dir_in_tvd = os.path.join(self.args.dir_data_tvd, "yuv") - self.dir_org_tvd = os.path.join(self.args.dir_data_ori_tvd) - - def _scan_class(self, is_tvd, class_name): - QPs = ["22", "27", "32", "37", "42"] - - dir_in = self.dir_in_tvd if is_tvd else self.dir_in - list_temp = glob.glob( - os.path.join(dir_in, "*_" + class_name + "_*." + self.image_ext) - ) - file_rec_list = [] - for i in list_temp: - index = i.find("poc") - poc = int(i[index + 3 : index + 6]) - if poc % 3 == 0 and poc != 0: - file_rec_list.append(i) - - if is_tvd: - file_rec_list = sorted(file_rec_list) - else: - file_rec_list = sorted(file_rec_list, key=str.lower) - - dir_org = self.dir_org_tvd if is_tvd else self.dir_org - list_temp = glob.glob( - os.path.join(dir_org, class_name + "*/*." + self.image_ext) - ) - # print(list_temp) - file_org_list = [] - for i in list_temp: - index = i.find("frame_") - poc = int(i[index + 6 : index + 9]) - if poc % 3 == 0 and poc != 0: - file_org_list.append(i) - - if is_tvd: - file_org_list = sorted(file_org_list) - else: - file_org_list = sorted(file_org_list, key=str.lower) - - frame_num = 62 - frame_num_sampled = math.ceil(frame_num / 3) - begin = self.begin_tvd if is_tvd else self.begin - end = self.end_tvd if is_tvd else self.end - - class_names_yuv = [] - class_names_yuv_org = [] - - for qp in QPs: - file_list = file_rec_list[ - (begin - 1) * frame_num_sampled * 5 : end * frame_num_sampled * 5 - ] - for filename in file_list: - idx = filename.find("qp") - if int(filename[idx + 2 : idx + 4]) == int(qp): - class_names_yuv.append(filename) - - file_list = file_org_list[ - (begin - 1) * frame_num_sampled : end * frame_num_sampled - ] - for filename in file_list: - class_names_yuv_org.append(filename) - - return class_names_yuv, class_names_yuv_org - - def _scan(self): - bvi_class_set = ["A"] - - names_yuv = [] - names_yuv_org = [] - for class_name in bvi_class_set: - class_names_yuv, class_names_yuv_org = self._scan_class(False, class_name) - names_yuv = names_yuv + class_names_yuv - names_yuv_org = names_yuv_org + class_names_yuv_org - - class_names_yuv, class_names_yuv_org = self._scan_class(True, "A") - names_yuv = names_yuv + class_names_yuv - names_yuv_org = names_yuv_org + class_names_yuv_org - - print(len(names_yuv)) - print(len(names_yuv_org)) - - return names_yuv, names_yuv_org - - def _get_image_list(self): - self.images_yuv, self.images_yuv_org = self._scan() - - def __getitem__(self, idx): - patch_in, patch_org, filename = self._load_file_get_patch(idx) - pair_t = Utils.np2Tensor(patch_in, patch_org) - - return pair_t[0], pair_t[1], filename - - def __len__(self): - if self.train: - return len(self.images_yuv) * self.repeat - else: - return len(self.images_yuv) * self.repeat - - def _get_index(self, idx): - if self.train: - return idx % len(self.images_yuv) - else: - return idx % len(self.images_yuv) - - def _load_file_get_patch(self, idx): - idx = self._get_index(idx) - - # reconstruction - image_yuv_path = self.images_yuv[idx] - - slice_qp_idx = int(image_yuv_path.rfind("qp")) - slice_qp = int(image_yuv_path[slice_qp_idx + 2 : slice_qp_idx + 4]) - slice_qp_map = np.uint16( - np.ones( - ( - (self.args.patch_size + 2 * self.args.shave) // 2, - (self.args.patch_size + 2 * self.args.shave) // 2, - 1, - ) - ) - * slice_qp - ) - - # base_qp_idx = int(image_yuv_path.find("qp")) - # base_qp = int(image_yuv_path[base_qp_idx + 2 : base_qp_idx + 4]) - # base_qp_map = np.uint16( - # np.ones( - # ( - # (self.args.patch_size + 2 * self.args.shave) // 2, - # (self.args.patch_size + 2 * self.args.shave) // 2, - # 1, - # ) - # ) - # * base_qp - # ) - - # prediction - pred_str = "_prediction" - pos = image_yuv_path.find(".yuv") - image_yuv_pred_path = image_yuv_path[:pos] + pred_str + image_yuv_path[pos:] - image_yuv_pred_path = image_yuv_pred_path.replace("/yuv/", "/prediction_image/") - - # RPR - rpr_str = "_rpr" - pos = image_yuv_path.find(".yuv") - image_yuv_rpr_path = image_yuv_path[:pos] + rpr_str + image_yuv_path[pos:] - image_yuv_rpr_path = image_yuv_rpr_path.replace("/yuv/", "/rpr_image/") - - # original - image_yuv_org_path = self.images_yuv_org[idx] - org_splits = os.path.basename(os.path.dirname(image_yuv_org_path)).split("_") - wh_org = org_splits[1].split("x") - w, h = list(map(lambda x: int(x), wh_org)) # noqa: C417 - - patch_in, patch_rpr, patch_org = Utils.get_patch( - image_yuv_path, - image_yuv_pred_path, - image_yuv_rpr_path, - image_yuv_org_path, - w, - h, - self.args.patch_size, - self.args.shave, - ) - - patch_in = np.concatenate((patch_in, slice_qp_map), axis=2) - - if self.train: - patch_in, patch_rpr, patch_org = Utils.augment( - patch_in, patch_rpr, patch_org - ) - - patch_lr = patch_in - patch_hr = np.concatenate((patch_rpr, patch_org), axis=2) - - return patch_lr, patch_hr, image_yuv_path diff --git a/training/training_scripts/NN_Super_Resolution/ReadMe.md b/training/training_scripts/NN_Super_Resolution/ReadMe.md deleted file mode 100644 index 82018fc8a01bb8ca5d7bafe94905d3d2c6a236f0..0000000000000000000000000000000000000000 --- a/training/training_scripts/NN_Super_Resolution/ReadMe.md +++ /dev/null @@ -1,14 +0,0 @@ -## Overview - -Requirements: -* One GPU with greater than 25GiB memory. -* Preferably the disk storage size is greater than 5TB. - -The overview of relationships among these scripts is shown below: -* Generate the raw data -* Generate the compression data -* Training and final model conversion - - -For the better viewing experience, please open those files with [typora](https://typora.io/). -Certainly, the guidance can be viewed as a general txt file, but the figure inside can not be displayed directly. In this case, you can find the needed figure in the figure folder and view it individually. \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_chroma-IB.png b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_chroma-IB.png deleted file mode 100644 index c14898bdf23ac1d79354538c9579b1f07161cca7..0000000000000000000000000000000000000000 Binary files a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_chroma-IB.png and /dev/null differ diff --git a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-B.png b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-B.png deleted file mode 100644 index e97a6476b4fea3b3bb525287563faa660d9a6a30..0000000000000000000000000000000000000000 Binary files a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-B.png and /dev/null differ diff --git a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-I.png b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-I.png deleted file mode 100644 index c6b552699415b12db2174320524a375ee6b843f3..0000000000000000000000000000000000000000 Binary files a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-I.png and /dev/null differ diff --git a/training/training_scripts/NN_Super_Resolution/figure/overview.png b/training/training_scripts/NN_Super_Resolution/figure/overview.png deleted file mode 100644 index 88af1b996f999c46506d2ffe0d700b0ca158f259..0000000000000000000000000000000000000000 Binary files a/training/training_scripts/NN_Super_Resolution/figure/overview.png and /dev/null differ