diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h index 0bf33d4e6753a3b9fbe5be7dc9c3e957220b55d2..f3802920e2f48d11d0c34d9a97d79feb1cde3502 100644 --- a/source/Lib/CommonLib/TypeDef.h +++ b/source/Lib/CommonLib/TypeDef.h @@ -122,7 +122,7 @@ using TypeSadl = float; #define JVET_AC0177_FLIP_INPUT 1 // JVET-AC0177: flip input and output of NN filter model #endif -#define JVET_AC0196_NNSR 1 // JVET-AC0916: EE1-2.2: GOP Level Adaptive Resampling with CNN-based Super Resolution +#define JVET_AC0196_NNSR 1 // JVET-AC0196: EE1-2.2: GOP Level Adaptive Resampling with CNN-based Super Resolution #if JVET_AC0196_NNSR #define ADAPTIVE_RPR 1 // JVET-AC0196: GOP Level Adaptive Resampling #endif diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/1_ReadMe.md b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/1_ReadMe.md new file mode 100644 index 0000000000000000000000000000000000000000..58d18ceeb97dbfa408aed7fe4ae2d3f5d21d3a6a --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/1_ReadMe.md @@ -0,0 +1,43 @@ +## [TVD](https://multimedia.tencent.com/resources/tvd) +All the sequences with 10 bit-depth (74 sequences) are used. (Partial dataset) +Generation: +1. Run tvd_to_yuv.py (Convert mp4 files to YUV files) +2. Run tvd_to_yuv_frame.py (Split YUV videos) +3. Obtain the raw data in **TVD_Video_YUV** to be used for the subsequent compression and as the ground truth + +## [BVI-DVC](https://vilab.blogs.bristol.ac.uk/2020/02/2375/) +All the 3840x2176 sequences (200 sequences) are used. (Partial dataset) +Generation: +1. Run bvi_dvc_to_yuv.py (Convert mp4 files to YUV files) +2. Run bvi_dvc_to_yuv_frame.py (Split YUV videos) +2. Obtain the raw data in **bvi_dvc_YUV** to be used for the subsequent compression and as the ground truth + +## The file structure of the raw dataset +When the raw dataset generation is finished, it should be the following file structure. + +### TVD +``` + TVD_Video_YUV + │ Bamboo_3840x2160_25fps_10bit_420.yuv + │ │ frame_000.yuv + │ │ frame_001.yuv + │ │ ... + │ BlackBird_3840x2160_25fps_10bit_420.yuv + │ │ frame_000.yuv + │ │ frame_001.yuv + │ │ ... + │ ... +``` +### BVI-DVC +``` + bvi_dvc_YUV + │ AAdvertisingMassagesBangkokVidevo_3840x2176_25fps_10bit_420.yuv + │ │ frame_000.yuv + │ │ frame_001.yuv + │ │ ... + │ AAmericanFootballS2Harmonics_3840x2176_60fps_10bit_420.yuv + │ │ frame_000.yuv + │ │ frame_001.yuv + │ │ ... + │ ... +``` \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv.py b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv.py new file mode 100644 index 0000000000000000000000000000000000000000..e5944bc04c8a9b5a3959f8316ea9af5c5dcecd78 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv.py @@ -0,0 +1,56 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os + +mp4_dir = './Videos' +yuv_dir = './bvi_dvc_YUV' +if not os.path.exists(yuv_dir): + os.makedirs(yuv_dir) + +for mp4_file in sorted(os.listdir(mp4_dir)): + if mp4_file[0] != 'A': + continue + + mp4_path = os.path.join(mp4_dir, mp4_file) + yuv_path = os.path.join(yuv_dir, mp4_file[:-4]+'.yuv') + if not os.path.exists(tar): + os.makedirs(tar) + + size = mp4_file.split('_')[1] + + cmd = f'ffmpeg -i {mp4_path} -pix_fmt yuv420p10le {yuv_path}' + os.system(cmd) + + diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv_frame.py b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..be9d08b46f6ed1532b036cd3b48fb17a99708949 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/bvi_dvc_to_yuv_frame.py @@ -0,0 +1,56 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os + +yuv_dir = './bvi_dvc_YUV' +yuv_dir_frame = './bvi_dvc_YUV_frame' +if not os.path.exists(yuv_dir_frame): + os.makedirs(yuv_dir_frame) + +for yuv_file in sorted(os.listdir(yuv_dir)): + if yuv_file[0] != 'A': + continue + + yuv_path = os.path.join(yuv_dir, yuv_file) + tar = os.path.join(yuv_dir_frame, yuv_file[:-4]) + if not os.path.exists(tar): + os.makedirs(tar) + + size = yuv_file.split('_')[1] + + cmd = f'ffmpeg -s {size} -pix_fmt yuv420p10le -i {yuv_path} -c copy -f segment -segment_time 0.001 {tar}/frame_%03d.yuv' + os.system(cmd) + + diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv.py b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv.py new file mode 100644 index 0000000000000000000000000000000000000000..3300dc9e5f3b95f03a4009aba838f4b34fd3b118 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv.py @@ -0,0 +1,56 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os + +mp4_dir = './Video' +yuv_dir = './TVD_Video_YUV' +if not os.path.exists(yuv_dir): + os.makedirs(yuv_dir) + +for mp4_file in sorted(os.listdir(mp4_dir)): + if '_10bit_' not in mp4_file: + continue + + mp4_path = os.path.join(mp4_dir, mp4_file) + yuv_path = os.path.join(yuv_dir, mp4_file[:-4]+'.yuv') + if not os.path.exists(tar): + os.makedirs(tar) + + size = mp4_file.split('_')[1] + + cmd = f'ffmpeg -i {mp4_path} -pix_fmt yuv420p10le {yuv_path}' + os.system(cmd) + + diff --git a/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv_frame.py b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..24203171020877b186c91caff296b47e9b04cc1a --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/1_generate_raw_data/tvd_to_yuv_frame.py @@ -0,0 +1,56 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os + +yuv_dir = './TVD_Video_YUV' +yuv_dir_frame = './TVD_Video_YUV_frame' +if not os.path.exists(yuv_dir_frame): + os.makedirs(yuv_dir_frame) + +for yuv_file in sorted(os.listdir(yuv_dir)): + if '_10bit_' not in yuv_file: + continue + + yuv_path = os.path.join(yuv_dir, yuv_file) + tar = os.path.join(yuv_dir_frame, yuv_file) + if not os.path.exists(tar): + os.makedirs(tar) + + size = yuv_file.split('_')[1] + + cmd = f'ffmpeg -s {size} -pix_fmt yuv420p10le -i {yuv_path} -c copy -f segment -segment_time 0.001 {tar}/frame_%03d.yuv' + os.system(cmd) + + diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/2_ReadMe.md b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/2_ReadMe.md new file mode 100644 index 0000000000000000000000000000000000000000..7b4b724ece03812609570b62d35a5f6124210078 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/2_ReadMe.md @@ -0,0 +1,114 @@ +## Generate the dataset +For the convenience, the detailed codec information, including resolution, encode level and so on, is provided by the python files bvi_dvc_codec_info.py and tvd_codec_info.py, which can make it easier to build your own script on the cluster. + +The corresponding raw dataset in sequence level YUV format to be compressed is generated by the scripts in '../1_generate_raw_data/'. +Finally, the compression dataset in frame level YUV format is obtained at the decoder based on VTM-11.0_nnvc-2.0 (https://vcgit.hhi.fraunhofer.de/jvet-ahg-nnvc/VVCSoftware_VTM/-/tree/VTM-11.0_nnvc-2.0). +Two patches (generate_SR_datasets_for_I_slices.patch and generate_SR_datasets_for_B_slices.patch) are applied to VTM-11.0_nnvc-2.0 to generate I-frame datasets and B-frame datasets, respectively. + +Specifically, the compression dataset generated from decoder includes reconstruction images, prediction images and RPR images. + +### Generate the I data +TVD and BVI-DVC are both compressed under AI configuration and then all I slices are selected to build this dataset. +VTM-11.0_nnvc-2.0 with generate_SR_datasets_for_I_slices.patch is used to generate I data, and TemporalSubsampleRatio is set to 1 as follows. +``` +--TemporalSubsampleRatio=1 +``` + +The macro configurations are provided as follows. Note that this macro DATA_GEN_DEC should be turned off on the encoder, and turned on on the decoder. +``` +#define DATA_GEN_ENC 1 // Encode frame by RPR downsampling +#define DATA_GEN_DEC 1 // Decode bin files to generate dataset, which should be turned off when running the encoder +#define DATA_PREDICTION 1 // Prediction data +``` + +### Generate the B data +TVD and BVI-DVC are both compressed under RA configuration and then all B slices are selected to build this dataset. + +The macro configurations are provided as follows. Note that this macro DATA_GEN_DEC should be turned off on the encoder, and turned on on the decoder. +``` +#define DATA_GEN_ENC 1 // Encode frame by RPR downsampling +#define DATA_GEN_DEC 1 // Decode bin files to generate dataset, which should be turned off when running the encoder +#define DATA_PREDICTION 1 // Prediction data +``` + +## The file structure of the compression dataset +When the compression dataset generation is finished, it should be adjusted into the following file structure. + +### AI +``` + AI_TVD + └───yuv + │ │ Bin_T1AI_A_S01_R32_qp32_s0_f65_t1_poc000.yuv + │ │ Bin_T1AI_A_S02_R27_qp27_s0_f65_t1_poc064.yuv + │ │ Bin_T1AI_A_S03_R32_qp32_s0_f65_t1_poc015.yuv + │ │ ... + │ + └───prediction_image + │ │ Bin_T1AI_A_S01_R32_qp32_s0_f65_t1_poc000_prediction.yuv + │ │ Bin_T1AI_A_S02_R27_qp27_s0_f65_t1_poc064_prediction.yuv + │ │ Bin_T1AI_A_S03_R32_qp32_s0_f65_t1_poc015_prediction.yuv + │ │ ... + └───rpr_image + │ Bin_T1AI_A_S01_R32_qp32_s0_f65_t1_poc000_rpr.yuv + │ Bin_T1AI_A_S02_R27_qp27_s0_f65_t1_poc064_rpr.yuv + │ Bin_T1AI_A_S03_R32_qp32_s0_f65_t1_poc015_rpr.yuv + │ ... + + AI_BVI_DVC + └───yuv + │ │ Bin_T1AI_A_S001_R32_qp32_s0_f64_t1_poc000.yuv + │ │ Bin_T1AI_A_S002_R27_qp27_s0_f64_t1_poc063.yuv + │ │ Bin_T1AI_A_S003_R32_qp32_s0_f64_t1_poc015.yuv + │ │ ... + │ + └───prediction_image + │ │ Bin_T1AI_A_S001_R32_qp32_s0_f64_t1_poc000_prediction.yuv + │ │ Bin_T1AI_A_S002_R27_qp27_s0_f64_t1_poc063_prediction.yuv + │ │ Bin_T1AI_A_S003_R32_qp32_s0_f64_t1_poc015_prediction.yuv + │ │ ... + │ + └───rpr_image + │ Bin_T1AI_A_S001_R32_qp32_s0_f64_t1_poc000_rpr.yuv + │ Bin_T1AI_A_S002_R27_qp27_s0_f64_t1_poc063_rpr.yuv + │ Bin_T1AI_A_S003_R32_qp32_s0_f64_t1_poc015_rpr.yuv + │ ... +``` +### RA +``` + RA_TVD + └───yuv + │ │ Bin_T2RA_A_S01_R27_qp27_s0_f65_t1_poc063_qp36.yuv + │ │ Bin_T2RA_A_S02_R32_qp32_s0_f65_t1_poc031_qp41.yuv + │ │ Bin_T2RA_A_S03_R42_qp42_s0_f65_t1_poc062_qp50.yuv + │ │ ... + │ + └───prediction_image + │ │ Bin_T2RA_A_S01_R27_qp27_s0_f65_t1_poc063_qp36_prediction.yuv + │ │ Bin_T2RA_A_S02_R32_qp32_s0_f65_t1_poc031_qp41_prediction.yuv + │ │ Bin_T2RA_A_S03_R42_qp42_s0_f65_t1_poc062_qp50_prediction.yuv + │ │ ... + └───rpr_image + │ Bin_T2RA_A_S01_R27_qp27_s0_f65_t1_poc063_qp36_rpr.yuv + │ Bin_T2RA_A_S02_R32_qp32_s0_f65_t1_poc031_qp41_rpr.yuv + │ Bin_T2RA_A_S03_R42_qp42_s0_f65_t1_poc062_qp50_rpr.yuv + │ ... + + RA_BVI_DVC + └───yuv + │ │ Bin_T2RA_A_S001_R27_qp27_s0_f64_t1_poc063_qp36.yuv + │ │ Bin_T2RA_A_S002_R32_qp32_s0_f64_t1_poc031_qp41.yuv + │ │ Bin_T2RA_A_S003_R42_qp42_s0_f65_t1_poc062_qp50.yuv + │ │ ... + │ + └───prediction_image + │ │ Bin_T2RA_A_S001_R27_qp27_s0_f64_t1_poc063_qp36_prediction.yuv + │ │ Bin_T2RA_A_S002_R32_qp32_s0_f64_t1_poc031_qp41_prediction.yuv + │ │ Bin_T2RA_A_S003_R42_qp42_s0_f65_t1_poc062_qp50_prediction.yuv + │ │ ... + │ + └───rpr_image + │ │ Bin_T2RA_A_S001_R27_qp27_s0_f64_t1_poc063_qp36_rpr.yuv + │ │ Bin_T2RA_A_S002_R32_qp32_s0_f64_t1_poc031_qp41_rpr.yuv + │ │ Bin_T2RA_A_S003_R42_qp42_s0_f65_t1_poc062_qp50_rpr.yuv + │ ... +``` diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_B_slices.patch b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_B_slices.patch new file mode 100644 index 0000000000000000000000000000000000000000..731bafebbf61a1d749fd8489ff380ed7766883c0 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_B_slices.patch @@ -0,0 +1,561 @@ +From 73a78fab59f91a4f30f24ab520baf70edbe855fc Mon Sep 17 00:00:00 2001 +From: renjiechang <renjiechang@tencent.com> +Date: Tue, 14 Feb 2023 15:21:58 +0800 +Subject: [PATCH] Generate SR datasets for B slices + +--- + source/App/DecoderApp/DecApp.cpp | 26 +++++++- + source/App/DecoderApp/DecAppCfg.cpp | 4 ++ + source/App/EncoderApp/EncAppCfg.cpp | 8 +++ + source/Lib/CommonLib/CodingStructure.cpp | 48 ++++++++++++++ + source/Lib/CommonLib/CodingStructure.h | 17 +++++ + source/Lib/CommonLib/Picture.cpp | 19 ++++++ + source/Lib/CommonLib/Picture.h | 8 +++ + source/Lib/CommonLib/Rom.cpp | 5 ++ + source/Lib/CommonLib/Rom.h | 4 ++ + source/Lib/CommonLib/TypeDef.h | 3 + + source/Lib/DecoderLib/DecCu.cpp | 14 +++- + source/Lib/EncoderLib/EncLib.cpp | 4 ++ + source/Lib/Utilities/VideoIOYuv.cpp | 82 +++++++++++++++++++++++- + source/Lib/Utilities/VideoIOYuv.h | 6 +- + 14 files changed, 243 insertions(+), 5 deletions(-) + +diff --git a/source/App/DecoderApp/DecApp.cpp b/source/App/DecoderApp/DecApp.cpp +index 85f63bb0..1842d346 100644 +--- a/source/App/DecoderApp/DecApp.cpp ++++ b/source/App/DecoderApp/DecApp.cpp +@@ -88,6 +88,17 @@ uint32_t DecApp::decode() + EXIT( "Failed to open bitstream file " << m_bitstreamFileName.c_str() << " for reading" ) ; + } + ++#if DATA_GEN_DEC ++ strcpy(global_str_name, m_bitstreamFileName.c_str()); ++ { ++ size_t len = strlen(global_str_name); ++ for (size_t i = len - 1; i > len - 5; i--) ++ { ++ global_str_name[i] = 0; ++ } ++ } ++#endif ++ + InputByteStream bytestream(bitstreamFile); + + if (!m_outputDecodedSEIMessagesFilename.empty() && m_outputDecodedSEIMessagesFilename!="-") +@@ -678,7 +689,14 @@ void DecApp::xWriteOutput( PicList* pcListPic, uint32_t tId ) + ChromaFormat chromaFormatIDC = sps->getChromaFormatIdc(); + if( m_upscaledOutput ) + { +- m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ); ++ m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ++#if DATA_GEN_DEC ++ , pcPic ++#endif ++ ); ++#if DATA_PREDICTION ++ pcPic->m_bufs[PIC_TRUE_PREDICTION].destroy(); ++#endif + } + else + { +@@ -825,7 +843,11 @@ void DecApp::xFlushOutput( PicList* pcListPic, const int layerId ) + ChromaFormat chromaFormatIDC = sps->getChromaFormatIdc(); + if( m_upscaledOutput ) + { +- m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ); ++ m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ++#if DATA_GEN_DEC ++ , pcPic ++#endif ++ ); + } + else + { +diff --git a/source/App/DecoderApp/DecAppCfg.cpp b/source/App/DecoderApp/DecAppCfg.cpp +index d96c2049..ad912462 100644 +--- a/source/App/DecoderApp/DecAppCfg.cpp ++++ b/source/App/DecoderApp/DecAppCfg.cpp +@@ -124,7 +124,11 @@ bool DecAppCfg::parseCfg( int argc, char* argv[] ) + #endif + ("MCTSCheck", m_mctsCheck, false, "If enabled, the decoder checks for violations of mc_exact_sample_value_match_flag in Temporal MCTS ") + ("targetSubPicIdx", m_targetSubPicIdx, 0, "Specify which subpicture shall be written to output, using subpic index, 0: disabled, subpicIdx=m_targetSubPicIdx-1 \n" ) ++#if DATA_GEN_DEC ++ ( "UpscaledOutput", m_upscaledOutput, 2, "Upscaled output for RPR" ) ++#else + ( "UpscaledOutput", m_upscaledOutput, 0, "Upscaled output for RPR" ) ++#endif + ; + + po::setDefaults(opts); +diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp +index b38001eb..c43cd0e1 100644 +--- a/source/App/EncoderApp/EncAppCfg.cpp ++++ b/source/App/EncoderApp/EncAppCfg.cpp +@@ -1411,11 +1411,19 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) + ( "CCALF", m_ccalf, true, "Cross-component Adaptive Loop Filter" ) + ( "CCALFQpTh", m_ccalfQpThreshold, 37, "QP threshold above which encoder reduces CCALF usage") + ( "RPR", m_rprEnabledFlag, true, "Reference Sample Resolution" ) ++#if DATA_GEN_ENC ++ ( "ScalingRatioHor", m_scalingRatioHor, 2.0, "Scaling ratio in hor direction" ) ++ ( "ScalingRatioVer", m_scalingRatioVer, 2.0, "Scaling ratio in ver direction" ) ++ ( "FractionNumFrames", m_fractionOfFrames, 1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" ) ++ ( "SwitchPocPeriod", m_switchPocPeriod, 0, "Switch POC period for RPR" ) ++ ( "UpscaledOutput", m_upscaledOutput, 2, "Output upscaled (2), decoded but in full resolution buffer (1) or decoded cropped (0, default) picture for RPR" ) ++#else + ( "ScalingRatioHor", m_scalingRatioHor, 1.0, "Scaling ratio in hor direction" ) + ( "ScalingRatioVer", m_scalingRatioVer, 1.0, "Scaling ratio in ver direction" ) + ( "FractionNumFrames", m_fractionOfFrames, 1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" ) + ( "SwitchPocPeriod", m_switchPocPeriod, 0, "Switch POC period for RPR" ) + ( "UpscaledOutput", m_upscaledOutput, 0, "Output upscaled (2), decoded but in full resolution buffer (1) or decoded cropped (0, default) picture for RPR" ) ++#endif + ( "MaxLayers", m_maxLayers, 1, "Max number of layers" ) + #if JVET_S0163_ON_TARGETOLS_SUBLAYERS + ( "EnableOperatingPointInformation", m_OPIEnabled, false, "Enables writing of Operating Point Information (OPI)" ) +diff --git a/source/Lib/CommonLib/CodingStructure.cpp b/source/Lib/CommonLib/CodingStructure.cpp +index b655d445..15e542ba 100644 +--- a/source/Lib/CommonLib/CodingStructure.cpp ++++ b/source/Lib/CommonLib/CodingStructure.cpp +@@ -107,6 +107,9 @@ void CodingStructure::destroy() + parent = nullptr; + + m_pred.destroy(); ++#if DATA_PREDICTION ++ m_predTrue.destroy(); ++#endif + m_resi.destroy(); + m_reco.destroy(); + m_orgr.destroy(); +@@ -895,6 +898,9 @@ void CodingStructure::create(const ChromaFormat &_chromaFormat, const Area& _are + + m_reco.create( area ); + m_pred.create( area ); ++#if DATA_PREDICTION ++ m_predTrue.create( area ); ++#endif + m_resi.create( area ); + m_orgr.create( area ); + } +@@ -910,6 +916,9 @@ void CodingStructure::create(const UnitArea& _unit, const bool isTopLayer, const + + m_reco.create( area ); + m_pred.create( area ); ++#if DATA_PREDICTION ++ m_predTrue.create( area ); ++#endif + m_resi.create( area ); + m_orgr.create( area ); + } +@@ -1082,6 +1091,16 @@ void CodingStructure::rebindPicBufs() + { + m_pred.destroy(); + } ++#if DATA_PREDICTION ++ if (!picture->M_BUFS(0, PIC_TRUE_PREDICTION).bufs.empty()) ++ { ++ m_predTrue.createFromBuf(picture->M_BUFS(0, PIC_TRUE_PREDICTION)); ++ } ++ else ++ { ++ m_predTrue.destroy(); ++ } ++#endif + if (!picture->M_BUFS(0, PIC_RESIDUAL).bufs.empty()) + { + m_resi.createFromBuf(picture->M_BUFS(0, PIC_RESIDUAL)); +@@ -1240,12 +1259,20 @@ void CodingStructure::useSubStructure( const CodingStructure& subStruct, const C + if( parent ) + { + // copy data to picture ++#if DATA_PREDICTION ++ getTruePredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); ++ getPredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); ++#endif + if( cpyPred ) getPredBuf ( clippedArea ).copyFrom( subPredBuf ); + if( cpyResi ) getResiBuf ( clippedArea ).copyFrom( subResiBuf ); + if( cpyReco ) getRecoBuf ( clippedArea ).copyFrom( subRecoBuf ); + if( cpyOrgResi ) getOrgResiBuf( clippedArea ).copyFrom( subStruct.getOrgResiBuf( clippedArea ) ); + } + ++#if DATA_PREDICTION ++ picture->getTruePredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); ++ picture->getPredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); ++#endif + if( cpyPred ) picture->getPredBuf( clippedArea ).copyFrom( subPredBuf ); + if( cpyResi ) picture->getResiBuf( clippedArea ).copyFrom( subResiBuf ); + if( cpyReco ) picture->getRecoBuf( clippedArea ).copyFrom( subRecoBuf ); +@@ -1562,6 +1589,13 @@ const CPelBuf CodingStructure::getPredBuf(const CompArea &blk) const { r + PelUnitBuf CodingStructure::getPredBuf(const UnitArea &unit) { return getBuf(unit, PIC_PREDICTION); } + const CPelUnitBuf CodingStructure::getPredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_PREDICTION); } + ++#if DATA_PREDICTION ++ PelBuf CodingStructure::getTruePredBuf(const CompArea &blk) { return getBuf(blk, PIC_TRUE_PREDICTION); } ++const CPelBuf CodingStructure::getTruePredBuf(const CompArea &blk) const { return getBuf(blk, PIC_TRUE_PREDICTION); } ++ PelUnitBuf CodingStructure::getTruePredBuf(const UnitArea &unit) { return getBuf(unit, PIC_TRUE_PREDICTION); } ++const CPelUnitBuf CodingStructure::getTruePredBuf(const UnitArea &unit)const { return getBuf(unit, PIC_TRUE_PREDICTION); } ++#endif ++ + PelBuf CodingStructure::getResiBuf(const CompArea &blk) { return getBuf(blk, PIC_RESIDUAL); } + const CPelBuf CodingStructure::getResiBuf(const CompArea &blk) const { return getBuf(blk, PIC_RESIDUAL); } + PelUnitBuf CodingStructure::getResiBuf(const UnitArea &unit) { return getBuf(unit, PIC_RESIDUAL); } +@@ -1603,6 +1637,13 @@ PelBuf CodingStructure::getBuf( const CompArea &blk, const PictureType &type ) + + PelStorage* buf = type == PIC_PREDICTION ? &m_pred : ( type == PIC_RESIDUAL ? &m_resi : ( type == PIC_RECONSTRUCTION ? &m_reco : ( type == PIC_ORG_RESI ? &m_orgr : nullptr ) ) ); + ++#if DATA_PREDICTION ++ if (type == PIC_TRUE_PREDICTION) ++ { ++ buf = &m_predTrue; ++ } ++#endif ++ + CHECK( !buf, "Unknown buffer requested" ); + + CHECKD( !area.blocks[compID].contains( blk ), "Buffer not contained in self requested" ); +@@ -1637,6 +1678,13 @@ const CPelBuf CodingStructure::getBuf( const CompArea &blk, const PictureType &t + + const PelStorage* buf = type == PIC_PREDICTION ? &m_pred : ( type == PIC_RESIDUAL ? &m_resi : ( type == PIC_RECONSTRUCTION ? &m_reco : ( type == PIC_ORG_RESI ? &m_orgr : nullptr ) ) ); + ++#if DATA_PREDICTION ++ if (type == PIC_TRUE_PREDICTION) ++ { ++ buf = &m_predTrue; ++ } ++#endif ++ + CHECK( !buf, "Unknown buffer requested" ); + + CHECKD( !area.blocks[compID].contains( blk ), "Buffer not contained in self requested" ); +diff --git a/source/Lib/CommonLib/CodingStructure.h b/source/Lib/CommonLib/CodingStructure.h +index b5ae7ac6..cdd3fbf1 100644 +--- a/source/Lib/CommonLib/CodingStructure.h ++++ b/source/Lib/CommonLib/CodingStructure.h +@@ -62,6 +62,9 @@ enum PictureType + PIC_ORIGINAL_INPUT, + PIC_TRUE_ORIGINAL_INPUT, + PIC_FILTERED_ORIGINAL_INPUT, ++#if DATA_PREDICTION ++ PIC_TRUE_PREDICTION, ++#endif + NUM_PIC_TYPES + }; + extern XUCache g_globalUnitCache; +@@ -228,6 +231,9 @@ private: + std::vector<SAOBlkParam> m_sao; + + PelStorage m_pred; ++#if DATA_PREDICTION ++ PelStorage m_predTrue; ++#endif + PelStorage m_resi; + PelStorage m_reco; + PelStorage m_orgr; +@@ -268,6 +274,17 @@ public: + PelUnitBuf getPredBuf(const UnitArea &unit); + const CPelUnitBuf getPredBuf(const UnitArea &unit) const; + ++#if DATA_PREDICTION ++ PelBuf getTruePredBuf(const CompArea &blk); ++ const CPelBuf getTruePredBuf(const CompArea &blk) const; ++ PelUnitBuf getTruePredBuf(const UnitArea &unit); ++ const CPelUnitBuf getTruePredBuf(const UnitArea &unit) const; ++#endif ++ ++#if DATA_PREDICTION ++ PelUnitBuf getTruePredBuf() { return m_predTrue; } ++#endif ++ + PelBuf getResiBuf(const CompArea &blk); + const CPelBuf getResiBuf(const CompArea &blk) const; + PelUnitBuf getResiBuf(const UnitArea &unit); +diff --git a/source/Lib/CommonLib/Picture.cpp b/source/Lib/CommonLib/Picture.cpp +index a7205bad..3cdc698a 100644 +--- a/source/Lib/CommonLib/Picture.cpp ++++ b/source/Lib/CommonLib/Picture.cpp +@@ -277,6 +277,12 @@ void Picture::createTempBuffers( const unsigned _maxCUSize ) + { + M_BUFS( jId, PIC_PREDICTION ).create( chromaFormat, a, _maxCUSize ); + M_BUFS( jId, PIC_RESIDUAL ).create( chromaFormat, a, _maxCUSize ); ++ ++#if DATA_PREDICTION ++ const Area a_old(Position{ 0, 0 }, lumaSize()); ++ M_BUFS(jId, PIC_TRUE_PREDICTION).create(chromaFormat, a_old, _maxCUSize); ++#endif ++ + #if ENABLE_SPLIT_PARALLELISM + if (jId > 0) + { +@@ -305,6 +311,11 @@ void Picture::destroyTempBuffers() + { + M_BUFS(jId, t).destroy(); + } ++#if DATA_PREDICTION ++#if !DATA_GEN_DEC ++ if (t == PIC_TRUE_PREDICTION) M_BUFS(jId, t).destroy(); ++#endif ++#endif + #if ENABLE_SPLIT_PARALLELISM + if (t == PIC_RECONSTRUCTION && jId > 0) + { +@@ -344,6 +355,14 @@ const CPelBuf Picture::getPredBuf(const CompArea &blk) const { return getBu + PelUnitBuf Picture::getPredBuf(const UnitArea &unit) { return getBuf(unit, PIC_PREDICTION); } + const CPelUnitBuf Picture::getPredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_PREDICTION); } + ++#if DATA_PREDICTION ++ PelBuf Picture::getTruePredBuf(const ComponentID compID, bool wrap) { return getBuf(compID, PIC_TRUE_PREDICTION); } ++ PelBuf Picture::getTruePredBuf(const CompArea &blk) { return getBuf(blk, PIC_TRUE_PREDICTION); } ++const CPelBuf Picture::getTruePredBuf(const CompArea &blk) const { return getBuf(blk, PIC_TRUE_PREDICTION); } ++ PelUnitBuf Picture::getTruePredBuf(const UnitArea &unit) { return getBuf(unit, PIC_TRUE_PREDICTION); } ++const CPelUnitBuf Picture::getTruePredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_TRUE_PREDICTION); } ++#endif ++ + PelBuf Picture::getResiBuf(const CompArea &blk) { return getBuf(blk, PIC_RESIDUAL); } + const CPelBuf Picture::getResiBuf(const CompArea &blk) const { return getBuf(blk, PIC_RESIDUAL); } + PelUnitBuf Picture::getResiBuf(const UnitArea &unit) { return getBuf(unit, PIC_RESIDUAL); } +diff --git a/source/Lib/CommonLib/Picture.h b/source/Lib/CommonLib/Picture.h +index 66073bf6..b48a6099 100644 +--- a/source/Lib/CommonLib/Picture.h ++++ b/source/Lib/CommonLib/Picture.h +@@ -128,6 +128,14 @@ struct Picture : public UnitArea + PelUnitBuf getPredBuf(const UnitArea &unit); + const CPelUnitBuf getPredBuf(const UnitArea &unit) const; + ++#if DATA_PREDICTION ++ PelBuf getTruePredBuf(const ComponentID compID, bool wrap = false); ++ PelBuf getTruePredBuf(const CompArea &blk); ++ const CPelBuf getTruePredBuf(const CompArea &blk) const; ++ PelUnitBuf getTruePredBuf(const UnitArea &unit); ++ const CPelUnitBuf getTruePredBuf(const UnitArea &unit) const; ++#endif ++ + PelBuf getResiBuf(const CompArea &blk); + const CPelBuf getResiBuf(const CompArea &blk) const; + PelUnitBuf getResiBuf(const UnitArea &unit); +diff --git a/source/Lib/CommonLib/Rom.cpp b/source/Lib/CommonLib/Rom.cpp +index dc1c29ae..28ad2c4f 100644 +--- a/source/Lib/CommonLib/Rom.cpp ++++ b/source/Lib/CommonLib/Rom.cpp +@@ -53,6 +53,11 @@ CDTrace *g_trace_ctx = NULL; + #endif + bool g_mctsDecCheckEnabled = false; + ++#if DATA_GEN_DEC ++unsigned int global_cnt = 0; ++char global_str_name[200]; ++#endif ++ + //! \ingroup CommonLib + //! \{ + +diff --git a/source/Lib/CommonLib/Rom.h b/source/Lib/CommonLib/Rom.h +index e7352e3c..4d1b38a1 100644 +--- a/source/Lib/CommonLib/Rom.h ++++ b/source/Lib/CommonLib/Rom.h +@@ -44,6 +44,10 @@ + #include <stdio.h> + #include <iostream> + ++#if DATA_GEN_DEC ++extern unsigned int global_cnt; ++extern char global_str_name[200]; ++#endif + + //! \ingroup CommonLib + //! \{ +diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h +index 8af59c7f..2874459a 100644 +--- a/source/Lib/CommonLib/TypeDef.h ++++ b/source/Lib/CommonLib/TypeDef.h +@@ -50,6 +50,9 @@ + #include <assert.h> + #include <cassert> + ++#define DATA_GEN_ENC 1 // Encode frame by RPR downsampling ++#define DATA_GEN_DEC 1 // Decode bin files to generate dataset, which should be turned off when running the encoder ++#define DATA_PREDICTION 1 // Prediction data + // clang-format off + + //########### place macros to be removed in next cycle below this line ############### +diff --git a/source/Lib/DecoderLib/DecCu.cpp b/source/Lib/DecoderLib/DecCu.cpp +index eeec3474..844c7aac 100644 +--- a/source/Lib/DecoderLib/DecCu.cpp ++++ b/source/Lib/DecoderLib/DecCu.cpp +@@ -182,6 +182,9 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID ) + const ChannelType chType = toChannelType( compID ); + + PelBuf piPred = cs.getPredBuf( area ); ++#if DATA_PREDICTION ++ PelBuf piPredTrue = cs.getTruePredBuf(area); ++#endif + + const PredictionUnit &pu = *tu.cs->getPU( area.pos(), chType ); + const uint32_t uiChFinalMode = PU::getFinalIntraMode( pu, chType ); +@@ -311,10 +314,15 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID ) + } + #if KEEP_PRED_AND_RESI_SIGNALS + pReco.reconstruct( piPred, piResi, tu.cu->cs->slice->clpRng( compID ) ); ++#else ++#if DATA_PREDICTION ++ piPredTrue.copyFrom(piPred); ++ pReco.reconstruct(piPred, piResi, tu.cu->cs->slice->clpRng(compID)); + #else + piPred.reconstruct( piPred, piResi, tu.cu->cs->slice->clpRng( compID ) ); + #endif +-#if !KEEP_PRED_AND_RESI_SIGNALS ++#endif ++#if !KEEP_PRED_AND_RESI_SIGNALS && !DATA_PREDICTION + pReco.copyFrom( piPred ); + #endif + if (slice.getLmcsEnabledFlag() && (m_pcReshape->getCTUFlag() || slice.isIntra()) && compID == COMPONENT_Y) +@@ -684,6 +692,10 @@ void DecCu::xReconInter(CodingUnit &cu) + DTRACE ( g_trace_ctx, D_TMP, "pred " ); + DTRACE_CRC( g_trace_ctx, D_TMP, *cu.cs, cu.cs->getPredBuf( cu ), &cu.Y() ); + ++#if DATA_PREDICTION ++ cu.cs->getTruePredBuf(cu).copyFrom(cu.cs->getPredBuf(cu)); ++#endif ++ + // inter recon + xDecodeInterTexture(cu); + +diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp +index bb5e51f6..f3287686 100644 +--- a/source/Lib/EncoderLib/EncLib.cpp ++++ b/source/Lib/EncoderLib/EncLib.cpp +@@ -657,6 +657,9 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu + } + #endif + ++#if DATA_GEN_ENC ++ ppsID = ENC_PPS_ID_RPR; ++#else + if( m_resChangeInClvsEnabled && m_intraPeriod == -1 ) + { + const int poc = m_iPOCLast + ( m_compositeRefEnabled ? 2 : 1 ); +@@ -675,6 +678,7 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu + { + ppsID = m_vps->getGeneralLayerIdx( m_layerId ); + } ++#endif + + xGetNewPicBuffer( rcListPicYuvRecOut, pcPicCurr, ppsID ); + +diff --git a/source/Lib/Utilities/VideoIOYuv.cpp b/source/Lib/Utilities/VideoIOYuv.cpp +index 8a30ccc5..7a271982 100644 +--- a/source/Lib/Utilities/VideoIOYuv.cpp ++++ b/source/Lib/Utilities/VideoIOYuv.cpp +@@ -1252,7 +1252,11 @@ void VideoIOYuv::ColourSpaceConvert(const CPelUnitBuf &src, PelUnitBuf &dest, co + } + } + +-bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice, ChromaFormat format, const bool bClipToRec709 ) ++bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice, ChromaFormat format, const bool bClipToRec709 ++#if DATA_GEN_DEC ++ , Picture* pcPic ++#endif ++) + { + ChromaFormat chromaFormatIDC = sps.getChromaFormatIdc(); + bool ret = false; +@@ -1284,6 +1288,82 @@ bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPe + int xScale = ( ( refPicWidth << SCALE_RATIO_BITS ) + ( curPicWidth >> 1 ) ) / curPicWidth; + int yScale = ( ( refPicHeight << SCALE_RATIO_BITS ) + ( curPicHeight >> 1 ) ) / curPicHeight; + ++#if DATA_GEN_DEC ++ if (pcPic->cs->slice->getSliceType() == B_SLICE) ++ { ++ PelStorage upscaledRPR; ++ upscaledRPR.create( chromaFormatIDC, Area( Position(), Size( sps.getMaxPicWidthInLumaSamples(), sps.getMaxPicHeightInLumaSamples() ) ) ); ++ Picture::rescalePicture( std::pair<int, int>( xScale, yScale ), pic, pps.getScalingWindow(), upscaledRPR, afterScaleWindowFullResolution, chromaFormatIDC, sps.getBitDepths(), false, false, sps.getHorCollocatedChromaFlag(), sps.getVerCollocatedChromaFlag() ); ++ ++ char rec_out_name[200]; ++ strcpy(rec_out_name, global_str_name); ++ sprintf(rec_out_name + strlen(rec_out_name), "_poc%03d_qp%d.yuv", pcPic->cs->slice->getPOC(), pcPic->cs->slice->getSliceQp()); ++ FILE* fp_rec = fopen(rec_out_name, "wb"); ++ ++ char pre_out_name[200]; ++ strcpy(pre_out_name, global_str_name); ++ sprintf(pre_out_name + strlen(pre_out_name), "_poc%03d_qp%d_prediction.yuv", pcPic->cs->slice->getPOC(), pcPic->cs->slice->getSliceQp()); ++ FILE* fp_pre = fopen(pre_out_name, "wb"); ++ ++ char rpr_out_name[200]; ++ strcpy(rpr_out_name, global_str_name); ++ sprintf(rpr_out_name + strlen(rpr_out_name), "_poc%03d_qp%d_rpr.yuv", pcPic->cs->slice->getPOC(), pcPic->cs->slice->getSliceQp()); ++ FILE* fp_rpr = fopen(rpr_out_name, "wb"); ++ ++ int8_t temp[2]; ++ ++ uint32_t curLumaH = pps.getPicHeightInLumaSamples(); ++ uint32_t curLumaW = pps.getPicWidthInLumaSamples(); ++ ++ uint32_t oriLumaH = sps.getMaxPicHeightInLumaSamples(); ++ uint32_t oriLumaW = sps.getMaxPicWidthInLumaSamples(); ++ ++ for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) ++ { ++ ComponentID compID = ComponentID(compIdx); ++ const int chromascaleY = getComponentScaleY(compID, pic.chromaFormat); ++ const int chromascaleX = getComponentScaleX(compID, pic.chromaFormat); ++ ++ uint32_t curPicH = curLumaH >> chromascaleY; ++ uint32_t curPicW = curLumaW >> chromascaleX; ++ ++ uint32_t oriPicH = oriLumaH >> chromascaleY; ++ uint32_t oriPicW = oriLumaW >> chromascaleX; ++ ++ for (uint32_t j = 0; j < curPicH; j++) ++ { ++ for (uint32_t i = 0; i < curPicW; i++) ++ { ++ temp[0] = (pic.get(compID).at(i, j) >> 0) & 0xff; ++ temp[1] = (pic.get(compID).at(i, j) >> 8) & 0xff; ++ ::fwrite(temp, sizeof(temp[0]), 2, fp_rec); ++ ++ CHECK(pic.get(compID).at(i, j) < 0 || pic.get(compID).at(i, j) > 1023, ""); ++ ++ temp[0] = (pcPic->getTruePredBuf(compID).at(i, j) >> 0) & 0xff; ++ temp[1] = (pcPic->getTruePredBuf(compID).at(i, j) >> 8) & 0xff; ++ ::fwrite(temp, sizeof(temp[0]), 2, fp_pre); ++ ++ CHECK(pcPic->getTruePredBuf(compID).at(i, j) < 0 || pcPic->getTruePredBuf(compID).at(i, j) > 1023, ""); ++ } ++ } ++ for (uint32_t j = 0; j < oriPicH; j++) ++ { ++ for (uint32_t i = 0; i < oriPicW; i++) ++ { ++ temp[0] = (upscaledRPR.get(compID).at(i, j) >> 0) & 0xff; ++ temp[1] = (upscaledRPR.get(compID).at(i, j) >> 8) & 0xff; ++ ::fwrite(temp, sizeof(temp[0]), 2, fp_rpr); ++ ++ CHECK(upscaledRPR.get(compID).at(i, j) < 0 || upscaledRPR.get(compID).at(i, j) > 1023, ""); ++ } ++ } ++ } ++ ::fclose(fp_rec); ++ ::fclose(fp_pre); ++ ::fclose(fp_rpr); ++ } ++#endif + Picture::rescalePicture( std::pair<int, int>( xScale, yScale ), pic, pps.getScalingWindow(), upscaledPic, afterScaleWindowFullResolution, chromaFormatIDC, sps.getBitDepths(), false, false, sps.getHorCollocatedChromaFlag(), sps.getVerCollocatedChromaFlag() ); + + ret = write( sps.getMaxPicWidthInLumaSamples(), sps.getMaxPicHeightInLumaSamples(), upscaledPic, +diff --git a/source/Lib/Utilities/VideoIOYuv.h b/source/Lib/Utilities/VideoIOYuv.h +index bf2c4705..e4baec31 100644 +--- a/source/Lib/Utilities/VideoIOYuv.h ++++ b/source/Lib/Utilities/VideoIOYuv.h +@@ -101,7 +101,11 @@ public: + int getFileBitdepth( int ch ) { return m_fileBitdepth[ch]; } + + bool writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, +- const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice = 0, ChromaFormat format = NUM_CHROMA_FORMAT, const bool bClipToRec709 = false ); ///< write one upsaled YUV frame ++ const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice = 0, ChromaFormat format = NUM_CHROMA_FORMAT, const bool bClipToRec709 = false ++#if DATA_GEN_DEC ++ , Picture* pcPic = nullptr ++#endif ++ ); ///< write one upsaled YUV frame + + }; + +-- +2.34.0.windows.1 + diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_I_slices.patch b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_I_slices.patch new file mode 100644 index 0000000000000000000000000000000000000000..df34598010965b2a58d63aa5e773d39b8a4429bd --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/Generate_SR_datasets_for_I_slices.patch @@ -0,0 +1,569 @@ +From 10295eea930cf3be502f5c279618313443b2a5b6 Mon Sep 17 00:00:00 2001 +From: renjiechang <renjiechang@tencent.com> +Date: Tue, 14 Feb 2023 14:59:10 +0800 +Subject: [PATCH] Generate SR datasets for I slices + +--- + source/App/DecoderApp/DecApp.cpp | 26 ++++++- + source/App/DecoderApp/DecAppCfg.cpp | 4 ++ + source/App/EncoderApp/EncAppCfg.cpp | 8 +++ + source/Lib/CommonLib/CodingStructure.cpp | 48 +++++++++++++ + source/Lib/CommonLib/CodingStructure.h | 17 +++++ + source/Lib/CommonLib/Picture.cpp | 19 +++++ + source/Lib/CommonLib/Picture.h | 8 +++ + source/Lib/CommonLib/Rom.cpp | 5 ++ + source/Lib/CommonLib/Rom.h | 4 ++ + source/Lib/CommonLib/TypeDef.h | 3 + + source/Lib/DecoderLib/DecCu.cpp | 14 +++- + source/Lib/EncoderLib/EncLib.cpp | 4 ++ + source/Lib/Utilities/VideoIOYuv.cpp | 90 +++++++++++++++++++++++- + source/Lib/Utilities/VideoIOYuv.h | 6 +- + 14 files changed, 251 insertions(+), 5 deletions(-) + +diff --git a/source/App/DecoderApp/DecApp.cpp b/source/App/DecoderApp/DecApp.cpp +index 85f63bb0..1842d346 100644 +--- a/source/App/DecoderApp/DecApp.cpp ++++ b/source/App/DecoderApp/DecApp.cpp +@@ -88,6 +88,17 @@ uint32_t DecApp::decode() + EXIT( "Failed to open bitstream file " << m_bitstreamFileName.c_str() << " for reading" ) ; + } + ++#if DATA_GEN_DEC ++ strcpy(global_str_name, m_bitstreamFileName.c_str()); ++ { ++ size_t len = strlen(global_str_name); ++ for (size_t i = len - 1; i > len - 5; i--) ++ { ++ global_str_name[i] = 0; ++ } ++ } ++#endif ++ + InputByteStream bytestream(bitstreamFile); + + if (!m_outputDecodedSEIMessagesFilename.empty() && m_outputDecodedSEIMessagesFilename!="-") +@@ -678,7 +689,14 @@ void DecApp::xWriteOutput( PicList* pcListPic, uint32_t tId ) + ChromaFormat chromaFormatIDC = sps->getChromaFormatIdc(); + if( m_upscaledOutput ) + { +- m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ); ++ m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ++#if DATA_GEN_DEC ++ , pcPic ++#endif ++ ); ++#if DATA_PREDICTION ++ pcPic->m_bufs[PIC_TRUE_PREDICTION].destroy(); ++#endif + } + else + { +@@ -825,7 +843,11 @@ void DecApp::xFlushOutput( PicList* pcListPic, const int layerId ) + ChromaFormat chromaFormatIDC = sps->getChromaFormatIdc(); + if( m_upscaledOutput ) + { +- m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ); ++ m_cVideoIOYuvReconFile[pcPic->layerId].writeUpscaledPicture( *sps, *pcPic->cs->pps, pcPic->getRecoBuf(), m_outputColourSpaceConvert, m_packedYUVMode, m_upscaledOutput, NUM_CHROMA_FORMAT, m_bClipOutputVideoToRec709Range ++#if DATA_GEN_DEC ++ , pcPic ++#endif ++ ); + } + else + { +diff --git a/source/App/DecoderApp/DecAppCfg.cpp b/source/App/DecoderApp/DecAppCfg.cpp +index d96c2049..ad912462 100644 +--- a/source/App/DecoderApp/DecAppCfg.cpp ++++ b/source/App/DecoderApp/DecAppCfg.cpp +@@ -124,7 +124,11 @@ bool DecAppCfg::parseCfg( int argc, char* argv[] ) + #endif + ("MCTSCheck", m_mctsCheck, false, "If enabled, the decoder checks for violations of mc_exact_sample_value_match_flag in Temporal MCTS ") + ("targetSubPicIdx", m_targetSubPicIdx, 0, "Specify which subpicture shall be written to output, using subpic index, 0: disabled, subpicIdx=m_targetSubPicIdx-1 \n" ) ++#if DATA_GEN_DEC ++ ( "UpscaledOutput", m_upscaledOutput, 2, "Upscaled output for RPR" ) ++#else + ( "UpscaledOutput", m_upscaledOutput, 0, "Upscaled output for RPR" ) ++#endif + ; + + po::setDefaults(opts); +diff --git a/source/App/EncoderApp/EncAppCfg.cpp b/source/App/EncoderApp/EncAppCfg.cpp +index b38001eb..c43cd0e1 100644 +--- a/source/App/EncoderApp/EncAppCfg.cpp ++++ b/source/App/EncoderApp/EncAppCfg.cpp +@@ -1411,11 +1411,19 @@ bool EncAppCfg::parseCfg( int argc, char* argv[] ) + ( "CCALF", m_ccalf, true, "Cross-component Adaptive Loop Filter" ) + ( "CCALFQpTh", m_ccalfQpThreshold, 37, "QP threshold above which encoder reduces CCALF usage") + ( "RPR", m_rprEnabledFlag, true, "Reference Sample Resolution" ) ++#if DATA_GEN_ENC ++ ( "ScalingRatioHor", m_scalingRatioHor, 2.0, "Scaling ratio in hor direction" ) ++ ( "ScalingRatioVer", m_scalingRatioVer, 2.0, "Scaling ratio in ver direction" ) ++ ( "FractionNumFrames", m_fractionOfFrames, 1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" ) ++ ( "SwitchPocPeriod", m_switchPocPeriod, 0, "Switch POC period for RPR" ) ++ ( "UpscaledOutput", m_upscaledOutput, 2, "Output upscaled (2), decoded but in full resolution buffer (1) or decoded cropped (0, default) picture for RPR" ) ++#else + ( "ScalingRatioHor", m_scalingRatioHor, 1.0, "Scaling ratio in hor direction" ) + ( "ScalingRatioVer", m_scalingRatioVer, 1.0, "Scaling ratio in ver direction" ) + ( "FractionNumFrames", m_fractionOfFrames, 1.0, "Encode a fraction of the specified in FramesToBeEncoded frames" ) + ( "SwitchPocPeriod", m_switchPocPeriod, 0, "Switch POC period for RPR" ) + ( "UpscaledOutput", m_upscaledOutput, 0, "Output upscaled (2), decoded but in full resolution buffer (1) or decoded cropped (0, default) picture for RPR" ) ++#endif + ( "MaxLayers", m_maxLayers, 1, "Max number of layers" ) + #if JVET_S0163_ON_TARGETOLS_SUBLAYERS + ( "EnableOperatingPointInformation", m_OPIEnabled, false, "Enables writing of Operating Point Information (OPI)" ) +diff --git a/source/Lib/CommonLib/CodingStructure.cpp b/source/Lib/CommonLib/CodingStructure.cpp +index b655d445..15e542ba 100644 +--- a/source/Lib/CommonLib/CodingStructure.cpp ++++ b/source/Lib/CommonLib/CodingStructure.cpp +@@ -107,6 +107,9 @@ void CodingStructure::destroy() + parent = nullptr; + + m_pred.destroy(); ++#if DATA_PREDICTION ++ m_predTrue.destroy(); ++#endif + m_resi.destroy(); + m_reco.destroy(); + m_orgr.destroy(); +@@ -895,6 +898,9 @@ void CodingStructure::create(const ChromaFormat &_chromaFormat, const Area& _are + + m_reco.create( area ); + m_pred.create( area ); ++#if DATA_PREDICTION ++ m_predTrue.create( area ); ++#endif + m_resi.create( area ); + m_orgr.create( area ); + } +@@ -910,6 +916,9 @@ void CodingStructure::create(const UnitArea& _unit, const bool isTopLayer, const + + m_reco.create( area ); + m_pred.create( area ); ++#if DATA_PREDICTION ++ m_predTrue.create( area ); ++#endif + m_resi.create( area ); + m_orgr.create( area ); + } +@@ -1082,6 +1091,16 @@ void CodingStructure::rebindPicBufs() + { + m_pred.destroy(); + } ++#if DATA_PREDICTION ++ if (!picture->M_BUFS(0, PIC_TRUE_PREDICTION).bufs.empty()) ++ { ++ m_predTrue.createFromBuf(picture->M_BUFS(0, PIC_TRUE_PREDICTION)); ++ } ++ else ++ { ++ m_predTrue.destroy(); ++ } ++#endif + if (!picture->M_BUFS(0, PIC_RESIDUAL).bufs.empty()) + { + m_resi.createFromBuf(picture->M_BUFS(0, PIC_RESIDUAL)); +@@ -1240,12 +1259,20 @@ void CodingStructure::useSubStructure( const CodingStructure& subStruct, const C + if( parent ) + { + // copy data to picture ++#if DATA_PREDICTION ++ getTruePredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); ++ getPredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); ++#endif + if( cpyPred ) getPredBuf ( clippedArea ).copyFrom( subPredBuf ); + if( cpyResi ) getResiBuf ( clippedArea ).copyFrom( subResiBuf ); + if( cpyReco ) getRecoBuf ( clippedArea ).copyFrom( subRecoBuf ); + if( cpyOrgResi ) getOrgResiBuf( clippedArea ).copyFrom( subStruct.getOrgResiBuf( clippedArea ) ); + } + ++#if DATA_PREDICTION ++ picture->getTruePredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); ++ picture->getPredBuf(clippedArea).copyFrom(subStruct.getPredBuf(clippedArea)); ++#endif + if( cpyPred ) picture->getPredBuf( clippedArea ).copyFrom( subPredBuf ); + if( cpyResi ) picture->getResiBuf( clippedArea ).copyFrom( subResiBuf ); + if( cpyReco ) picture->getRecoBuf( clippedArea ).copyFrom( subRecoBuf ); +@@ -1562,6 +1589,13 @@ const CPelBuf CodingStructure::getPredBuf(const CompArea &blk) const { r + PelUnitBuf CodingStructure::getPredBuf(const UnitArea &unit) { return getBuf(unit, PIC_PREDICTION); } + const CPelUnitBuf CodingStructure::getPredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_PREDICTION); } + ++#if DATA_PREDICTION ++ PelBuf CodingStructure::getTruePredBuf(const CompArea &blk) { return getBuf(blk, PIC_TRUE_PREDICTION); } ++const CPelBuf CodingStructure::getTruePredBuf(const CompArea &blk) const { return getBuf(blk, PIC_TRUE_PREDICTION); } ++ PelUnitBuf CodingStructure::getTruePredBuf(const UnitArea &unit) { return getBuf(unit, PIC_TRUE_PREDICTION); } ++const CPelUnitBuf CodingStructure::getTruePredBuf(const UnitArea &unit)const { return getBuf(unit, PIC_TRUE_PREDICTION); } ++#endif ++ + PelBuf CodingStructure::getResiBuf(const CompArea &blk) { return getBuf(blk, PIC_RESIDUAL); } + const CPelBuf CodingStructure::getResiBuf(const CompArea &blk) const { return getBuf(blk, PIC_RESIDUAL); } + PelUnitBuf CodingStructure::getResiBuf(const UnitArea &unit) { return getBuf(unit, PIC_RESIDUAL); } +@@ -1603,6 +1637,13 @@ PelBuf CodingStructure::getBuf( const CompArea &blk, const PictureType &type ) + + PelStorage* buf = type == PIC_PREDICTION ? &m_pred : ( type == PIC_RESIDUAL ? &m_resi : ( type == PIC_RECONSTRUCTION ? &m_reco : ( type == PIC_ORG_RESI ? &m_orgr : nullptr ) ) ); + ++#if DATA_PREDICTION ++ if (type == PIC_TRUE_PREDICTION) ++ { ++ buf = &m_predTrue; ++ } ++#endif ++ + CHECK( !buf, "Unknown buffer requested" ); + + CHECKD( !area.blocks[compID].contains( blk ), "Buffer not contained in self requested" ); +@@ -1637,6 +1678,13 @@ const CPelBuf CodingStructure::getBuf( const CompArea &blk, const PictureType &t + + const PelStorage* buf = type == PIC_PREDICTION ? &m_pred : ( type == PIC_RESIDUAL ? &m_resi : ( type == PIC_RECONSTRUCTION ? &m_reco : ( type == PIC_ORG_RESI ? &m_orgr : nullptr ) ) ); + ++#if DATA_PREDICTION ++ if (type == PIC_TRUE_PREDICTION) ++ { ++ buf = &m_predTrue; ++ } ++#endif ++ + CHECK( !buf, "Unknown buffer requested" ); + + CHECKD( !area.blocks[compID].contains( blk ), "Buffer not contained in self requested" ); +diff --git a/source/Lib/CommonLib/CodingStructure.h b/source/Lib/CommonLib/CodingStructure.h +index b5ae7ac6..cdd3fbf1 100644 +--- a/source/Lib/CommonLib/CodingStructure.h ++++ b/source/Lib/CommonLib/CodingStructure.h +@@ -62,6 +62,9 @@ enum PictureType + PIC_ORIGINAL_INPUT, + PIC_TRUE_ORIGINAL_INPUT, + PIC_FILTERED_ORIGINAL_INPUT, ++#if DATA_PREDICTION ++ PIC_TRUE_PREDICTION, ++#endif + NUM_PIC_TYPES + }; + extern XUCache g_globalUnitCache; +@@ -228,6 +231,9 @@ private: + std::vector<SAOBlkParam> m_sao; + + PelStorage m_pred; ++#if DATA_PREDICTION ++ PelStorage m_predTrue; ++#endif + PelStorage m_resi; + PelStorage m_reco; + PelStorage m_orgr; +@@ -268,6 +274,17 @@ public: + PelUnitBuf getPredBuf(const UnitArea &unit); + const CPelUnitBuf getPredBuf(const UnitArea &unit) const; + ++#if DATA_PREDICTION ++ PelBuf getTruePredBuf(const CompArea &blk); ++ const CPelBuf getTruePredBuf(const CompArea &blk) const; ++ PelUnitBuf getTruePredBuf(const UnitArea &unit); ++ const CPelUnitBuf getTruePredBuf(const UnitArea &unit) const; ++#endif ++ ++#if DATA_PREDICTION ++ PelUnitBuf getTruePredBuf() { return m_predTrue; } ++#endif ++ + PelBuf getResiBuf(const CompArea &blk); + const CPelBuf getResiBuf(const CompArea &blk) const; + PelUnitBuf getResiBuf(const UnitArea &unit); +diff --git a/source/Lib/CommonLib/Picture.cpp b/source/Lib/CommonLib/Picture.cpp +index a7205bad..d5d1400a 100644 +--- a/source/Lib/CommonLib/Picture.cpp ++++ b/source/Lib/CommonLib/Picture.cpp +@@ -277,6 +277,12 @@ void Picture::createTempBuffers( const unsigned _maxCUSize ) + { + M_BUFS( jId, PIC_PREDICTION ).create( chromaFormat, a, _maxCUSize ); + M_BUFS( jId, PIC_RESIDUAL ).create( chromaFormat, a, _maxCUSize ); ++ ++#if DATA_PREDICTION ++ const Area a_old(Position{ 0, 0 }, lumaSize()); ++ M_BUFS(jId, PIC_TRUE_PREDICTION).create(chromaFormat, a_old, _maxCUSize); ++#endif ++ + #if ENABLE_SPLIT_PARALLELISM + if (jId > 0) + { +@@ -305,6 +311,11 @@ void Picture::destroyTempBuffers() + { + M_BUFS(jId, t).destroy(); + } ++#if DATA_PREDICTION ++#if !DATA_GEN_DEC ++ if (t == PIC_TRUE_PREDICTION) M_BUFS(jId, t).destroy(); ++#endif ++#endif + #if ENABLE_SPLIT_PARALLELISM + if (t == PIC_RECONSTRUCTION && jId > 0) + { +@@ -344,6 +355,14 @@ const CPelBuf Picture::getPredBuf(const CompArea &blk) const { return getBu + PelUnitBuf Picture::getPredBuf(const UnitArea &unit) { return getBuf(unit, PIC_PREDICTION); } + const CPelUnitBuf Picture::getPredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_PREDICTION); } + ++#if DATA_PREDICTION ++ PelBuf Picture::getTruePredBuf(const ComponentID compID, bool wrap) { return getBuf(compID, PIC_TRUE_PREDICTION); } ++ PelBuf Picture::getTruePredBuf(const CompArea &blk) { return getBuf(blk, PIC_TRUE_PREDICTION); } ++const CPelBuf Picture::getTruePredBuf(const CompArea &blk) const { return getBuf(blk, PIC_TRUE_PREDICTION); } ++ PelUnitBuf Picture::getTruePredBuf(const UnitArea &unit) { return getBuf(unit, PIC_TRUE_PREDICTION); } ++const CPelUnitBuf Picture::getTruePredBuf(const UnitArea &unit) const { return getBuf(unit, PIC_TRUE_PREDICTION); } ++#endif ++ + PelBuf Picture::getResiBuf(const CompArea &blk) { return getBuf(blk, PIC_RESIDUAL); } + const CPelBuf Picture::getResiBuf(const CompArea &blk) const { return getBuf(blk, PIC_RESIDUAL); } + PelUnitBuf Picture::getResiBuf(const UnitArea &unit) { return getBuf(unit, PIC_RESIDUAL); } +diff --git a/source/Lib/CommonLib/Picture.h b/source/Lib/CommonLib/Picture.h +index 66073bf6..b48a6099 100644 +--- a/source/Lib/CommonLib/Picture.h ++++ b/source/Lib/CommonLib/Picture.h +@@ -128,6 +128,14 @@ struct Picture : public UnitArea + PelUnitBuf getPredBuf(const UnitArea &unit); + const CPelUnitBuf getPredBuf(const UnitArea &unit) const; + ++#if DATA_PREDICTION ++ PelBuf getTruePredBuf(const ComponentID compID, bool wrap = false); ++ PelBuf getTruePredBuf(const CompArea &blk); ++ const CPelBuf getTruePredBuf(const CompArea &blk) const; ++ PelUnitBuf getTruePredBuf(const UnitArea &unit); ++ const CPelUnitBuf getTruePredBuf(const UnitArea &unit) const; ++#endif ++ + PelBuf getResiBuf(const CompArea &blk); + const CPelBuf getResiBuf(const CompArea &blk) const; + PelUnitBuf getResiBuf(const UnitArea &unit); +diff --git a/source/Lib/CommonLib/Rom.cpp b/source/Lib/CommonLib/Rom.cpp +index dc1c29ae..28ad2c4f 100644 +--- a/source/Lib/CommonLib/Rom.cpp ++++ b/source/Lib/CommonLib/Rom.cpp +@@ -53,6 +53,11 @@ CDTrace *g_trace_ctx = NULL; + #endif + bool g_mctsDecCheckEnabled = false; + ++#if DATA_GEN_DEC ++unsigned int global_cnt = 0; ++char global_str_name[200]; ++#endif ++ + //! \ingroup CommonLib + //! \{ + +diff --git a/source/Lib/CommonLib/Rom.h b/source/Lib/CommonLib/Rom.h +index e7352e3c..4d1b38a1 100644 +--- a/source/Lib/CommonLib/Rom.h ++++ b/source/Lib/CommonLib/Rom.h +@@ -44,6 +44,10 @@ + #include <stdio.h> + #include <iostream> + ++#if DATA_GEN_DEC ++extern unsigned int global_cnt; ++extern char global_str_name[200]; ++#endif + + //! \ingroup CommonLib + //! \{ +diff --git a/source/Lib/CommonLib/TypeDef.h b/source/Lib/CommonLib/TypeDef.h +index 8af59c7f..2874459a 100644 +--- a/source/Lib/CommonLib/TypeDef.h ++++ b/source/Lib/CommonLib/TypeDef.h +@@ -50,6 +50,9 @@ + #include <assert.h> + #include <cassert> + ++#define DATA_GEN_ENC 1 // Encode frame by RPR downsampling ++#define DATA_GEN_DEC 1 // Decode bin files to generate dataset, which should be turned off when running the encoder ++#define DATA_PREDICTION 1 // Prediction data + // clang-format off + + //########### place macros to be removed in next cycle below this line ############### +diff --git a/source/Lib/DecoderLib/DecCu.cpp b/source/Lib/DecoderLib/DecCu.cpp +index eeec3474..844c7aac 100644 +--- a/source/Lib/DecoderLib/DecCu.cpp ++++ b/source/Lib/DecoderLib/DecCu.cpp +@@ -182,6 +182,9 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID ) + const ChannelType chType = toChannelType( compID ); + + PelBuf piPred = cs.getPredBuf( area ); ++#if DATA_PREDICTION ++ PelBuf piPredTrue = cs.getTruePredBuf(area); ++#endif + + const PredictionUnit &pu = *tu.cs->getPU( area.pos(), chType ); + const uint32_t uiChFinalMode = PU::getFinalIntraMode( pu, chType ); +@@ -311,10 +314,15 @@ void DecCu::xIntraRecBlk( TransformUnit& tu, const ComponentID compID ) + } + #if KEEP_PRED_AND_RESI_SIGNALS + pReco.reconstruct( piPred, piResi, tu.cu->cs->slice->clpRng( compID ) ); ++#else ++#if DATA_PREDICTION ++ piPredTrue.copyFrom(piPred); ++ pReco.reconstruct(piPred, piResi, tu.cu->cs->slice->clpRng(compID)); + #else + piPred.reconstruct( piPred, piResi, tu.cu->cs->slice->clpRng( compID ) ); + #endif +-#if !KEEP_PRED_AND_RESI_SIGNALS ++#endif ++#if !KEEP_PRED_AND_RESI_SIGNALS && !DATA_PREDICTION + pReco.copyFrom( piPred ); + #endif + if (slice.getLmcsEnabledFlag() && (m_pcReshape->getCTUFlag() || slice.isIntra()) && compID == COMPONENT_Y) +@@ -684,6 +692,10 @@ void DecCu::xReconInter(CodingUnit &cu) + DTRACE ( g_trace_ctx, D_TMP, "pred " ); + DTRACE_CRC( g_trace_ctx, D_TMP, *cu.cs, cu.cs->getPredBuf( cu ), &cu.Y() ); + ++#if DATA_PREDICTION ++ cu.cs->getTruePredBuf(cu).copyFrom(cu.cs->getPredBuf(cu)); ++#endif ++ + // inter recon + xDecodeInterTexture(cu); + +diff --git a/source/Lib/EncoderLib/EncLib.cpp b/source/Lib/EncoderLib/EncLib.cpp +index bb5e51f6..f3287686 100644 +--- a/source/Lib/EncoderLib/EncLib.cpp ++++ b/source/Lib/EncoderLib/EncLib.cpp +@@ -657,6 +657,9 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu + } + #endif + ++#if DATA_GEN_ENC ++ ppsID = ENC_PPS_ID_RPR; ++#else + if( m_resChangeInClvsEnabled && m_intraPeriod == -1 ) + { + const int poc = m_iPOCLast + ( m_compositeRefEnabled ? 2 : 1 ); +@@ -675,6 +678,7 @@ bool EncLib::encodePrep( bool flush, PelStorage* pcPicYuvOrg, PelStorage* cPicYu + { + ppsID = m_vps->getGeneralLayerIdx( m_layerId ); + } ++#endif + + xGetNewPicBuffer( rcListPicYuvRecOut, pcPicCurr, ppsID ); + +diff --git a/source/Lib/Utilities/VideoIOYuv.cpp b/source/Lib/Utilities/VideoIOYuv.cpp +index 8a30ccc5..3ea4d985 100644 +--- a/source/Lib/Utilities/VideoIOYuv.cpp ++++ b/source/Lib/Utilities/VideoIOYuv.cpp +@@ -1252,7 +1252,11 @@ void VideoIOYuv::ColourSpaceConvert(const CPelUnitBuf &src, PelUnitBuf &dest, co + } + } + +-bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice, ChromaFormat format, const bool bClipToRec709 ) ++bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice, ChromaFormat format, const bool bClipToRec709 ++#if DATA_GEN_DEC ++ , Picture* pcPic ++#endif ++) + { + ChromaFormat chromaFormatIDC = sps.getChromaFormatIdc(); + bool ret = false; +@@ -1284,6 +1288,90 @@ bool VideoIOYuv::writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPe + int xScale = ( ( refPicWidth << SCALE_RATIO_BITS ) + ( curPicWidth >> 1 ) ) / curPicWidth; + int yScale = ( ( refPicHeight << SCALE_RATIO_BITS ) + ( curPicHeight >> 1 ) ) / curPicHeight; + ++#if DATA_GEN_DEC ++ if (pcPic->cs->slice->getSliceType() == I_SLICE) ++ { ++ PelStorage upscaledRPR; ++ upscaledRPR.create( chromaFormatIDC, Area( Position(), Size( sps.getMaxPicWidthInLumaSamples(), sps.getMaxPicHeightInLumaSamples() ) ) ); ++ Picture::rescalePicture( std::pair<int, int>( xScale, yScale ), pic, pps.getScalingWindow(), upscaledRPR, afterScaleWindowFullResolution, chromaFormatIDC, sps.getBitDepths(), false, false, sps.getHorCollocatedChromaFlag(), sps.getVerCollocatedChromaFlag() ); ++ ++ char rec_out_name[200]; ++ strcpy(rec_out_name, global_str_name); ++ sprintf(rec_out_name + strlen(rec_out_name), "_poc%03d.yuv", pcPic->cs->slice->getPOC()); ++ FILE* fp_rec = fopen(rec_out_name, "wb"); ++ ++#if DATA_PREDICTION ++ char pre_out_name[200]; ++ strcpy(pre_out_name, global_str_name); ++ sprintf(pre_out_name + strlen(pre_out_name), "_poc%03d_prediction.yuv", pcPic->cs->slice->getPOC()); ++ FILE* fp_pre = fopen(pre_out_name, "wb"); ++#endif ++ ++ char rpr_out_name[200]; ++ strcpy(rpr_out_name, global_str_name); ++ sprintf(rpr_out_name + strlen(rpr_out_name), "_poc%03d_rpr.yuv", pcPic->cs->slice->getPOC()); ++ FILE* fp_rpr = fopen(rpr_out_name, "wb"); ++ ++ int8_t temp[2]; ++ ++ uint32_t curLumaH = pps.getPicHeightInLumaSamples(); ++ uint32_t curLumaW = pps.getPicWidthInLumaSamples(); ++ ++ uint32_t oriLumaH = sps.getMaxPicHeightInLumaSamples(); ++ uint32_t oriLumaW = sps.getMaxPicWidthInLumaSamples(); ++ ++ for (int compIdx = 0; compIdx < MAX_NUM_COMPONENT; compIdx++) ++ { ++ ComponentID compID = ComponentID(compIdx); ++ const int chromascaleY = getComponentScaleY(compID, pic.chromaFormat); ++ const int chromascaleX = getComponentScaleX(compID, pic.chromaFormat); ++ ++ uint32_t curPicH = curLumaH >> chromascaleY; ++ uint32_t curPicW = curLumaW >> chromascaleX; ++ ++ uint32_t oriPicH = oriLumaH >> chromascaleY; ++ uint32_t oriPicW = oriLumaW >> chromascaleX; ++ ++ for (uint32_t j = 0; j < curPicH; j++) ++ { ++ for (uint32_t i = 0; i < curPicW; i++) ++ { ++ temp[0] = (pic.get(compID).at(i, j) >> 0) & 0xff; ++ temp[1] = (pic.get(compID).at(i, j) >> 8) & 0xff; ++ ::fwrite(temp, sizeof(temp[0]), 2, fp_rec); ++ ++ CHECK(pic.get(compID).at(i, j) < 0 || pic.get(compID).at(i, j) > 1023, ""); ++ ++#if DATA_PREDICTION ++ temp[0] = (pcPic->getTruePredBuf(compID).at(i, j) >> 0) & 0xff; ++ temp[1] = (pcPic->getTruePredBuf(compID).at(i, j) >> 8) & 0xff; ++ ::fwrite(temp, sizeof(temp[0]), 2, fp_pre); ++ ++ CHECK(pcPic->getTruePredBuf(compID).at(i, j) < 0 || pcPic->getTruePredBuf(compID).at(i, j) > 1023, ""); ++#endif ++ } ++ } ++ for (uint32_t j = 0; j < oriPicH; j++) ++ { ++ for (uint32_t i = 0; i < oriPicW; i++) ++ { ++ temp[0] = (upscaledRPR.get(compID).at(i, j) >> 0) & 0xff; ++ temp[1] = (upscaledRPR.get(compID).at(i, j) >> 8) & 0xff; ++ ::fwrite(temp, sizeof(temp[0]), 2, fp_rpr); ++ ++ CHECK(upscaledRPR.get(compID).at(i, j) < 0 || upscaledRPR.get(compID).at(i, j) > 1023, ""); ++ } ++ } ++ } ++ ::fclose(fp_rec); ++#if DATA_PREDICTION ++ ::fclose(fp_pre); ++#endif ++ ::fclose(fp_rpr); ++ ++ global_cnt++; ++ } ++#endif + Picture::rescalePicture( std::pair<int, int>( xScale, yScale ), pic, pps.getScalingWindow(), upscaledPic, afterScaleWindowFullResolution, chromaFormatIDC, sps.getBitDepths(), false, false, sps.getHorCollocatedChromaFlag(), sps.getVerCollocatedChromaFlag() ); + + ret = write( sps.getMaxPicWidthInLumaSamples(), sps.getMaxPicHeightInLumaSamples(), upscaledPic, +diff --git a/source/Lib/Utilities/VideoIOYuv.h b/source/Lib/Utilities/VideoIOYuv.h +index bf2c4705..e4baec31 100644 +--- a/source/Lib/Utilities/VideoIOYuv.h ++++ b/source/Lib/Utilities/VideoIOYuv.h +@@ -101,7 +101,11 @@ public: + int getFileBitdepth( int ch ) { return m_fileBitdepth[ch]; } + + bool writeUpscaledPicture( const SPS& sps, const PPS& pps, const CPelUnitBuf& pic, +- const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice = 0, ChromaFormat format = NUM_CHROMA_FORMAT, const bool bClipToRec709 = false ); ///< write one upsaled YUV frame ++ const InputColourSpaceConversion ipCSC, const bool bPackedYUVOutputMode, int outputChoice = 0, ChromaFormat format = NUM_CHROMA_FORMAT, const bool bClipToRec709 = false ++#if DATA_GEN_DEC ++ , Picture* pcPic = nullptr ++#endif ++ ); ///< write one upsaled YUV frame + + }; + +-- +2.34.0.windows.1 + diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/bvi_dvc_codec_info.py b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/bvi_dvc_codec_info.py new file mode 100644 index 0000000000000000000000000000000000000000..e91bee4f21b8f1fea64f4fe84ef3bbd3abd455dd --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/bvi_dvc_codec_info.py @@ -0,0 +1,217 @@ +SequenceTable = [ + ['B_S001', 'BAdvertisingMassagesBangkokVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S002', 'BAmericanFootballS2Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S003', 'BAmericanFootballS3Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S004', 'BAmericanFootballS4Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S005', 'BAnimalsS11Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S006', 'BAnimalsS1Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S007', 'BBangkokMarketVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S008', 'BBasketballGoalScoredS1Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S009', 'BBasketballGoalScoredS2Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S010', 'BBasketballS1YonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S011', 'BBasketballS2YonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S012', 'BBasketballS3YonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S013', 'BBoatsChaoPhrayaRiverVidevo_1920x1088_23fps_10bit_420.yuv', 1920, 1088, 0, 64, 23, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S014', 'BBobbleheadBVIHFR_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S015', 'BBookcaseBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S016', 'BBoxingPracticeHarmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S017', 'BBricksBushesStaticBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S018', 'BBricksLeavesBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S019', 'BBricksTiltingBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S020', 'BBubblesPitcherS1BVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S021', 'BBuildingRoofS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S022', 'BBuildingRoofS2IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S023', 'BBuildingRoofS3IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S024', 'BBuildingRoofS4IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S025', 'BBuntingHangingAcrossHongKongVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S026', 'BBusyHongKongStreetVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S027', 'BCalmingWaterBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S028', 'BCarpetPanAverageBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S029', 'BCatchBVIHFR_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S030', 'BCeramicsandSpicesMoroccoVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S031', 'BCharactersYonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S032', 'BChristmasPresentsIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S033', 'BChristmasRoomDareful_1920x1088_29fps_10bit_420.yuv', 1920, 1088, 0, 64, 29, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S034', 'BChurchInsideMCLJCV_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S035', 'BCityScapesS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S036', 'BCityScapesS2IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S037', 'BCityScapesS3IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S038', 'BCityStreetS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S039', 'BCityStreetS3IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S040', 'BCityStreetS4IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S041', 'BCityStreetS5IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S042', 'BCityStreetS6IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S043', 'BCityStreetS7IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S044', 'BCloseUpBasketballSceneVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S045', 'BCloudsStaticBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S046', 'BColourfulDecorationWatPhoVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S047', 'BColourfulKoreanLanternsVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S048', 'BColourfulPaperLanternsVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S049', 'BColourfulRugsMoroccoVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S050', 'BConstructionS2YonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S051', 'BCostaRicaS3Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S052', 'BCrosswalkHarmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S053', 'BCrosswalkHongKong2S1Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S054', 'BCrosswalkHongKong2S2Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S055', 'BCrosswalkHongKongVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S056', 'BCrowdRunMCLV_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S057', 'BCyclistS1BVIHFR_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S058', 'BCyclistVeniceBeachBoardwalkVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S059', 'BDollsScene1YonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S060', 'BDollsScene2YonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S061', 'BDowntownHongKongVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S062', 'BDrivingPOVHarmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S063', 'BDropsOnWaterBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S064', 'BElFuenteMaskLIVENetFlix_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S065', 'BEnteringHongKongStallS1Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S066', 'BEnteringHongKongStallS2Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S067', 'BFerrisWheelTurningVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S068', 'BFireS18Mitch_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S069', 'BFireS21Mitch_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S070', 'BFireS71Mitch_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S071', 'BFirewoodS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S072', 'BFirewoodS2IRIS_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S073', 'BFitnessIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S074', 'BFjordsS1Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S075', 'BFlagShootTUMSVT_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S076', 'BFlowerChapelS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S077', 'BFlowerChapelS2IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S078', 'BFlyingCountrysideDareful_1920x1088_29fps_10bit_420.yuv', 1920, 1088, 0, 64, 29, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S079', 'BFlyingMountainsDareful_1920x1088_29fps_10bit_420.yuv', 1920, 1088, 0, 64, 29, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S080', 'BFlyingThroughLAStreetVidevo_1920x1088_23fps_10bit_420.yuv', 1920, 1088, 0, 64, 23, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S081', 'BFungusZoomBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S082', 'BGrassBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S083', 'BGrazTowerIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S084', 'BHamsterBVIHFR_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S085', 'BHarleyDavidsonIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S086', 'BHongKongIslandVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S087', 'BHongKongMarket1Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S088', 'BHongKongMarket2Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S089', 'BHongKongMarket3S1Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S090', 'BHongKongMarket3S2Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S091', 'BHongKongMarket4S1Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S092', 'BHongKongMarket4S2Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S093', 'BHongKongS1Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S094', 'BHongKongS2Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S095', 'BHongKongS3Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S096', 'BHorseDrawnCarriagesVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S097', 'BHorseStaringS1Videvo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S098', 'BHorseStaringS2Videvo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S099', 'BJockeyHarmonics_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S100', 'BJoggersS1BVIHFR_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S101', 'BJoggersS2BVIHFR_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S102', 'BKartingIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S103', 'BKoraDrumsVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S104', 'BLakeYonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S105', 'BLampLeavesBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S106', 'BLaundryHangingOverHongKongVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S107', 'BLeaves1BVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S108', 'BLeaves3BVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S109', 'BLowLevelShotAlongHongKongVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S110', 'BLungshanTempleS1Videvo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S111', 'BLungshanTempleS2Videvo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S112', 'BManMoTempleVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S113', 'BManStandinginProduceTruckVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S114', 'BManWalkingThroughBangkokVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S115', 'BMaplesS1YonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S116', 'BMaplesS2YonseiUniversity_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S117', 'BMirabellParkS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S118', 'BMirabellParkS2IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S119', 'BMobileHarmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S120', 'BMoroccanCeramicsShopVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S121', 'BMoroccanSlippersVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S122', 'BMuralPaintingVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S123', 'BMyanmarS4Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S124', 'BMyanmarS6Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S125', 'BMyeongDongVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S126', 'BNewYorkStreetDareful_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S127', 'BOrangeBuntingoverHongKongVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S128', 'BPaintingTiltingBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S129', 'BParkViolinMCLJCV_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S130', 'BPedestriansSeoulatDawnVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S131', 'BPeopleWalkingS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S132', 'BPersonRunningOutsideVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S133', 'BPillowsTransBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S134', 'BPlasmaFreeBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S135', 'BPresentsChristmasTreeDareful_1920x1088_29fps_10bit_420.yuv', 1920, 1088, 0, 64, 29, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S136', 'BReadySetGoS2TampereUniversity_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S137', 'BResidentialBuildingSJTU_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S138', 'BRollerCoaster2Netflix_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S139', 'BRunnersSJTU_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S140', 'BRuralSetupIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S141', 'BRuralSetupS2IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S142', 'BScarfSJTU_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S143', 'BSeasideWalkIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S144', 'BSeekingMCLV_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S145', 'BSeoulCanalatDawnVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S146', 'BShoppingCentreVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S147', 'BSignboardBoatLIVENetFlix_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S148', 'BSkyscraperBangkokVidevo_1920x1088_23fps_10bit_420.yuv', 1920, 1088, 0, 64, 23, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S149', 'BSmokeClearBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S150', 'BSmokeS45Mitch_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S151', 'BSparklerBVIHFR_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S152', 'BSquareAndTimelapseHarmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S153', 'BSquareS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S154', 'BSquareS2IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S155', 'BStreetArtVidevo_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S156', 'BStreetDancerS1IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S157', 'BStreetDancerS2IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S158', 'BStreetDancerS3IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S159', 'BStreetDancerS4IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S160', 'BStreetDancerS5IRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S161', 'BStreetsOfIndiaS1Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S162', 'BStreetsOfIndiaS2Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S163', 'BStreetsOfIndiaS3Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S164', 'BTaiChiHongKongS1Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S165', 'BTaiChiHongKongS2Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S166', 'BTaipeiCityRooftops8Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S167', 'BTaipeiCityRooftopsS1Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S168', 'BTaipeiCityRooftopsS2Videvo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S169', 'BTaksinBridgeVidevo_1920x1088_23fps_10bit_420.yuv', 1920, 1088, 0, 64, 23, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S170', 'BTallBuildingsSJTU_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S171', 'BTennisMCLV_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S172', 'BToddlerFountain2Netflix_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S173', 'BTouristsSatOutsideVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S174', 'BToyCalendarHarmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S175', 'BTrackingDownHongKongSideVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S176', 'BTrackingPastRestaurantVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S177', 'BTrackingPastStallHongKongVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S178', 'BTraditionalIndonesianKecakVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S179', 'BTrafficandBuildingSJTU_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S180', 'BTrafficFlowSJTU_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S181', 'BTrafficonTasksinBridgeVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S182', 'BTreeWillsBVITexture_1920x1088_120fps_10bit_420.yuv', 1920, 1088, 0, 64, 120, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S183', 'BTruckIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S184', 'BTunnelFlagS1Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S185', 'BUnloadingVegetablesVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S186', 'BVegetableMarketS1LIVENetFlix_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S187', 'BVegetableMarketS2LIVENetFlix_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S188', 'BVegetableMarketS3LIVENetFlix_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S189', 'BVegetableMarketS4LIVENetFlix_1920x1088_30fps_10bit_420.yuv', 1920, 1088, 0, 64, 30, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S190', 'BVeniceS1Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S191', 'BVeniceS2Harmonics_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S192', 'BVeniceSceneIRIS_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S193', 'BWalkingDownKhaoStreetVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S194', 'BWalkingDownNorthRodeoVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S195', 'BWalkingThroughFootbridgeVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S196', 'BWaterS65Mitch_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S197', 'BWaterS81Mitch_1920x1088_24fps_10bit_420.yuv', 1920, 1088, 0, 64, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S198', 'BWatPhoTempleVidevo_1920x1088_50fps_10bit_420.yuv', 1920, 1088, 0, 64, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S199', 'BWoodSJTU_1920x1088_60fps_10bit_420.yuv', 1920, 1088, 0, 64, 60, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['B_S200', 'BWovenVidevo_1920x1088_25fps_10bit_420.yuv', 1920, 1088, 0, 64, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], +] + + +TemporalSubsampleRatio = 1 + +for seq_data in SequenceTable: + seq_key, seq_file_name, width, height, StartFrame, FramesToBeEncoded, FrameRate, InputBitDepth, QPs, RateNames, level = seq_data + + QPs = QPs.split(' ') + RateNames = RateNames.split(' ') + QPs_RateNames = [] + for QP, RateName in zip(QPs,RateNames): + commonFileName = ('T2RA_' + seq_key + '_' + RateName + '_qp' + QP + '_s' + str(StartFrame) + + '_f' + str(FramesToBeEncoded) + '_t' + str(TemporalSubsampleRatio)) + binFile = 'Bin_' + commonFileName + '.bin' + print(binFile) diff --git a/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/tvd_codec_info.py b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/tvd_codec_info.py new file mode 100644 index 0000000000000000000000000000000000000000..113635942907a1040ab906675ff1ca94d1a50c68 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/2_generate_compression_data/tvd_codec_info.py @@ -0,0 +1,91 @@ +SequenceTable = [ + ['A_S01', 'Bamboo_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S02', 'BlackBird_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S03', 'BoyDressing1_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S04', 'BoyDressing2_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S05', 'BoyMakingUp1_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S06', 'BoyMakingUp2_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S07', 'BoyWithCostume_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S08', 'BuildingTouristAttraction1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S09', 'BuildingTouristAttraction2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S10', 'BuildingTouristAttraction3_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S11', 'CableCar_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S12', 'ChefCooking1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S13', 'ChefCooking2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S14', 'ChefCooking3_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S15', 'ChefCooking4_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S16', 'ChefCooking5_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S17', 'ChefCuttingUp1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S18', 'ChefCuttingUp2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S19', 'DryRedPepper_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S20', 'FilmMachine_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S21', 'FlowingWater_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S22', 'Fountain_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S23', 'GirlRunningOnGrass_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S24', 'GirlWithTeaSet1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S25', 'GirlWithTeaSet2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S26', 'GirlWithTeaSet3_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S27', 'GirlsOnGrass1_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S28', 'GirlsOnGrass2_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S29', 'HotPot_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S30', 'HotelClerks_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S31', 'LyingDog_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S32', 'ManWithFilmMachine_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S33', 'MountainsAndStairs1_3840x2160_24fps_10bit_420.yuv', 3840, 2160, 0, 65, 24, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S34', 'MountainsAndStairs2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S35', 'MountainsAndStairs3_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S36', 'MountainsAndStairs4_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S37', 'MountainsView1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S38', 'MountainsView2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S39', 'MountainsView3_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S40', 'MountainsView4_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S41', 'MovingBikesAndPedestrian4_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S42', 'OilPainting1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S43', 'OilPainting2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S44', 'PeopleNearDesk_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S45', 'PeopleOnGrass_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S46', 'Plaque_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S47', 'PressureCooker_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S48', 'RawDucks_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S49', 'RedBush_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S50', 'RedRibbonsWithLocks_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S51', 'RestaurantWaitress1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S52', 'RestaurantWaitress2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S53', 'RiverAndTrees_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S54', 'RoastedDuck_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S55', 'RoomTouristAttraction1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S56', 'RoomTouristAttraction2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S57', 'RoomTouristAttraction3_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S58', 'RoomTouristAttraction4_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S59', 'RoomTouristAttraction5_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S60', 'RoomTouristAttraction6_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S61', 'RoomTouristAttraction7_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S62', 'StampCarving1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S63', 'StampCarving2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S64', 'StaticRocks_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S65', 'StaticWaterAndBikes2_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S66', 'SunAndTrees_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S67', 'SunriseMountainHuang_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S68', 'SunsetMountainHuang1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S69', 'SunsetMountainHuang2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S70', 'TreesAndLeaves_3840x2160_50fps_10bit_420.yuv', 3840, 2160, 0, 65, 50, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S71', 'TreesOnMountains1_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S72', 'TreesOnMountains2_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S73', 'TreesOnMountains3_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], + ['A_S74', 'Weave_3840x2160_25fps_10bit_420.yuv', 3840, 2160, 0, 65, 25, 10, '22 27 32 37 42', 'R22 R27 R32 R37 R42', 4.1], +] + + +TemporalSubsampleRatio = 1 + +for seq_data in SequenceTable: + seq_key, seq_file_name, width, height, StartFrame, FramesToBeEncoded, FrameRate, InputBitDepth, QPs, RateNames, level = seq_data + + QPs = QPs.split(' ') + RateNames = RateNames.split(' ') + QPs_RateNames = [] + for QP, RateName in zip(QPs,RateNames): + commonFileName = ('T2RA_' + seq_key + '_' + RateName + '_qp' + QP + '_s' + str(StartFrame) + + '_f' + str(FramesToBeEncoded) + '_t' + str(TemporalSubsampleRatio)) + binFile = 'Bin_' + commonFileName + '.bin' + print(binFile) diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/3_ReadMe.md b/training/training_scripts/NN_Super_Resolution/3_train_tasks/3_ReadMe.md new file mode 100644 index 0000000000000000000000000000000000000000..927be4d440e8aac76987e33eaef441d369960b9f --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/3_ReadMe.md @@ -0,0 +1,72 @@ +## training process +For NNSR, a total of three networks including two luma networks and one chroma network need to be trained. +### How to monitor the training process + +Tensorboard is used here to monitor the training process. +1. Run the following command in the path 'Experiments' generated by the training script +``` +tensorboard --logdir=Tensorboard --port=6001 +``` +2. Access http://localhost:6001/ to view the result + + +### The training stage + +Launch the training stage by the following command: +``` +sh train.sh +``` +The following lines to set the dataset paths may need to be revised according to your local side. + +| path | row | dataset | format | +| :-------------------------- | :----- | :----------------------------- | :-------------- | + +| ./training_scripts/Luma-I | line 29 | The compression data of BVI-DVC | frame level YUV | +| ./training_scripts/Luma-I | line 30 | The raw data of BVI-DVC | frame level YUV | +| ./training_scripts/Luma-I | line 33 | The compression data of TVD | frame level YUV | +| ./training_scripts/Luma-I | line 34 | The raw data of TVD | frame level YUV | + +| ./training_scripts/Luma-B | line 29 | The compression data of BVI-DVC | frame level YUV | +| ./training_scripts/Luma-B | line 30 | The raw data of BVI-DVC | frame level YUV | +| ./training_scripts/Luma-B | line 33 | The compression data of TVD | frame level YUV | +| ./training_scripts/Luma-B | line 34 | The raw data of TVD | frame level YUV | + +| ./training_scripts/Chroma-IB | line 29 | The compression data of BVI-DVC | frame level YUV | +| ./training_scripts/Chroma-IB | line 30 | The raw data of BVI-DVC | frame level YUV | +| ./training_scripts/Chroma-IB | line 33 | The compression data of TVD | frame level YUV | +| ./training_scripts/Chroma-IB | line 34 | The raw data of TVD | frame level YUV | +| ./training_scripts/Chroma-IB | line 37 | The compression data of BVI-DVC | frame level YUV | +| ./training_scripts/Chroma-IB | line 38 | The raw data of BVI-DVC | frame level YUV | +| ./training_scripts/Chroma-IB | line 41 | The compression data of TVD | frame level YUV | +| ./training_scripts/Chroma-IB | line 42 | The raw data of TVD | frame level YUV | + +The convergence curve for different models on the validation set is shown below. + + + + + +The convergence curve above is enlarged to select the optimal model as follows. The model selection is decided based on the PSNR improvement of Y, Cb and Cr. +Finally, the selected optimal model and its training epoch are shown as follows in the training stage. + +| network | epoch | +| :----- | :------------ | +| Luma-I | model_0820.pth | +| Luma-B | model_0925.pth | +| Chroma-IB | model_0315.pth | + +### Convert to Libtorch model +1. Select the optimal model of the training stage as the final model. +2. Use the following command to get the converted models (Luma-I.pt Luma-B.pt Chroma-IB.pt). +``` +sh conversion.sh +``` +Please note that to successfully generate different models, the final model path (in model_conversion.py line 9), coversion code (in model_conversion.py line 21/24/27) and generated model name (in model_conversion.py line 29) should be consistent. + +### Other +Noted that the models in this training scripts are generated by **pytorch-1.9.0**. +The corresponding Libtorch or Pytorch version should be used when loading the model. Otherwise, it is likely to get an error. + +Empirically, it is best to perform training until the PSNR results on the validation set begin to decline. So for the provided training scripts, a fixed and relative large 2000 training epochs is set. +Because of the randomness in the training process, the number of training epochs provided in the training stage is only for the reference. +It is suggested to run the task until the number of training epochs is greater than the provided one (Luma-I: 0820, Luma-B: 0925, Chroma-IB:0315) directly, and then find the optimal result around. diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/conversion.sh b/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/conversion.sh new file mode 100644 index 0000000000000000000000000000000000000000..36bb9a8fae84578d8d5af9dd61cd3cec6ef71dea --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/conversion.sh @@ -0,0 +1,32 @@ +# The copyright in this software is being made available under the BSD +# License, included below. This software may be subject to other third party +# and contributor rights, including patent rights, and no such rights are +# granted under this license. +# +# Copyright (c) 2010-2022, ITU/ISO/IEC +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +# be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +python model_conversion.py \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/model_conversion.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/model_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..7a2654287b2f31aeef5f74e5691074bc5e0bc000 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/model_conversion/model_conversion.py @@ -0,0 +1,65 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +from nn_model import Net +import argparse +import numpy as np +import torch +from argparse import ArgumentParser + +if __name__ == '__main__': + model = Net() + model_name = 'model_0315.pth' + model.load_state_dict(torch.load(model_name, map_location=lambda storage, loc: storage)["network"]) + model.eval() + + example1 = torch.ones(1, 3, 144, 144) + example2 = torch.ones(1, 1, 144, 144) + example3 = torch.ones(1, 1, 72, 72) + example4 = torch.ones(1, 2, 72, 72) + example5 = torch.ones(1, 2, 144, 144) + example6 = torch.ones(1, 3, 72, 72) + + # Luma-I + #traced_script_module = torch.jit.trace(model, [example3, example3, example2, example3]) + + # Luma-B + #traced_script_module = torch.jit.trace(model, [example3, example3, example2, example3, example3]) + + # Chroma-IB + traced_script_module = torch.jit.trace(model, [example2, example4, example5, example3, example3, example3]) + + traced_script_module.save("Chroma-IB.pt") + + diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/Utils.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/Utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3b918f11cb2210859d03d8e8828e9ef0f8a6a59c --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/Utils.py @@ -0,0 +1,241 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import argparse +import datetime +import logging +import math +import random +import struct +from pathlib import Path +import numpy as np +import PIL.Image as Image +import numpy as np +import os + + +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms as tfs + +def parse_args(): + parser = argparse.ArgumentParser() + + path_cur = Path(os.path.split(os.path.realpath(__file__))[0]) + path_save = path_cur.joinpath("Experiments") + + # for loading data + parser.add_argument("--ext", type=str, default='yuv', help="data file extension") + + parser.add_argument("--data_range_bvi_AI", type=str, default='1-180/181-200', help="train/test data range") + parser.add_argument('--dir_data_bvi_AI', type=str, default='/path/EE1_2_2_train/AI_BVI_DVC', help='distorted dataset directory') + parser.add_argument('--dir_data_ori_bvi_AI', type=str, default='/path/EE1_2_2_train_ori/BVI_DVC', help='raw dataset directory') + + parser.add_argument("--data_range_tvd_AI", type=str, default='1-66/67-74', help="train/test data range") + parser.add_argument('--dir_data_tvd_AI', type=str, default='/path/EE1_2_2_train/AI_TVD', help='distorted dataset directory') + parser.add_argument('--dir_data_ori_tvd_AI', type=str, default='/path/EE1_2_2_train_ori/TVD', help='raw dataset directory') + + parser.add_argument("--data_range_bvi_RA", type=str, default='1-180/181-200', help="train/test data range") + parser.add_argument('--dir_data_bvi_RA', type=str, default='/path/EE1_2_2_train/RA_BVI_DVC', help='distorted dataset directory') + parser.add_argument('--dir_data_ori_bvi_RA', type=str, default='/path/EE1_2_2_train_ori/BVI_DVC', help='raw dataset directory') + + parser.add_argument("--data_range_tvd_RA", type=str, default='1-66/67-74', help="train/test data range") + parser.add_argument('--dir_data_tvd_RA', type=str, default='/path/EE1_2_2_train/RA_TVD', help='distorted dataset directory') + parser.add_argument('--dir_data_ori_tvd_RA', type=str, default='/path/EE1_2_2_train_ori/TVD', help='raw dataset directory') + + # for loading model + parser.add_argument("--checkpoints", type=str, help="checkpoints file path") + parser.add_argument("--pretrained", type=str, help="pretrained model path") + + # batch size + parser.add_argument("--batch_size", type=int, default=64, help="batch size for Fusion stage") + # do validation + parser.add_argument("--test_every",type=int, default=1200, help="do test per every N batches") + + # learning rate + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for Fusion stage") + + parser.add_argument("--gpu", action='store_true', default=True, help="use gpu or cpu") + + # epoch + parser.add_argument("--max_epoch", type=int, default=2000, help="max training epochs") + + # patch_size + parser.add_argument("--patch_size", type=int, default=256, help="train/val patch size") + parser.add_argument("--shave", type=int, default=8, help="train/shave") + + # for recording + parser.add_argument("--verbose", action='store_true', default=True, help="use tensorboard and logger") + parser.add_argument("--save_dir", type=str, default=path_save, help="directory for recording") + parser.add_argument("--eval_epochs", type=int, default=5, help="save model after epochs") + + args = parser.parse_args() + return args + + +def init(): + # parse arguments + args = parse_args() + + # create directory for recording + experiment_dir = Path(args.save_dir) + experiment_dir.mkdir(exist_ok=True) + + ckpt_dir = experiment_dir.joinpath("Checkpoints/") + ckpt_dir.mkdir(exist_ok=True) + print(r"===========Save checkpoints to {0}===========".format(str(ckpt_dir))) + + if args.verbose: + # initialize logger + log_dir = experiment_dir.joinpath('Log/') + log_dir.mkdir(exist_ok=True) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler = logging.FileHandler(str(log_dir) + '/Log.txt') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.info('PARAMETER ...') + logger.info(args) + # initialize tensorboard + tb_dir_all = experiment_dir.joinpath('Tensorboard_all/') + tb_dir_all.mkdir(exist_ok=True) + tensorboard_all = SummaryWriter(log_dir=str(tb_dir_all), flush_secs=30) + + tb_dir = experiment_dir.joinpath('Tensorboard/') + tb_dir.mkdir(exist_ok=True) + tensorboard = SummaryWriter(log_dir=str(tb_dir), flush_secs=30) + print(r"===========Save tensorboard and logger to {0}===========".format(str(tb_dir_all))) + else: + print(r"===========Disable tensorboard and logger to accelerate training===========") + logger = None + tensorboard_all = None + tensorboard = None + + return args, logger, ckpt_dir, tensorboard_all, tensorboard + +def yuv_read(yuv_path, h, w, iy, ix, ip): + h_c = h//2 + w_c = w//2 + + ip_c = ip//2 + iy_c = iy//2 + ix_c = ix//2 + + fp = open(yuv_path, 'rb') + + # y + fp.seek(iy*w*2, 0) + patch_y = np.fromfile(fp, np.uint16, ip*w).reshape(ip, w, 1) + patch_y = patch_y[:, ix:ix + ip, :] + + # u + fp.seek(( w*h+ iy_c*w_c)*2, 0) + patch_u = np.fromfile(fp, np.uint16, ip_c*w_c).reshape(ip_c, w_c, 1) + patch_u = patch_u[:, ix_c:ix_c + ip_c, :] + + # v + fp.seek(( w*h+ w_c*h_c+ iy_c*w_c)*2, 0) + patch_v = np.fromfile(fp, np.uint16, ip_c*w_c).reshape(ip_c, w_c, 1) + patch_v = patch_v[:, ix_c:ix_c + ip_c, :] + + fp.close() + + return patch_y, patch_u, patch_v + +def upsample(img, height, width): + img = np.squeeze(img, axis = 2) + img=np.array(Image.fromarray(img.astype(np.float)).resize((width, height), Image.NEAREST)) + img = np.expand_dims(img, axis = 2) + return img + +def patch_process(yuv_path, h, w, iy, ix, ip): + y, u, v = yuv_read(yuv_path, h, w, iy, ix, ip) + #u_up = upsample(u, ip, ip) + #v_up = upsample(v, ip, ip) + #yuv = np.concatenate((y, u_up, v_up), axis=2) + return y, u, v + +def get_patch(image_yuv_path, image_yuv_rpr_path, image_yuv_org_path, w, h, patch_size, shave): + ih = h + iw = w + + ip = patch_size + ih -= ih % ip + iw -= iw % ip + iy = random.randrange(ip, ih-ip, ip) - shave + ix = random.randrange(ip, iw-ip, ip) - shave + + # + patch_rec_y, patch_rec_u, patch_rec_v = patch_process(image_yuv_path, h//2, w//2, iy//2, ix//2, (ip + 2*shave)//2) + _, patch_rpr_u, patch_rpr_v = patch_process(image_yuv_rpr_path, h, w, iy, ix, ip + 2*shave) + _, patch_org_u, patch_org_v = patch_process(image_yuv_org_path, h, w, iy, ix, ip + 2*shave) + + patch_in = np.concatenate((patch_rec_u, patch_rec_v), axis=2) + patch_rpr = np.concatenate((patch_rpr_u, patch_rpr_v), axis=2) + patch_org = np.concatenate((patch_org_u, patch_org_v), axis=2) + + ret = [patch_rec_y, patch_in, patch_rpr, patch_org] + + return ret + +def augment(*args): + x = random.random() + hflip = x < 0.2 + vflip = x >= 0.2 and x < 0.4 + rot90 = x >= 0.4 and x < 0.6 + + def _augment(img): + if hflip: img = img[:, ::-1, :] + if vflip: img = img[::-1, :, :] + if rot90: img = img.transpose(1, 0, 2) + + return img + + return [_augment(a) for a in args] + +def np2Tensor(*args): + def _np2Tensor(img): + np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) + tensor = torch.from_numpy(np_transpose.astype(np.int32)).float() / 1023.0 + + return tensor + + return [_np2Tensor(a) for a in args] + +def cal_psnr(distortion: torch.Tensor): + psnr = -10 * torch.log10(distortion) + return psnr diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/nn_model.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/nn_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2dacdd640434a973b37e4a7526c9e65c20853ed9 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/nn_model.py @@ -0,0 +1,105 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn +from torch.nn import Parameter + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + # hyper-params + n_resblocks = 24 + n_feats_k = 64 + n_feats_m = 192 + + # define head module + self.head_rec_y = nn.Sequential( + nn.Conv2d(in_channels = 1, out_channels = n_feats_k, kernel_size = 3, stride = 2, padding = 1), + nn.PReLU() + ) + self.head_rec = nn.Sequential( + nn.Conv2d(in_channels = 2, out_channels = n_feats_k, kernel_size = 3, stride = 1, padding = 1), # downsmaple by stride = 2 + nn.PReLU() + ) + + #define fuse module + self.fuse = nn.Sequential( + nn.Conv2d(in_channels = n_feats_k*2 + 3, out_channels = n_feats_k, kernel_size = 1, stride = 1, padding = 0), + nn.PReLU() + ) + + # define body module + body = [] + for _ in range(n_resblocks): + body.append(DscBlock(n_feats_k, n_feats_m)) + + self.body = nn.Sequential(*body) + + # define tail module + self.tail = nn.Sequential( + nn.Conv2d(in_channels = n_feats_k, out_channels = 4 * 2, kernel_size = 3, padding = 1), + nn.PixelShuffle(2) #feature_map:(B, 2x2x2, N, N) -> (B, 2, 2N, 2N) + ) + + def forward(self, y_rec, uv_rec, uv_rpr, slice_qp, base_qp, slice_type): + in_0 = self.head_rec_y(y_rec) + in_1 = self.head_rec(uv_rec) + + x = self.fuse(torch.cat((in_0, in_1, slice_qp, base_qp, slice_type), 1)) + x = self.body(x) + x = self.tail(x) + x[:,0:1,:,:] += uv_rpr[:,0:1,:,:] + x[:,1:2,:,:] += uv_rpr[:,1:2,:,:] + + return x + +class DscBlock(nn.Module): + def __init__(self, n_feats_k, n_feats_m, expansion=1): + super(DscBlock, self).__init__() + self.expansion = expansion + self.c1 = nn.Conv2d(in_channels=n_feats_k, out_channels=n_feats_m, kernel_size=1, padding=0) + self.prelu = nn.PReLU() + self.c2 = nn.Conv2d(in_channels=n_feats_m, out_channels=n_feats_k, kernel_size=1, padding=0) + self.c3 = nn.Conv2d(in_channels=n_feats_k, out_channels=n_feats_k, kernel_size=3, padding=1) + + def forward(self, x): + i = x + x = self.c2(self.prelu(self.c1(x))) + x = self.c3(x) + x += i + + return x + + diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train.sh b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b026d766aefdfa8a5c92a46d1750661bc344f5a --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train.sh @@ -0,0 +1,32 @@ +# The copyright in this software is being made available under the BSD +# License, included below. This software may be subject to other third party +# and contributor rights, including patent rights, and no such rights are +# granted under this license. +# +# Copyright (c) 2010-2022, ITU/ISO/IEC +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +# be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +python train_YUV.py \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train_YUV.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train_YUV.py new file mode 100644 index 0000000000000000000000000000000000000000..783c80f058014d61b4ae2dc0348a2cd0de8e49da --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/train_YUV.py @@ -0,0 +1,339 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data.dataloader import DataLoader +import datetime +import os, glob + +from yuv10bdata import YUV10bData +from Utils import init, cal_psnr +from nn_model import Net + +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True + + +class Trainer: + def __init__(self): + self.args, self.logger, self.checkpoints_dir, self.tensorboard_all, self.tensorboard = init() + + self.net = Net().to("cuda" if self.args.gpu else "cpu") + + self.L1loss = nn.L1Loss().to("cuda" if self.args.gpu else "cpu") + self.L2loss = nn.MSELoss().to("cuda" if self.args.gpu else "cpu") + + self.optimizer = Adam(self.net.parameters(), lr = self.args.lr) + self.scheduler = MultiStepLR(optimizer=self.optimizer, milestones=[4001, 4002], gamma=0.5) + + print("============>loading data") + self.train_dataset = YUV10bData(self.args, train=True) + self.eval_dataset = YUV10bData(self.args, train=False) + + self.train_dataloader = DataLoader(dataset=self.train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=12, pin_memory=False) + self.eval_dataloader = DataLoader(dataset=self.eval_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=12, pin_memory=False) + + self.train_steps = self.eval_steps = 0 + + def train(self): + start_epoch = self.load_checkpoints() + print("============>start training") + for epoch in range(start_epoch, self.args.max_epoch): + print("Epoch {}/{}".format(epoch, self.args.max_epoch)) + self.logger.info("Epoch {}/{}".format(epoch, self.args.max_epoch)) + self.train_one_epoch() + self.scheduler.step() + if (epoch+1) % self.args.eval_epochs == 0: + self.eval(epoch=epoch) + self.save_ckpt(epoch=epoch) + + def train_one_epoch(self): + self.net.train() + for _, tensor in enumerate(self.train_dataloader): + + img_lr, img_hr, filename = tensor + + img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") + img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") + + uv_rec = img_lr[:,0:2,:,:] + slice_qp = img_lr[:,2:3,:,:] + base_qp = img_lr[:, 3:4,:,:] + slice_type = img_lr[:,4:5,:,:] + y_rec = img_hr[:,0:1,:,:] + img_rpr = img_hr[:,1:3,:,:] + img_ori = img_hr[:,3:5,:,:] + + img_out = self.net(y_rec, uv_rec, img_rpr, slice_qp, base_qp, slice_type) + + # calculate distortion + shave=self.args.shave // 2 + + #L1_loss_pred_Y = self.L1loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + L1_loss_pred_Cb = self.L1loss(img_out[:,0:1,shave:-shave,shave:-shave], img_ori[:,0:1,shave:-shave,shave:-shave]) + L1_loss_pred_Cr = self.L1loss(img_out[:,1:2,shave:-shave,shave:-shave], img_ori[:,1:2,shave:-shave,shave:-shave]) + + #loss_pred_Y = self.L2loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + loss_pred_Cb = self.L2loss(img_out[:,0:1,shave:-shave,shave:-shave], img_ori[:,0:1,shave:-shave,shave:-shave]) + loss_pred_Cr = self.L2loss(img_out[:,1:2,shave:-shave,shave:-shave], img_ori[:,1:2,shave:-shave,shave:-shave]) + + #loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr + loss_pred= L1_loss_pred_Cb + L1_loss_pred_Cr + + #loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + # visualization + self.train_steps += 1 + if self.train_steps % 20 == 0: + #psnr_pred_Y = cal_psnr(loss_pred_Y) + psnr_pred_Cb = cal_psnr(loss_pred_Cb) + psnr_pred_Cr = cal_psnr(loss_pred_Cr) + + #psnr_input_Y = cal_psnr(loss_rec_Y) + #psnr_input_Cb = cal_psnr(loss_rec_Cb) + #psnr_input_Cr = cal_psnr(loss_rec_Cr) + + time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") + + print("[{}/{}]\tCb:{:.8f}\tCr:{:.8f}\tPSNR_Cb:{:.8f}\tPSNR_Cr:{:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + loss_pred_Cb, loss_pred_Cr, psnr_pred_Cb, psnr_pred_Cr, time)) + self.logger.info("[{}/{}]\tCb:{:.8f}\tCr:{:.8f}\tPSNR_Cb:{:.8f}\tPSNR_Cr:{:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + loss_pred_Cb, loss_pred_Cr, psnr_pred_Cb, psnr_pred_Cr)) + + #print("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y, time)) + #self.logger.info("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y)) + + self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + tag_scalar_dict={"pred_Cb": psnr_pred_Cb.data}, + global_step=self.train_steps) + self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + tag_scalar_dict={"pred_Cr": psnr_pred_Cr.data}, + global_step=self.train_steps) + self.tensorboard_all.add_image("rec", uv_rec[0:1,0:1,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("pre", pre[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + self.tensorboard_all.add_image("rpr", img_rpr[0:1,0:1,:,:].squeeze(dim=0), global_step=self.train_steps) + self.tensorboard_all.add_image("out", img_out[0:1,0:1,:,:].squeeze(dim=0), global_step=self.train_steps) + self.tensorboard_all.add_image("ori", img_ori[0:1,0:1,:,:].squeeze(dim=0), global_step=self.train_steps) + + #self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + # tag_scalar_dict={"input_Cb": psnr_input_Cb.data, + # "pred_Cb": psnr_pred_Cb.data}, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + # tag_scalar_dict={"input_Cr": psnr_input_Cr.data, + # "pred_Cr": psnr_pred_Cr.data}, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Y", + # scalar_value = psnr_pred_Y - psnr_input_Y, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cb", + # scalar_value = psnr_pred_Cb - psnr_input_Cb, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cr", + # scalar_value = psnr_pred_Cr - psnr_input_Cr, + # global_step=self.train_steps) + + self.tensorboard_all.add_scalar(tag="Train/train_loss_pred", + scalar_value = loss_pred, + global_step=self.train_steps) + + # backward + self.optimizer.zero_grad() + loss_pred.backward() + self.optimizer.step() + + @torch.no_grad() + def eval(self, epoch: int): + print("============>start evaluating") + eval_cnt = 0 + #ave_psnr_Y = 0.000 + ave_psnr_Cb = 0.000 + ave_psnr_Cr = 0.000 + self.net.eval() + for _, tensor in enumerate(self.eval_dataloader): + + img_lr, img_hr, filename = tensor + + img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") + img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") + + uv_rec = img_lr[:,0:2,:,:] + slice_qp = img_lr[:,2:3,:,:] + base_qp = img_lr[:, 3:4,:,:] + slice_type = img_lr[:,4:5,:,:] + y_rec = img_hr[:,0:1,:,:] + img_rpr = img_hr[:,1:3,:,:] + img_ori = img_hr[:,3:5,:,:] + + img_out = self.net(y_rec, uv_rec, img_rpr, slice_qp, base_qp, slice_type) + + # calculate distortion + shave=self.args.shave // 2 + + #L1_loss_pred_Y = self.L1loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + L1_loss_pred_Cb = self.L1loss(img_out[:,0:1,shave:-shave,shave:-shave], img_ori[:,0:1,shave:-shave,shave:-shave]) + L1_loss_pred_Cr = self.L1loss(img_out[:,1:2,shave:-shave,shave:-shave], img_ori[:,1:2,shave:-shave,shave:-shave]) + + #loss_pred_Y = self.L2loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + loss_pred_Cb = self.L2loss(img_out[:,0:1,shave:-shave,shave:-shave], img_ori[:,0:1,shave:-shave,shave:-shave]) + loss_pred_Cr = self.L2loss(img_out[:,1:2,shave:-shave,shave:-shave], img_ori[:,1:2,shave:-shave,shave:-shave]) + + #loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr + loss_pred = L1_loss_pred_Cb + L1_loss_pred_Cr + + #loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + #psnr_pred_Y = cal_psnr(loss_pred_Y) + psnr_pred_Cb = cal_psnr(loss_pred_Cb) + psnr_pred_Cr = cal_psnr(loss_pred_Cr) + + #psnr_input_Y = cal_psnr(loss_rec_Y) + #psnr_input_Cb = cal_psnr(loss_rec_Cb) + #psnr_input_Cr = cal_psnr(loss_rec_Cr) + + #ave_psnr_Y += psnr_pred_Y + ave_psnr_Cb += psnr_pred_Cb + ave_psnr_Cr += psnr_pred_Cr + + eval_cnt += 1 + # visualization + self.eval_steps += 1 + if self.eval_steps % 2 == 0: + + #self.tensorboard_all.add_scalar(tag="Eval/PSNR_Cb", + # scalar_value = psnr_pred_Cb, + # global_step=self.eval_steps) + + self.tensorboard_all.add_scalar(tag="Eval/PSNR_Cb", + scalar_value = psnr_pred_Cb, + global_step=self.eval_steps) + + self.tensorboard_all.add_scalar(tag="Eval/PSNR_Cr", + scalar_value = psnr_pred_Cr, + global_step=self.eval_steps) + + self.tensorboard_all.add_scalar(tag="Eval/eval_loss_pred", + scalar_value = loss_pred, + global_step=self.eval_steps) + + time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") + #print("PSNR_Y:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, time)) + #self.logger.info("PSNR_Y:{:.3f}".format(ave_psnr_Y / eval_cnt)) + + print("delta_Cb:{:.3f}\tdelta_Cr:{:.3f}------{}".format(ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt, time)) + self.logger.info("delta_Cb:{:.3f}\tdelta_Cr:{:.3f}".format(ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt)) + + #self.tensorboard.add_scalar(tag = "Eval/PSNR_Y_ave", + # scalar_value = ave_psnr_Y / eval_cnt, + # global_step = epoch + 1) + self.tensorboard.add_scalar(tag = "Eval/PSNR_Cb_ave", + scalar_value = ave_psnr_Cb / eval_cnt, + global_step = epoch + 1) + self.tensorboard.add_scalar(tag = "Eval/PSNR_Cr_ave", + scalar_value = ave_psnr_Cr / eval_cnt, + global_step = epoch + 1) + + def load_checkpoints(self): + if not self.args.checkpoints: + ckpt_list=sorted(glob.glob(os.path.join(self.checkpoints_dir, '*.pth'))) + num = len(ckpt_list) + if(num > 1): + if os.path.getsize(ckpt_list[-1]) == os.path.getsize(ckpt_list[-2]): + self.args.checkpoints = ckpt_list[-1] + else: + self.args.checkpoints = ckpt_list[-2] + + if self.args.checkpoints: + print("===========Load checkpoints {0}===========".format(self.args.checkpoints)) + self.logger.info("Load checkpoints {0}".format(self.args.checkpoints)) + ckpt = torch.load(self.args.checkpoints) + # load network weights + try: + self.net.load_state_dict(ckpt["network"]) + except: + print("Can not find network weights") + # load optimizer params + try: + self.optimizer.load_state_dict(ckpt["optimizer"]) + self.scheduler.load_state_dict(ckpt["scheduler"]) + except: + print("Can not find some optimizers params, just ignore") + start_epoch = ckpt["epoch"] + 1 + self.train_steps = ckpt["train_step"] + 1 + self.eval_steps = ckpt["eval_step"] + 1 + elif self.args.pretrained: + ckpt = torch.load(self.args.pretrained) + print("===========Load network weights {0}===========".format(self.args.checkpoints)) + self.logger.info("Load network weights {0}".format(self.args.checkpoints)) + # load codec weights + try: + self.net.load_state_dict(ckpt["network"]) + except: + print("Can not find network weights") + start_epoch = 0 + else: + print("===========Training from scratch===========") + self.logger.info("Training from scratch") + start_epoch = 0 + return start_epoch + + def save_ckpt(self, epoch: int): + checkpoint = { + "network": self.net.state_dict(), + "epoch": epoch, + "train_step": self.train_steps, + "eval_step": self.eval_steps, + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict()} + + torch.save(checkpoint, '%s/model_%.4d.pth' % (self.checkpoints_dir, epoch+1)) + self.logger.info('Save model..') + print("======================Saving model {0}======================".format(str(epoch))) + +if __name__ == "__main__": + trainer = Trainer() + trainer.train() diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/yuv10bdata.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/yuv10bdata.py new file mode 100644 index 0000000000000000000000000000000000000000..0cce9b7a69cc9a3a3803b37743c812d012ce9eed --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Chroma-IB/yuv10bdata.py @@ -0,0 +1,288 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os +import glob +from torch.utils.data import Dataset +import torch.nn.functional as F +import numpy as np +import string +import random +import Utils +import math + + +class YUV10bData(Dataset): + def __init__(self, args, name='YuvData', train=True): + super(YUV10bData, self).__init__() + self.args = args + self.split = 'train' if train else 'valid' + self.image_ext = args.ext + self.name = name + self.train = train + + data_range_bvi_AI = [r.split('-') for r in args.data_range_bvi_AI.split('/')] + data_range_tvd_AI = [r.split('-') for r in args.data_range_tvd_AI.split('/')] + data_range_bvi_RA = [r.split('-') for r in args.data_range_bvi_RA.split('/')] + data_range_tvd_RA = [r.split('-') for r in args.data_range_tvd_RA.split('/')] + + if train: + data_range_bvi_AI = data_range_bvi_AI[0] + data_range_tvd_AI = data_range_tvd_AI[0] + data_range_bvi_RA = data_range_bvi_RA[0] + data_range_tvd_RA = data_range_tvd_RA[0] + else: + data_range_bvi_AI = data_range_bvi_AI[1] + data_range_tvd_AI = data_range_tvd_AI[1] + data_range_bvi_RA = data_range_bvi_RA[1] + data_range_tvd_RA = data_range_tvd_RA[1] + + + self.begin_bvi_AI, self.end_bvi_AI = list(map(lambda x: int(x), data_range_bvi_AI)) + self.begin_tvd_AI, self.end_tvd_AI = list(map(lambda x: int(x), data_range_tvd_AI)) + self.begin_bvi_RA, self.end_bvi_RA = list(map(lambda x: int(x), data_range_bvi_RA)) + self.begin_tvd_RA, self.end_tvd_RA = list(map(lambda x: int(x), data_range_tvd_RA)) + + self._set_data() + self._get_image_list() + + if train: + n_patches = args.batch_size * args.test_every + n_images = len(self.images_yuv) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = 1 + #self.repeat = max(n_patches // n_images, 1) + print(f"repeating dataset {self.repeat} for one epoch") + else: + n_patches = args.batch_size * args.test_every // 25 + n_images = len(self.images_yuv) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = 5 + print(f"repeating dataset {self.repeat} for one epoch") + + def _set_data(self): + self.dir_in_bvi_AI = os.path.join(self.args.dir_data_bvi_AI, "yuv") + self.dir_org_bvi_AI = os.path.join(self.args.dir_data_ori_bvi_AI) + + self.dir_in_tvd_AI = os.path.join(self.args.dir_data_tvd_AI, "yuv") + self.dir_org_tvd_AI = os.path.join(self.args.dir_data_ori_tvd_AI) + + self.dir_in_bvi_RA = os.path.join(self.args.dir_data_bvi_RA, "yuv") + self.dir_org_bvi_RA = os.path.join(self.args.dir_data_ori_bvi_RA) + + self.dir_in_tvd_RA = os.path.join(self.args.dir_data_tvd_RA, "yuv") + self.dir_org_tvd_RA = os.path.join(self.args.dir_data_ori_tvd_RA) + + def _scan_class(self, is_tvd, class_name, mode): + QPs = ['22', '27', '32', '37', '42'] + + if is_tvd: + if mode == 'AI': + dir_in = self.dir_in_tvd_AI + dir_org = self.dir_org_tvd_AI + else: + dir_in = self.dir_in_tvd_RA + dir_org = self.dir_org_tvd_RA + else: + if mode == 'AI': + dir_in = self.dir_in_bvi_AI + dir_org = self.dir_org_bvi_AI + else: + dir_in = self.dir_in_bvi_RA + dir_org = self.dir_org_bvi_RA + + list_temp = glob.glob(os.path.join(dir_in, '*_' + class_name + '_*.' + self.image_ext)) + file_rec_list = [] + for i in list_temp: + index = i.find('poc') + poc = int(i[index+3:index+6]) + if poc % 3 == 0 and poc != 0: + file_rec_list.append(i) + + if is_tvd: + file_rec_list = sorted(file_rec_list) + else: + file_rec_list = sorted(file_rec_list, key=str.lower) + + list_temp = glob.glob(os.path.join(dir_org, class_name + '*/*.' + self.image_ext)) + #print(list_temp) + file_org_list = [] + for i in list_temp: + index = i.find('frame_') + poc = int(i[index+6:index+9]) + if poc % 3 == 0 and poc != 0: + file_org_list.append(i) + + if is_tvd: + file_org_list = sorted(file_org_list) + else: + file_org_list = sorted(file_org_list, key=str.lower) + + frame_num = 62 + frame_num_sampled = math.ceil(frame_num / 3) + + if is_tvd: + if mode == 'AI': + begin = self.begin_tvd_AI + end = self.end_tvd_AI + else: + begin = self.begin_tvd_RA + end = self.end_tvd_RA + else: + if mode == 'AI': + begin = self.begin_bvi_AI + end = self.end_bvi_AI + else: + begin = self.begin_bvi_RA + end = self.end_bvi_RA + + class_names_yuv=[] + class_names_yuv_org=[] + + for qp in QPs: + file_list=file_rec_list[(begin-1)*frame_num_sampled*5:end*frame_num_sampled*5] + for filename in file_list: + idx = filename.find("qp") + if int(filename[idx+2:idx+4]) == int(qp): + class_names_yuv.append(filename) + + file_list=file_org_list[(begin-1)*frame_num_sampled:end*frame_num_sampled] + for filename in file_list: + class_names_yuv_org.append(filename) + + return class_names_yuv, class_names_yuv_org + + def _scan(self): + bvi_class_set = ['A'] + + names_yuv=[] + names_yuv_org=[] + for class_name in bvi_class_set: + class_names_yuv, class_names_yuv_org = self._scan_class(False, class_name, 'AI') + names_yuv = names_yuv + class_names_yuv + names_yuv_org = names_yuv_org + class_names_yuv_org + + class_names_yuv, class_names_yuv_org = self._scan_class(False, class_name, 'RA') + names_yuv = names_yuv + class_names_yuv + names_yuv_org = names_yuv_org + class_names_yuv_org + + class_names_yuv, class_names_yuv_org = self._scan_class(True, 'A', 'AI') + names_yuv = names_yuv + class_names_yuv + names_yuv_org = names_yuv_org + class_names_yuv_org + + class_names_yuv, class_names_yuv_org = self._scan_class(True, 'A', 'RA') + names_yuv = names_yuv + class_names_yuv + names_yuv_org = names_yuv_org + class_names_yuv_org + + print(len(names_yuv)) + print(len(names_yuv_org)) + + return names_yuv, names_yuv_org + + def _get_image_list(self): + self.images_yuv, self.images_yuv_org = self._scan() + + def __getitem__(self, idx): + patch_in, patch_org, filename = self._load_file_get_patch(idx) + pair_t = Utils.np2Tensor(patch_in, patch_org) + + return pair_t[0], pair_t[1], filename + + def __len__(self): + if self.train: + return len(self.images_yuv) * self.repeat + else: + return len(self.images_yuv) * self.repeat + + def _get_index(self, idx): + if self.train: + return idx % len(self.images_yuv) + else: + return idx % len(self.images_yuv) + + + def _load_file_get_patch(self, idx): + + idx = self._get_index(idx) + + # reconstruction + image_yuv_path = self.images_yuv[idx] + + slice_qp_idx = int(image_yuv_path.rfind("qp")) + slice_qp = int(image_yuv_path[slice_qp_idx+2:slice_qp_idx+4]) + slice_qp_map = np.uint16(np.ones(((self.args.patch_size + 2*self.args.shave)//4, (self.args.patch_size + 2*self.args.shave)//4, 1))*slice_qp) + + base_qp_idx = int(image_yuv_path.find("qp")) + base_qp = int(image_yuv_path[base_qp_idx+2:base_qp_idx+4]) + base_qp_map = np.uint16(np.ones(((self.args.patch_size + 2*self.args.shave)//4, (self.args.patch_size + 2*self.args.shave)//4, 1))*base_qp) + + if self.args.dir_data_bvi_AI in image_yuv_path or self.args.dir_data_tvd_AI in image_yuv_path: + is_AI = 1 + else: + is_AI = 0 + + if is_AI: + slice_type = 0 + else: + slice_type = 1023 + slice_type_map = np.uint16(np.ones(((self.args.patch_size + 2*self.args.shave)//4, (self.args.patch_size + 2*self.args.shave)//4, 1))*slice_type) + + # RPR + rpr_str = '_rpr' + pos = image_yuv_path.find('.yuv') + image_yuv_rpr_path = image_yuv_path[:pos] + rpr_str + image_yuv_path[pos:] + image_yuv_rpr_path = image_yuv_rpr_path.replace('/yuv/', '/rpr_image/') + + # original + image_yuv_org_path = self.images_yuv_org[idx] + org_splits = os.path.basename(os.path.dirname(image_yuv_org_path)).split('_') + wh_org=org_splits[1].split('x') + w, h = list(map(lambda x: int(x), wh_org)) + + patch_rec_y, patch_in, patch_rpr, patch_org = Utils.get_patch( image_yuv_path, image_yuv_rpr_path, image_yuv_org_path, w, h, self.args.patch_size, self.args.shave ) + + patch_lr = np.concatenate((patch_in, slice_qp_map, base_qp_map, slice_type_map), axis=2) + patch_hr = np.concatenate((patch_rec_y, patch_rpr, patch_org), axis=2) + + if self.train: + patch_lr, patch_hr = Utils.augment(patch_lr, patch_hr) + + return patch_lr, patch_hr, image_yuv_path + + + diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/Utils.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/Utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ddaa700a3f9d709b8243804e72296b0700b3dd --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/Utils.py @@ -0,0 +1,232 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import argparse +import datetime +import logging +import math +import random +import struct +from pathlib import Path +import numpy as np +import PIL.Image as Image +import numpy as np +import os + + +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms as tfs + +def parse_args(): + parser = argparse.ArgumentParser() + + path_cur = Path(os.path.split(os.path.realpath(__file__))[0]) + path_save = path_cur.joinpath("Experiments") + + # for loading data + parser.add_argument("--ext", type=str, default='yuv', help="data file extension") + + parser.add_argument("--data_range", type=str, default='1-180/181-200', help="train/test data range") + parser.add_argument('--dir_data', type=str, default='/path/EE1_2_2_train/RA_BVI_DVC', help='distorted dataset directory') + parser.add_argument('--dir_data_ori', type=str, default='/path/EE1_2_2_train_ori/BVI_DVC', help='raw dataset directory') + + parser.add_argument("--data_range_tvd", type=str, default='1-66/67-74', help="train/test data range") + parser.add_argument('--dir_data_tvd', type=str, default='/path/EE1_2_2_train/RA_TVD', help='distorted dataset directory') + parser.add_argument('--dir_data_ori_tvd', type=str, default='/path/EE1_2_2_train_ori/TVD', help='raw dataset directory') + + # for loading model + parser.add_argument("--checkpoints", type=str, help="checkpoints file path") + parser.add_argument("--pretrained", type=str, help="pretrained model path") + + # batch size + parser.add_argument("--batch_size", type=int, default=64, help="batch size for Fusion stage") + # do validation + parser.add_argument("--test_every",type=int, default=1200, help="do test per every N batches") + + # learning rate + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for Fusion stage") + + parser.add_argument("--gpu", action='store_true', default=True, help="use gpu or cpu") + + # epoch + parser.add_argument("--max_epoch", type=int, default=2000, help="max training epochs") + + # patch_size + parser.add_argument("--patch_size", type=int, default=128, help="train/val patch size") + parser.add_argument("--shave", type=int, default=8, help="train/shave") + + # for recording + parser.add_argument("--verbose", action='store_true', default=True, help="use tensorboard and logger") + parser.add_argument("--save_dir", type=str, default=path_save, help="directory for recording") + parser.add_argument("--eval_epochs", type=int, default=5, help="save model after epochs") + + args = parser.parse_args() + return args + + +def init(): + # parse arguments + args = parse_args() + + # create directory for recording + experiment_dir = Path(args.save_dir) + experiment_dir.mkdir(exist_ok=True) + + ckpt_dir = experiment_dir.joinpath("Checkpoints/") + ckpt_dir.mkdir(exist_ok=True) + print(r"===========Save checkpoints to {0}===========".format(str(ckpt_dir))) + + if args.verbose: + # initialize logger + log_dir = experiment_dir.joinpath('Log/') + log_dir.mkdir(exist_ok=True) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler = logging.FileHandler(str(log_dir) + '/Log.txt') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.info('PARAMETER ...') + logger.info(args) + # initialize tensorboard + tb_dir_all = experiment_dir.joinpath('Tensorboard_all/') + tb_dir_all.mkdir(exist_ok=True) + tensorboard_all = SummaryWriter(log_dir=str(tb_dir_all), flush_secs=30) + + tb_dir = experiment_dir.joinpath('Tensorboard/') + tb_dir.mkdir(exist_ok=True) + tensorboard = SummaryWriter(log_dir=str(tb_dir), flush_secs=30) + print(r"===========Save tensorboard and logger to {0}===========".format(str(tb_dir_all))) + else: + print(r"===========Disable tensorboard and logger to accelerate training===========") + logger = None + tensorboard_all = None + tensorboard = None + + return args, logger, ckpt_dir, tensorboard_all, tensorboard + +def yuv_read(yuv_path, h, w, iy, ix, ip): + h_c = h//2 + w_c = w//2 + + ip_c = ip//2 + iy_c = iy//2 + ix_c = ix//2 + + fp = open(yuv_path, 'rb') + + # y + fp.seek(iy*w*2, 0) + patch_y = np.fromfile(fp, np.uint16, ip*w).reshape(ip, w, 1) + patch_y = patch_y[:, ix:ix + ip, :] + + # u + fp.seek(( w*h+ iy_c*w_c)*2, 0) + patch_u = np.fromfile(fp, np.uint16, ip_c*w_c).reshape(ip_c, w_c, 1) + patch_u = patch_u[:, ix_c:ix_c + ip_c, :] + + # v + fp.seek(( w*h+ w_c*h_c+ iy_c*w_c)*2, 0) + patch_v = np.fromfile(fp, np.uint16, ip_c*w_c).reshape(ip_c, w_c, 1) + patch_v = patch_v[:, ix_c:ix_c + ip_c, :] + + fp.close() + + return patch_y, patch_u, patch_v + +def upsample(img, height, width): + img = np.squeeze(img, axis = 2) + img=np.array(Image.fromarray(img.astype(np.float)).resize((width, height), Image.NEAREST)) + img = np.expand_dims(img, axis = 2) + return img + +def patch_process(yuv_path, h, w, iy, ix, ip): + y, u, v = yuv_read(yuv_path, h, w, iy, ix, ip) + #u_up = upsample(u, ip, ip) + #v_up = upsample(v, ip, ip) + #yuv = np.concatenate((y, u_up, v_up), axis=2) + return y + +def get_patch(image_yuv_path, image_yuv_pred_path, image_yuv_rpr_path, image_yuv_org_path, w, h, patch_size, shave): + ih = h + iw = w + + ip = patch_size + ih -= ih % ip + iw -= iw % ip + iy = random.randrange(ip, ih-ip, ip) - shave + ix = random.randrange(ip, iw-ip, ip) - shave + + # + patch_rec = patch_process(image_yuv_path, h//2, w//2, iy//2, ix//2, (ip + 2*shave)//2) + patch_pre = patch_process(image_yuv_pred_path, h//2, w//2, iy//2, ix//2, (ip + 2*shave)//2) + patch_rpr = patch_process(image_yuv_rpr_path, h, w, iy, ix, ip + 2*shave) + patch_org = patch_process(image_yuv_org_path, h, w, iy, ix, ip + 2*shave) + + patch_in = np.concatenate((patch_rec, patch_pre), axis=2) + + ret = [patch_in, patch_rpr, patch_org] + + return ret + +def augment(*args): + x = random.random() + hflip = x < 0.2 + vflip = x >= 0.2 and x < 0.4 + rot90 = x >= 0.4 and x < 0.6 + + def _augment(img): + if hflip: img = img[:, ::-1, :] + if vflip: img = img[::-1, :, :] + if rot90: img = img.transpose(1, 0, 2) + + return img + + return [_augment(a) for a in args] + +def np2Tensor(*args): + def _np2Tensor(img): + np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) + tensor = torch.from_numpy(np_transpose.astype(np.int32)).float() / 1023.0 + + return tensor + + return [_np2Tensor(a) for a in args] + +def cal_psnr(distortion: torch.Tensor): + psnr = -10 * torch.log10(distortion) + return psnr diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/nn_model.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/nn_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f7060afd3386c8845881c5851c34b1bbdad207dc --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/nn_model.py @@ -0,0 +1,104 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn +from torch.nn import Parameter + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + # hyper-params + n_resblocks = 24 + n_feats_k = 64 + n_feats_m = 192 + + # define head module + self.head_rec = nn.Sequential( + nn.Conv2d(in_channels = 1, out_channels = n_feats_k, kernel_size = 3, stride = 1, padding = 1), # downsmaple by stride = 2 + nn.PReLU() + ) + self.head_pre = nn.Sequential( + nn.Conv2d(in_channels = 1, out_channels = n_feats_k, kernel_size = 3, stride = 1, padding = 1), # downsmaple by stride = 2 + nn.PReLU() + ) + + #define fuse module + self.fuse = nn.Sequential( + nn.Conv2d(in_channels = n_feats_k*2 + 2, out_channels = n_feats_k, kernel_size = 1, stride = 1, padding = 0), + nn.PReLU() + ) + + # define body module + body = [] + for _ in range(n_resblocks): + body.append(DscBlock(n_feats_k, n_feats_m)) + + self.body = nn.Sequential(*body) + + # define tail module + self.tail = nn.Sequential( + nn.Conv2d(in_channels = n_feats_k, out_channels = 4 * 1, kernel_size = 3, padding = 1), + nn.PixelShuffle(2) #feature_map:(B, 2x2x1, N, N) -> (B, 1, 2N, 2N) + ) + + def forward(self, rec, pre, rpr, slice_qp, base_qp): + in_0 = self.head_rec(rec) + in_1 = self.head_pre(pre) + + x = self.fuse(torch.cat((in_0, in_1, slice_qp, base_qp), 1)) + x = self.body(x) + x = self.tail(x) + x += rpr + + return x + +class DscBlock(nn.Module): + def __init__(self, n_feats_k, n_feats_m, expansion=1): + super(DscBlock, self).__init__() + self.expansion = expansion + self.c1 = nn.Conv2d(in_channels=n_feats_k, out_channels=n_feats_m, kernel_size=1, padding=0) + self.prelu = nn.PReLU() + self.c2 = nn.Conv2d(in_channels=n_feats_m, out_channels=n_feats_k, kernel_size=1, padding=0) + self.c3 = nn.Conv2d(in_channels=n_feats_k, out_channels=n_feats_k, kernel_size=3, padding=1) + + def forward(self, x): + i = x + x = self.c2(self.prelu(self.c1(x))) + x = self.c3(x) + x += i + + return x + + diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train.sh b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b026d766aefdfa8a5c92a46d1750661bc344f5a --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train.sh @@ -0,0 +1,32 @@ +# The copyright in this software is being made available under the BSD +# License, included below. This software may be subject to other third party +# and contributor rights, including patent rights, and no such rights are +# granted under this license. +# +# Copyright (c) 2010-2022, ITU/ISO/IEC +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +# be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +python train_YUV.py \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train_YUV.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train_YUV.py new file mode 100644 index 0000000000000000000000000000000000000000..82dc2dd461c1a6f2a07c68692baf1dfb78ab26fc --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/train_YUV.py @@ -0,0 +1,333 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data.dataloader import DataLoader +import datetime +import os, glob + +from yuv10bdata import YUV10bData +from Utils import init, cal_psnr +from nn_model import Net + +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True + + +class Trainer: + def __init__(self): + self.args, self.logger, self.checkpoints_dir, self.tensorboard_all, self.tensorboard = init() + + self.net = Net().to("cuda" if self.args.gpu else "cpu") + + self.L1loss = nn.L1Loss().to("cuda" if self.args.gpu else "cpu") + self.L2loss = nn.MSELoss().to("cuda" if self.args.gpu else "cpu") + + self.optimizer = Adam(self.net.parameters(), lr = self.args.lr) + self.scheduler = MultiStepLR(optimizer=self.optimizer, milestones=[4001, 4002], gamma=0.5) + + print("============>loading data") + self.train_dataset = YUV10bData(self.args, train=True) + self.eval_dataset = YUV10bData(self.args, train=False) + + self.train_dataloader = DataLoader(dataset=self.train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=12, pin_memory=False) + self.eval_dataloader = DataLoader(dataset=self.eval_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=12, pin_memory=False) + + self.train_steps = self.eval_steps = 0 + + def train(self): + start_epoch = self.load_checkpoints() + print("============>start training") + for epoch in range(start_epoch, self.args.max_epoch): + print("Epoch {}/{}".format(epoch, self.args.max_epoch)) + self.logger.info("Epoch {}/{}".format(epoch, self.args.max_epoch)) + self.train_one_epoch() + self.scheduler.step() + if (epoch+1) % self.args.eval_epochs == 0: + self.eval(epoch=epoch) + self.save_ckpt(epoch=epoch) + + def train_one_epoch(self): + self.net.train() + for _, tensor in enumerate(self.train_dataloader): + + img_lr, img_hr, filename = tensor + + img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") + img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") + + rec = img_lr[:,0:1,:,:] + pre = img_lr[:,1:2,:,:] + slice_qp = img_lr[:,2:3,:,:] + base_qp = img_lr[:,3:4,:,:] + img_rpr = img_hr[:,0:1,:,:] + img_ori = img_hr[:,1:2,:,:] + + img_out = self.net(rec, pre, img_rpr, slice_qp, base_qp) + + # calculate distortion + shave=self.args.shave + + L1_loss_pred_Y = self.L1loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #L1_loss_pred_Cb = self.L1loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #L1_loss_pred_Cr = self.L1loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + loss_pred_Y = self.L2loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_pred_Cb = self.L2loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_pred_Cr = self.L2loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + #loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr + loss_pred= L1_loss_pred_Y + + #loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + # visualization + self.train_steps += 1 + if self.train_steps % 20 == 0: + psnr_pred_Y = cal_psnr(loss_pred_Y) + #psnr_pred_Cb = cal_psnr(loss_pred_Cb) + #psnr_pred_Cr = cal_psnr(loss_pred_Cr) + + #psnr_input_Y = cal_psnr(loss_rec_Y) + #psnr_input_Cb = cal_psnr(loss_rec_Cb) + #psnr_input_Cr = cal_psnr(loss_rec_Cr) + + time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") + + print("[{}/{}]\tY:{:.8f}\tPSNR_Y: {:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + loss_pred_Y, psnr_pred_Y, time)) + self.logger.info("[{}/{}]\tY:{:.8f}\tPSNR_Y: {:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + loss_pred_Y, psnr_pred_Y)) + + #print("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y, time)) + #self.logger.info("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y)) + + self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + tag_scalar_dict={"pred_Y": psnr_pred_Y.data}, + global_step=self.train_steps) + #self.tensorboard_all.add_image("rec", rec[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("pre", pre[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("rpr", img_rpr[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("out", img_out[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("ori", img_ori[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + + #self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + # tag_scalar_dict={"input_Cb": psnr_input_Cb.data, + # "pred_Cb": psnr_pred_Cb.data}, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + # tag_scalar_dict={"input_Cr": psnr_input_Cr.data, + # "pred_Cr": psnr_pred_Cr.data}, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Y", + # scalar_value = psnr_pred_Y - psnr_input_Y, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cb", + # scalar_value = psnr_pred_Cb - psnr_input_Cb, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cr", + # scalar_value = psnr_pred_Cr - psnr_input_Cr, + # global_step=self.train_steps) + + self.tensorboard_all.add_scalar(tag="Train/train_loss_pred", + scalar_value = loss_pred, + global_step=self.train_steps) + + # backward + self.optimizer.zero_grad() + loss_pred.backward() + self.optimizer.step() + + @torch.no_grad() + def eval(self, epoch: int): + print("============>start evaluating") + eval_cnt = 0 + ave_psnr_Y = 0.000 + #ave_psnr_Cb = 0.000 + #ave_psnr_Cr = 0.000 + self.net.eval() + for _, tensor in enumerate(self.eval_dataloader): + + img_lr, img_hr, filename = tensor + + img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") + img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") + + rec = img_lr[:,0:1,:,:] + pre = img_lr[:,1:2,:,:] + slice_qp = img_lr[:,2:3,:,:] + base_qp = img_lr[:,3:4,:,:] + img_rpr = img_hr[:,0:1,:,:] + img_ori = img_hr[:,1:2,:,:] + img_out = self.net(rec, pre, img_rpr, slice_qp, base_qp) + + # calculate distortion and psnr + shave=self.args.shave + + L1_loss_pred_Y = self.L1loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #L1_loss_pred_Cb = self.L1loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #L1_loss_pred_Cr = self.L1loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + loss_pred_Y = self.L2loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_pred_Cb = self.L2loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_pred_Cr = self.L2loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + #loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr + loss_pred = L1_loss_pred_Y + + #loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + psnr_pred_Y = cal_psnr(loss_pred_Y) + #psnr_pred_Cb = cal_psnr(loss_pred_Cb) + #psnr_pred_Cr = cal_psnr(loss_pred_Cr) + + #psnr_input_Y = cal_psnr(loss_rec_Y) + #psnr_input_Cb = cal_psnr(loss_rec_Cb) + #psnr_input_Cr = cal_psnr(loss_rec_Cr) + + ave_psnr_Y += psnr_pred_Y + #ave_psnr_Cb += psnr_pred_Cb - psnr_input_Cb + #ave_psnr_Cr += psnr_pred_Cr - psnr_input_Cr + + eval_cnt += 1 + # visualization + self.eval_steps += 1 + if self.eval_steps % 2 == 0: + + self.tensorboard_all.add_scalar(tag="Eval/PSNR_Y", + scalar_value = psnr_pred_Y, + global_step=self.eval_steps) + + #self.tensorboard_all.add_scalar(tag="Eval/delta_PSNR_Cb", + # scalar_value = psnr_pred_Cb - psnr_input_Cb, + # global_step=self.eval_steps) + + #self.tensorboard_all.add_scalar(tag="Eval/delta_PSNR_Cr", + # scalar_value = psnr_pred_Cr - psnr_input_Cr, + # global_step=self.eval_steps) + + self.tensorboard_all.add_scalar(tag="Eval/eval_loss_pred", + scalar_value = loss_pred, + global_step=self.eval_steps) + + time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") + print("PSNR_Y:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, time)) + self.logger.info("PSNR_Y:{:.3f}".format(ave_psnr_Y / eval_cnt)) + + #print("delta_Y:{:.3f}\tdelta_Cb:{:.3f}\tdelta_Cr:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt, time)) + #self.logger.info("delta_Y:{:.3f}\tdelta_Cb:{:.3f}\tdelta_Cr:{:.3f}".format(ave_psnr_Y / eval_cnt, ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt)) + + self.tensorboard.add_scalar(tag = "Eval/PSNR_Y_ave", + scalar_value = ave_psnr_Y / eval_cnt, + global_step = epoch + 1) + #self.tensorboard.add_scalar(tag = "Eval/delta_PSNR_Cb_ave", + # scalar_value = ave_psnr_Cb / eval_cnt, + # global_step = epoch + 1) + #self.tensorboard.add_scalar(tag = "Eval/delta_PSNR_Cr_ave", + # scalar_value = ave_psnr_Cr / eval_cnt, + # global_step = epoch + 1) + + def load_checkpoints(self): + if not self.args.checkpoints: + ckpt_list=sorted(glob.glob(os.path.join(self.checkpoints_dir, '*.pth'))) + num = len(ckpt_list) + if(num > 1): + if os.path.getsize(ckpt_list[-1]) == os.path.getsize(ckpt_list[-2]): + self.args.checkpoints = ckpt_list[-1] + else: + self.args.checkpoints = ckpt_list[-2] + + if self.args.checkpoints: + print("===========Load checkpoints {0}===========".format(self.args.checkpoints)) + self.logger.info("Load checkpoints {0}".format(self.args.checkpoints)) + ckpt = torch.load(self.args.checkpoints) + # load network weights + try: + self.net.load_state_dict(ckpt["network"]) + except: + print("Can not find network weights") + # load optimizer params + try: + self.optimizer.load_state_dict(ckpt["optimizer"]) + self.scheduler.load_state_dict(ckpt["scheduler"]) + except: + print("Can not find some optimizers params, just ignore") + start_epoch = ckpt["epoch"] + 1 + self.train_steps = ckpt["train_step"] + 1 + self.eval_steps = ckpt["eval_step"] + 1 + elif self.args.pretrained: + ckpt = torch.load(self.args.pretrained) + print("===========Load network weights {0}===========".format(self.args.checkpoints)) + self.logger.info("Load network weights {0}".format(self.args.checkpoints)) + # load codec weights + try: + self.net.load_state_dict(ckpt["network"]) + except: + print("Can not find network weights") + start_epoch = 0 + else: + print("===========Training from scratch===========") + self.logger.info("Training from scratch") + start_epoch = 0 + return start_epoch + + def save_ckpt(self, epoch: int): + checkpoint = { + "network": self.net.state_dict(), + "epoch": epoch, + "train_step": self.train_steps, + "eval_step": self.eval_steps, + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict()} + + torch.save(checkpoint, '%s/model_%.4d.pth' % (self.checkpoints_dir, epoch+1)) + self.logger.info('Save model..') + print("======================Saving model {0}======================".format(str(epoch))) + +if __name__ == "__main__": + trainer = Trainer() + trainer.train() diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/yuv10bdata.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/yuv10bdata.py new file mode 100644 index 0000000000000000000000000000000000000000..05ccd8483f44138a30fe3a63d0fd9c5db4f97ad7 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-B/yuv10bdata.py @@ -0,0 +1,238 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os +import glob +from torch.utils.data import Dataset +import torch.nn.functional as F +import numpy as np +import string +import random +import Utils +import math + + +class YUV10bData(Dataset): + def __init__(self, args, name='YuvData', train=True): + super(YUV10bData, self).__init__() + self.args = args + self.split = 'train' if train else 'valid' + self.image_ext = args.ext + self.name = name + self.train = train + + data_range = [r.split('-') for r in args.data_range.split('/')] + if train: + data_range = data_range[0] + else: + data_range = data_range[1] + self.begin, self.end = list(map(lambda x: int(x), data_range)) + + data_range_tvd = [r.split('-') for r in args.data_range_tvd.split('/')] + if train: + data_range_tvd = data_range_tvd[0] + else: + data_range_tvd = data_range_tvd[1] + self.begin_tvd, self.end_tvd = list(map(lambda x: int(x), data_range_tvd)) + + self._set_data() + self._get_image_list() + + if train: + n_patches = args.batch_size * args.test_every + n_images = len(self.images_yuv) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = max(n_patches // n_images, 1) + print(f"repeating dataset {self.repeat} for one epoch") + else: + n_patches = args.batch_size * args.test_every // 25 + n_images = len(self.images_yuv) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = 5 + print(f"repeating dataset {self.repeat} for one epoch") + + def _set_data(self): + self.dir_in = os.path.join(self.args.dir_data, "yuv") + self.dir_org = os.path.join(self.args.dir_data_ori) + + self.dir_in_tvd = os.path.join(self.args.dir_data_tvd, "yuv") + self.dir_org_tvd = os.path.join(self.args.dir_data_ori_tvd) + + def _scan_class(self, is_tvd, class_name): + QPs = ['22', '27', '32', '37', '42'] + + dir_in = self.dir_in_tvd if is_tvd else self.dir_in + list_temp = glob.glob(os.path.join(dir_in, '*_' + class_name + '_*.' + self.image_ext)) + file_rec_list = [] + for i in list_temp: + print(i) + index = i.find('poc') + poc = int(i[index+3:index+6]) + if poc % 3 == 0 and poc != 0: + file_rec_list.append(i) + + if is_tvd: + file_rec_list = sorted(file_rec_list) + else: + file_rec_list = sorted(file_rec_list, key=str.lower) + + dir_org = self.dir_org_tvd if is_tvd else self.dir_org + list_temp = glob.glob(os.path.join(dir_org, class_name + '*/*.' + self.image_ext)) + #print(list_temp) + file_org_list = [] + for i in list_temp: + index = i.find('frame_') + poc = int(i[index+6:index+9]) + if poc % 3 == 0 and poc != 0: + file_org_list.append(i) + + if is_tvd: + file_org_list = sorted(file_org_list) + else: + file_org_list = sorted(file_org_list, key=str.lower) + + frame_num = 62 + frame_num_sampled = math.ceil(frame_num / 3) + begin = self.begin_tvd if is_tvd else self.begin + end = self.end_tvd if is_tvd else self.end + + class_names_yuv=[] + class_names_yuv_org=[] + + for qp in QPs: + file_list=file_rec_list[(begin-1)*frame_num_sampled*5:end*frame_num_sampled*5] + for filename in file_list: + idx = filename.find("qp") + if int(filename[idx+2:idx+4]) == int(qp): + class_names_yuv.append(filename) + + file_list=file_org_list[(begin-1)*frame_num_sampled:end*frame_num_sampled] + for filename in file_list: + class_names_yuv_org.append(filename) + + return class_names_yuv, class_names_yuv_org + + def _scan(self): + bvi_class_set = ['A'] + + names_yuv=[] + names_yuv_org=[] + for class_name in bvi_class_set: + class_names_yuv, class_names_yuv_org = self._scan_class(False, class_name) + names_yuv = names_yuv + class_names_yuv + names_yuv_org = names_yuv_org + class_names_yuv_org + + class_names_yuv, class_names_yuv_org = self._scan_class(True, 'A') + names_yuv = names_yuv + class_names_yuv + names_yuv_org = names_yuv_org + class_names_yuv_org + + print(len(names_yuv)) + print(len(names_yuv_org)) + + return names_yuv, names_yuv_org + + def _get_image_list(self): + self.images_yuv, self.images_yuv_org = self._scan() + + def __getitem__(self, idx): + patch_in, patch_org, filename = self._load_file_get_patch(idx) + pair_t = Utils.np2Tensor(patch_in, patch_org) + + return pair_t[0], pair_t[1], filename + + def __len__(self): + if self.train: + return len(self.images_yuv) * self.repeat + else: + return len(self.images_yuv) * self.repeat + + def _get_index(self, idx): + if self.train: + return idx % len(self.images_yuv) + else: + return idx % len(self.images_yuv) + + + def _load_file_get_patch(self, idx): + + idx = self._get_index(idx) + + # reconstruction + image_yuv_path = self.images_yuv[idx] + + slice_qp_idx = int(image_yuv_path.rfind("qp")) + slice_qp = int(image_yuv_path[slice_qp_idx+2:slice_qp_idx+4]) + slice_qp_map = np.uint16(np.ones(((self.args.patch_size + 2*self.args.shave)//2, (self.args.patch_size + 2*self.args.shave)//2, 1))*slice_qp) + + base_qp_idx = int(image_yuv_path.find("qp")) + base_qp = int(image_yuv_path[base_qp_idx+2:base_qp_idx+4]) + base_qp_map = np.uint16(np.ones(((self.args.patch_size + 2*self.args.shave)//2, (self.args.patch_size + 2*self.args.shave)//2, 1))*base_qp) + + # prediction + pred_str = '_prediction' + pos = image_yuv_path.find('.yuv') + image_yuv_pred_path = image_yuv_path[:pos] + pred_str + image_yuv_path[pos:] + image_yuv_pred_path = image_yuv_pred_path.replace('/yuv/', '/prediction_image/') + + # RPR + rpr_str = '_rpr' + pos = image_yuv_path.find('.yuv') + image_yuv_rpr_path = image_yuv_path[:pos] + rpr_str + image_yuv_path[pos:] + image_yuv_rpr_path = image_yuv_rpr_path.replace('/yuv/', '/rpr_image/') + + # original + image_yuv_org_path = self.images_yuv_org[idx] + org_splits = os.path.basename(os.path.dirname(image_yuv_org_path)).split('_') + wh_org=org_splits[1].split('x') + w, h = list(map(lambda x: int(x), wh_org)) + + patch_in, patch_rpr, patch_org = Utils.get_patch( image_yuv_path, image_yuv_pred_path, image_yuv_rpr_path, image_yuv_org_path, w, h, self.args.patch_size, self.args.shave ) + + patch_in = np.concatenate((patch_in, slice_qp_map, base_qp_map), axis=2) + + + if self.train: + patch_in, patch_rpr, patch_org = Utils.augment(patch_in, patch_rpr, patch_org) + + patch_lr = patch_in + patch_hr = np.concatenate((patch_rpr, patch_org), axis=2) + + return patch_lr, patch_hr, image_yuv_path + + + diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/Utils.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/Utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4850b6f42906dc91421d509de79dc52613b1c2d0 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/Utils.py @@ -0,0 +1,232 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import argparse +import datetime +import logging +import math +import random +import struct +from pathlib import Path +import numpy as np +import PIL.Image as Image +import numpy as np +import os + + +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms as tfs + +def parse_args(): + parser = argparse.ArgumentParser() + + path_cur = Path(os.path.split(os.path.realpath(__file__))[0]) + path_save = path_cur.joinpath("Experiments") + + # for loading data + parser.add_argument("--ext", type=str, default='yuv', help="data file extension") + + parser.add_argument("--data_range", type=str, default='1-180/181-200', help="train/test data range") + parser.add_argument('--dir_data', type=str, default='/path/EE1_2_2_train/AI_BVI_DVC', help='distorted dataset directory') + parser.add_argument('--dir_data_ori', type=str, default='/path/EE1_2_2_train_ori/BVI_DVC', help='raw dataset directory') + + parser.add_argument("--data_range_tvd", type=str, default='1-66/67-74', help="train/test data range") + parser.add_argument('--dir_data_tvd', type=str, default='/path/EE1_2_2_train/AI_TVD', help='distorted dataset directory') + parser.add_argument('--dir_data_ori_tvd', type=str, default='/path/EE1_2_2_train_ori/TVD', help='raw dataset directory') + + # for loading model + parser.add_argument("--checkpoints", type=str, help="checkpoints file path") + parser.add_argument("--pretrained", type=str, help="pretrained model path") + + # batch size + parser.add_argument("--batch_size", type=int, default=64, help="batch size for Fusion stage") + # do validation + parser.add_argument("--test_every",type=int, default=1200, help="do test per every N batches") + + # learning rate + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for Fusion stage") + + parser.add_argument("--gpu", action='store_true', default=True, help="use gpu or cpu") + + # epoch + parser.add_argument("--max_epoch", type=int, default=2000, help="max training epochs") + + # patch_size + parser.add_argument("--patch_size", type=int, default=128, help="train/val patch size") + parser.add_argument("--shave", type=int, default=8, help="train/shave") + + # for recording + parser.add_argument("--verbose", action='store_true', default=True, help="use tensorboard and logger") + parser.add_argument("--save_dir", type=str, default=path_save, help="directory for recording") + parser.add_argument("--eval_epochs", type=int, default=5, help="save model after epochs") + + args = parser.parse_args() + return args + + +def init(): + # parse arguments + args = parse_args() + + # create directory for recording + experiment_dir = Path(args.save_dir) + experiment_dir.mkdir(exist_ok=True) + + ckpt_dir = experiment_dir.joinpath("Checkpoints/") + ckpt_dir.mkdir(exist_ok=True) + print(r"===========Save checkpoints to {0}===========".format(str(ckpt_dir))) + + if args.verbose: + # initialize logger + log_dir = experiment_dir.joinpath('Log/') + log_dir.mkdir(exist_ok=True) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler = logging.FileHandler(str(log_dir) + '/Log.txt') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.info('PARAMETER ...') + logger.info(args) + # initialize tensorboard + tb_dir_all = experiment_dir.joinpath('Tensorboard_all/') + tb_dir_all.mkdir(exist_ok=True) + tensorboard_all = SummaryWriter(log_dir=str(tb_dir_all), flush_secs=30) + + tb_dir = experiment_dir.joinpath('Tensorboard/') + tb_dir.mkdir(exist_ok=True) + tensorboard = SummaryWriter(log_dir=str(tb_dir), flush_secs=30) + print(r"===========Save tensorboard and logger to {0}===========".format(str(tb_dir_all))) + else: + print(r"===========Disable tensorboard and logger to accelerate training===========") + logger = None + tensorboard_all = None + tensorboard = None + + return args, logger, ckpt_dir, tensorboard_all, tensorboard + +def yuv_read(yuv_path, h, w, iy, ix, ip): + h_c = h//2 + w_c = w//2 + + ip_c = ip//2 + iy_c = iy//2 + ix_c = ix//2 + + fp = open(yuv_path, 'rb') + + # y + fp.seek(iy*w*2, 0) + patch_y = np.fromfile(fp, np.uint16, ip*w).reshape(ip, w, 1) + patch_y = patch_y[:, ix:ix + ip, :] + + # u + fp.seek(( w*h+ iy_c*w_c)*2, 0) + patch_u = np.fromfile(fp, np.uint16, ip_c*w_c).reshape(ip_c, w_c, 1) + patch_u = patch_u[:, ix_c:ix_c + ip_c, :] + + # v + fp.seek(( w*h+ w_c*h_c+ iy_c*w_c)*2, 0) + patch_v = np.fromfile(fp, np.uint16, ip_c*w_c).reshape(ip_c, w_c, 1) + patch_v = patch_v[:, ix_c:ix_c + ip_c, :] + + fp.close() + + return patch_y, patch_u, patch_v + +def upsample(img, height, width): + img = np.squeeze(img, axis = 2) + img=np.array(Image.fromarray(img.astype(np.float)).resize((width, height), Image.NEAREST)) + img = np.expand_dims(img, axis = 2) + return img + +def patch_process(yuv_path, h, w, iy, ix, ip): + y, u, v = yuv_read(yuv_path, h, w, iy, ix, ip) + #u_up = upsample(u, ip, ip) + #v_up = upsample(v, ip, ip) + #yuv = np.concatenate((y, u_up, v_up), axis=2) + return y + +def get_patch(image_yuv_path, image_yuv_pred_path, image_yuv_rpr_path, image_yuv_org_path, w, h, patch_size, shave): + ih = h + iw = w + + ip = patch_size + ih -= ih % ip + iw -= iw % ip + iy = random.randrange(ip, ih-ip, ip) - shave + ix = random.randrange(ip, iw-ip, ip) - shave + + # + patch_rec = patch_process(image_yuv_path, h//2, w//2, iy//2, ix//2, (ip + 2*shave)//2) + patch_pre = patch_process(image_yuv_pred_path, h//2, w//2, iy//2, ix//2, (ip + 2*shave)//2) + patch_rpr = patch_process(image_yuv_rpr_path, h, w, iy, ix, ip + 2*shave) + patch_org = patch_process(image_yuv_org_path, h, w, iy, ix, ip + 2*shave) + + patch_in = np.concatenate((patch_rec, patch_pre), axis=2) + + ret = [patch_in, patch_rpr, patch_org] + + return ret + +def augment(*args): + x = random.random() + hflip = x < 0.2 + vflip = x >= 0.2 and x < 0.4 + rot90 = x >= 0.4 and x < 0.6 + + def _augment(img): + if hflip: img = img[:, ::-1, :] + if vflip: img = img[::-1, :, :] + if rot90: img = img.transpose(1, 0, 2) + + return img + + return [_augment(a) for a in args] + +def np2Tensor(*args): + def _np2Tensor(img): + np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) + tensor = torch.from_numpy(np_transpose.astype(np.int32)).float() / 1023.0 + + return tensor + + return [_np2Tensor(a) for a in args] + +def cal_psnr(distortion: torch.Tensor): + psnr = -10 * torch.log10(distortion) + return psnr diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/nn_model.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/nn_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b72f19c49b4ae73fe41ffb8163774445367c2fec --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/nn_model.py @@ -0,0 +1,104 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn +from torch.nn import Parameter + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + # hyper-params + n_resblocks = 24 + n_feats_k = 64 + n_feats_m = 192 + + # define head module + self.head_rec = nn.Sequential( + nn.Conv2d(in_channels = 1, out_channels = n_feats_k, kernel_size = 3, stride = 1, padding = 1), # downsmaple by stride = 2 + nn.PReLU() + ) + self.head_pre = nn.Sequential( + nn.Conv2d(in_channels = 1, out_channels = n_feats_k, kernel_size = 3, stride = 1, padding = 1), # downsmaple by stride = 2 + nn.PReLU() + ) + + #define fuse module + self.fuse = nn.Sequential( + nn.Conv2d(in_channels = n_feats_k*2 + 1, out_channels = n_feats_k, kernel_size = 1, stride = 1, padding = 0), + nn.PReLU() + ) + + # define body module + body = [] + for _ in range(n_resblocks): + body.append(DscBlock(n_feats_k, n_feats_m)) + + self.body = nn.Sequential(*body) + + # define tail module + self.tail = nn.Sequential( + nn.Conv2d(in_channels = n_feats_k, out_channels = 4 * 1, kernel_size = 3, padding = 1), + nn.PixelShuffle(2) #feature_map:(B, 2x2x1, N, N) -> (B, 1, 2N, 2N) + ) + + def forward(self, rec, pre, rpr, slice_qp): + in_0 = self.head_rec(rec) + in_1 = self.head_pre(pre) + + x = self.fuse(torch.cat((in_0, in_1, slice_qp), 1)) + x = self.body(x) + x = self.tail(x) + x += rpr + + return x + +class DscBlock(nn.Module): + def __init__(self, n_feats_k, n_feats_m, expansion=1): + super(DscBlock, self).__init__() + self.expansion = expansion + self.c1 = nn.Conv2d(in_channels=n_feats_k, out_channels=n_feats_m, kernel_size=1, padding=0) + self.prelu = nn.PReLU() + self.c2 = nn.Conv2d(in_channels=n_feats_m, out_channels=n_feats_k, kernel_size=1, padding=0) + self.c3 = nn.Conv2d(in_channels=n_feats_k, out_channels=n_feats_k, kernel_size=3, padding=1) + + def forward(self, x): + i = x + x = self.c2(self.prelu(self.c1(x))) + x = self.c3(x) + x += i + + return x + + diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train.sh b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b026d766aefdfa8a5c92a46d1750661bc344f5a --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train.sh @@ -0,0 +1,32 @@ +# The copyright in this software is being made available under the BSD +# License, included below. This software may be subject to other third party +# and contributor rights, including patent rights, and no such rights are +# granted under this license. +# +# Copyright (c) 2010-2022, ITU/ISO/IEC +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +# be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +python train_YUV.py \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train_YUV.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train_YUV.py new file mode 100644 index 0000000000000000000000000000000000000000..a181c120b688065284217a659538580c2ac12e9d --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/train_YUV.py @@ -0,0 +1,331 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import torch.nn as nn +from torch.optim.adam import Adam +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data.dataloader import DataLoader +import datetime +import os, glob + +from yuv10bdata import YUV10bData +from Utils import init, cal_psnr +from nn_model import Net + +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True + + +class Trainer: + def __init__(self): + self.args, self.logger, self.checkpoints_dir, self.tensorboard_all, self.tensorboard = init() + + self.net = Net().to("cuda" if self.args.gpu else "cpu") + + self.L1loss = nn.L1Loss().to("cuda" if self.args.gpu else "cpu") + self.L2loss = nn.MSELoss().to("cuda" if self.args.gpu else "cpu") + + self.optimizer = Adam(self.net.parameters(), lr = self.args.lr) + self.scheduler = MultiStepLR(optimizer=self.optimizer, milestones=[4001, 4002], gamma=0.5) + + print("============>loading data") + self.train_dataset = YUV10bData(self.args, train=True) + self.eval_dataset = YUV10bData(self.args, train=False) + + self.train_dataloader = DataLoader(dataset=self.train_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=12, pin_memory=False) + self.eval_dataloader = DataLoader(dataset=self.eval_dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=12, pin_memory=False) + + self.train_steps = self.eval_steps = 0 + + def train(self): + start_epoch = self.load_checkpoints() + print("============>start training") + for epoch in range(start_epoch, self.args.max_epoch): + print("Epoch {}/{}".format(epoch, self.args.max_epoch)) + self.logger.info("Epoch {}/{}".format(epoch, self.args.max_epoch)) + self.train_one_epoch() + self.scheduler.step() + if (epoch+1) % self.args.eval_epochs == 0: + self.eval(epoch=epoch) + self.save_ckpt(epoch=epoch) + + def train_one_epoch(self): + self.net.train() + for _, tensor in enumerate(self.train_dataloader): + + img_lr, img_hr, filename = tensor + + img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") + img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") + + rec = img_lr[:,0:1,:,:] + pre = img_lr[:,1:2,:,:] + slice_qp = img_lr[:,2:3,:,:] + img_rpr = img_hr[:,0:1,:,:] + img_ori = img_hr[:,1:2,:,:] + + img_out = self.net(rec, pre, img_rpr, slice_qp) + + # calculate distortion + shave=self.args.shave + + L1_loss_pred_Y = self.L1loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #L1_loss_pred_Cb = self.L1loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #L1_loss_pred_Cr = self.L1loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + loss_pred_Y = self.L2loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_pred_Cb = self.L2loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_pred_Cr = self.L2loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + #loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr + loss_pred= L1_loss_pred_Y + + #loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + # visualization + self.train_steps += 1 + if self.train_steps % 20 == 0: + psnr_pred_Y = cal_psnr(loss_pred_Y) + #psnr_pred_Cb = cal_psnr(loss_pred_Cb) + #psnr_pred_Cr = cal_psnr(loss_pred_Cr) + + #psnr_input_Y = cal_psnr(loss_rec_Y) + #psnr_input_Cb = cal_psnr(loss_rec_Cb) + #psnr_input_Cr = cal_psnr(loss_rec_Cr) + + time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") + + print("[{}/{}]\tY:{:.8f}\tPSNR_Y: {:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + loss_pred_Y, psnr_pred_Y, time)) + self.logger.info("[{}/{}]\tY:{:.8f}\tPSNR_Y: {:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + loss_pred_Y, psnr_pred_Y)) + + #print("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}------{}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y, time)) + #self.logger.info("[{}/{}]\tY:{:.8f}\tCb:{:.8f}\tCr:{:.8f}\tdelta_Y: {:.8f}".format((self.train_steps % len(self.train_dataloader)), len(self.train_dataloader), + # loss_pred_Y, loss_pred_Cb, loss_pred_Cr, psnr_pred_Y - psnr_input_Y)) + + self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + tag_scalar_dict={"pred_Y": psnr_pred_Y.data}, + global_step=self.train_steps) + #self.tensorboard_all.add_image("rec", rec[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("pre", pre[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("rpr", img_rpr[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("out", img_out[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + #self.tensorboard_all.add_image("ori", img_ori[0:1,:,:,:].squeeze(dim=0), global_step=self.train_steps) + + #self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + # tag_scalar_dict={"input_Cb": psnr_input_Cb.data, + # "pred_Cb": psnr_pred_Cb.data}, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalars(main_tag="Train/PSNR", + # tag_scalar_dict={"input_Cr": psnr_input_Cr.data, + # "pred_Cr": psnr_pred_Cr.data}, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Y", + # scalar_value = psnr_pred_Y - psnr_input_Y, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cb", + # scalar_value = psnr_pred_Cb - psnr_input_Cb, + # global_step=self.train_steps) + + #self.tensorboard_all.add_scalar(tag="Train/delta_PSNR_Cr", + # scalar_value = psnr_pred_Cr - psnr_input_Cr, + # global_step=self.train_steps) + + self.tensorboard_all.add_scalar(tag="Train/train_loss_pred", + scalar_value = loss_pred, + global_step=self.train_steps) + + # backward + self.optimizer.zero_grad() + loss_pred.backward() + self.optimizer.step() + + @torch.no_grad() + def eval(self, epoch: int): + print("============>start evaluating") + eval_cnt = 0 + ave_psnr_Y = 0.000 + #ave_psnr_Cb = 0.000 + #ave_psnr_Cr = 0.000 + self.net.eval() + for _, tensor in enumerate(self.eval_dataloader): + + img_lr, img_hr, filename = tensor + + img_lr = img_lr.to("cuda" if self.args.gpu else "cpu") + img_hr = img_hr.to("cuda" if self.args.gpu else "cpu") + + rec = img_lr[:,0:1,:,:] + pre = img_lr[:,1:2,:,:] + slice_qp = img_lr[:,2:3,:,:] + img_rpr = img_hr[:,0:1,:,:] + img_ori = img_hr[:,1:2,:,:] + img_out = self.net(rec, pre, img_rpr, slice_qp) + + # calculate distortion and psnr + shave=self.args.shave + + L1_loss_pred_Y = self.L1loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #L1_loss_pred_Cb = self.L1loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #L1_loss_pred_Cr = self.L1loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + loss_pred_Y = self.L2loss(img_out[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_pred_Cb = self.L2loss(img_out[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_pred_Cr = self.L2loss(img_out[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + #loss_pred = 10*L1_loss_pred_Y + L1_loss_pred_Cb + L1_loss_pred_Cr + loss_pred = L1_loss_pred_Y + + #loss_rec_Y = self.L2loss(img_in[:,0,shave:-shave,shave:-shave], img_ori[:, 0,shave:-shave,shave:-shave]) + #loss_rec_Cb = self.L2loss(img_in[:,1,shave:-shave,shave:-shave], img_ori[:, 1,shave:-shave,shave:-shave]) + #loss_rec_Cr = self.L2loss(img_in[:,2,shave:-shave,shave:-shave], img_ori[:, 2,shave:-shave,shave:-shave]) + + psnr_pred_Y = cal_psnr(loss_pred_Y) + #psnr_pred_Cb = cal_psnr(loss_pred_Cb) + #psnr_pred_Cr = cal_psnr(loss_pred_Cr) + + #psnr_input_Y = cal_psnr(loss_rec_Y) + #psnr_input_Cb = cal_psnr(loss_rec_Cb) + #psnr_input_Cr = cal_psnr(loss_rec_Cr) + + ave_psnr_Y += psnr_pred_Y + #ave_psnr_Cb += psnr_pred_Cb - psnr_input_Cb + #ave_psnr_Cr += psnr_pred_Cr - psnr_input_Cr + + eval_cnt += 1 + # visualization + self.eval_steps += 1 + if self.eval_steps % 2 == 0: + + self.tensorboard_all.add_scalar(tag="Eval/PSNR_Y", + scalar_value = psnr_pred_Y, + global_step=self.eval_steps) + + #self.tensorboard_all.add_scalar(tag="Eval/delta_PSNR_Cb", + # scalar_value = psnr_pred_Cb - psnr_input_Cb, + # global_step=self.eval_steps) + + #self.tensorboard_all.add_scalar(tag="Eval/delta_PSNR_Cr", + # scalar_value = psnr_pred_Cr - psnr_input_Cr, + # global_step=self.eval_steps) + + self.tensorboard_all.add_scalar(tag="Eval/eval_loss_pred", + scalar_value = loss_pred, + global_step=self.eval_steps) + + time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M") + print("PSNR_Y:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, time)) + self.logger.info("PSNR_Y:{:.3f}".format(ave_psnr_Y / eval_cnt)) + + #print("delta_Y:{:.3f}\tdelta_Cb:{:.3f}\tdelta_Cr:{:.3f}------{}".format(ave_psnr_Y / eval_cnt, ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt, time)) + #self.logger.info("delta_Y:{:.3f}\tdelta_Cb:{:.3f}\tdelta_Cr:{:.3f}".format(ave_psnr_Y / eval_cnt, ave_psnr_Cb / eval_cnt, ave_psnr_Cr / eval_cnt)) + + self.tensorboard.add_scalar(tag = "Eval/PSNR_Y_ave", + scalar_value = ave_psnr_Y / eval_cnt, + global_step = epoch + 1) + #self.tensorboard.add_scalar(tag = "Eval/delta_PSNR_Cb_ave", + # scalar_value = ave_psnr_Cb / eval_cnt, + # global_step = epoch + 1) + #self.tensorboard.add_scalar(tag = "Eval/delta_PSNR_Cr_ave", + # scalar_value = ave_psnr_Cr / eval_cnt, + # global_step = epoch + 1) + + def load_checkpoints(self): + if not self.args.checkpoints: + ckpt_list=sorted(glob.glob(os.path.join(self.checkpoints_dir, '*.pth'))) + num = len(ckpt_list) + if(num > 1): + if os.path.getsize(ckpt_list[-1]) == os.path.getsize(ckpt_list[-2]): + self.args.checkpoints = ckpt_list[-1] + else: + self.args.checkpoints = ckpt_list[-2] + + if self.args.checkpoints: + print("===========Load checkpoints {0}===========".format(self.args.checkpoints)) + self.logger.info("Load checkpoints {0}".format(self.args.checkpoints)) + ckpt = torch.load(self.args.checkpoints) + # load network weights + try: + self.net.load_state_dict(ckpt["network"]) + except: + print("Can not find network weights") + # load optimizer params + try: + self.optimizer.load_state_dict(ckpt["optimizer"]) + self.scheduler.load_state_dict(ckpt["scheduler"]) + except: + print("Can not find some optimizers params, just ignore") + start_epoch = ckpt["epoch"] + 1 + self.train_steps = ckpt["train_step"] + 1 + self.eval_steps = ckpt["eval_step"] + 1 + elif self.args.pretrained: + ckpt = torch.load(self.args.pretrained) + print("===========Load network weights {0}===========".format(self.args.checkpoints)) + self.logger.info("Load network weights {0}".format(self.args.checkpoints)) + # load codec weights + try: + self.net.load_state_dict(ckpt["network"]) + except: + print("Can not find network weights") + start_epoch = 0 + else: + print("===========Training from scratch===========") + self.logger.info("Training from scratch") + start_epoch = 0 + return start_epoch + + def save_ckpt(self, epoch: int): + checkpoint = { + "network": self.net.state_dict(), + "epoch": epoch, + "train_step": self.train_steps, + "eval_step": self.eval_steps, + "optimizer": self.optimizer.state_dict(), + "scheduler": self.scheduler.state_dict()} + + torch.save(checkpoint, '%s/model_%.4d.pth' % (self.checkpoints_dir, epoch+1)) + self.logger.info('Save model..') + print("======================Saving model {0}======================".format(str(epoch))) + +if __name__ == "__main__": + trainer = Trainer() + trainer.train() diff --git a/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/yuv10bdata.py b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/yuv10bdata.py new file mode 100644 index 0000000000000000000000000000000000000000..a5612b228f41bdc14a8640d6f7c4e96960eb4174 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/3_train_tasks/training_scripts/Luma-I/yuv10bdata.py @@ -0,0 +1,237 @@ +""" +/* The copyright in this software is being made available under the BSD +* License, included below. This software may be subject to other third party +* and contributor rights, including patent rights, and no such rights are +* granted under this license. +* +* Copyright (c) 2010-2022, ITU/ISO/IEC +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* * Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* * Neither the name of the ITU/ISO/IEC nor the names of its contributors may +* be used to endorse or promote products derived from this software without +* specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS +* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +* THE POSSIBILITY OF SUCH DAMAGE. +""" + +import os +import glob +from torch.utils.data import Dataset +import torch.nn.functional as F +import numpy as np +import string +import random +import Utils +import math + + +class YUV10bData(Dataset): + def __init__(self, args, name='YuvData', train=True): + super(YUV10bData, self).__init__() + self.args = args + self.split = 'train' if train else 'valid' + self.image_ext = args.ext + self.name = name + self.train = train + + data_range = [r.split('-') for r in args.data_range.split('/')] + if train: + data_range = data_range[0] + else: + data_range = data_range[1] + self.begin, self.end = list(map(lambda x: int(x), data_range)) + + data_range_tvd = [r.split('-') for r in args.data_range_tvd.split('/')] + if train: + data_range_tvd = data_range_tvd[0] + else: + data_range_tvd = data_range_tvd[1] + self.begin_tvd, self.end_tvd = list(map(lambda x: int(x), data_range_tvd)) + + self._set_data() + self._get_image_list() + + if train: + n_patches = args.batch_size * args.test_every + n_images = len(self.images_yuv) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = max(n_patches // n_images, 1) + print(f"repeating dataset {self.repeat} for one epoch") + else: + n_patches = args.batch_size * args.test_every // 25 + n_images = len(self.images_yuv) + if n_images == 0: + self.repeat = 0 + else: + self.repeat = 5 + print(f"repeating dataset {self.repeat} for one epoch") + + def _set_data(self): + self.dir_in = os.path.join(self.args.dir_data, "yuv") + self.dir_org = os.path.join(self.args.dir_data_ori) + + self.dir_in_tvd = os.path.join(self.args.dir_data_tvd, "yuv") + self.dir_org_tvd = os.path.join(self.args.dir_data_ori_tvd) + + def _scan_class(self, is_tvd, class_name): + QPs = ['22', '27', '32', '37', '42'] + + dir_in = self.dir_in_tvd if is_tvd else self.dir_in + list_temp = glob.glob(os.path.join(dir_in, '*_' + class_name + '_*.' + self.image_ext)) + file_rec_list = [] + for i in list_temp: + index = i.find('poc') + poc = int(i[index+3:index+6]) + if poc % 3 == 0 and poc != 0: + file_rec_list.append(i) + + if is_tvd: + file_rec_list = sorted(file_rec_list) + else: + file_rec_list = sorted(file_rec_list, key=str.lower) + + dir_org = self.dir_org_tvd if is_tvd else self.dir_org + list_temp = glob.glob(os.path.join(dir_org, class_name + '*/*.' + self.image_ext)) + #print(list_temp) + file_org_list = [] + for i in list_temp: + index = i.find('frame_') + poc = int(i[index+6:index+9]) + if poc % 3 == 0 and poc != 0: + file_org_list.append(i) + + if is_tvd: + file_org_list = sorted(file_org_list) + else: + file_org_list = sorted(file_org_list, key=str.lower) + + frame_num = 62 + frame_num_sampled = math.ceil(frame_num / 3) + begin = self.begin_tvd if is_tvd else self.begin + end = self.end_tvd if is_tvd else self.end + + class_names_yuv=[] + class_names_yuv_org=[] + + for qp in QPs: + file_list=file_rec_list[(begin-1)*frame_num_sampled*5:end*frame_num_sampled*5] + for filename in file_list: + idx = filename.find("qp") + if int(filename[idx+2:idx+4]) == int(qp): + class_names_yuv.append(filename) + + file_list=file_org_list[(begin-1)*frame_num_sampled:end*frame_num_sampled] + for filename in file_list: + class_names_yuv_org.append(filename) + + return class_names_yuv, class_names_yuv_org + + def _scan(self): + bvi_class_set = ['A'] + + names_yuv=[] + names_yuv_org=[] + for class_name in bvi_class_set: + class_names_yuv, class_names_yuv_org = self._scan_class(False, class_name) + names_yuv = names_yuv + class_names_yuv + names_yuv_org = names_yuv_org + class_names_yuv_org + + class_names_yuv, class_names_yuv_org = self._scan_class(True, 'A') + names_yuv = names_yuv + class_names_yuv + names_yuv_org = names_yuv_org + class_names_yuv_org + + print(len(names_yuv)) + print(len(names_yuv_org)) + + return names_yuv, names_yuv_org + + def _get_image_list(self): + self.images_yuv, self.images_yuv_org = self._scan() + + def __getitem__(self, idx): + patch_in, patch_org, filename = self._load_file_get_patch(idx) + pair_t = Utils.np2Tensor(patch_in, patch_org) + + return pair_t[0], pair_t[1], filename + + def __len__(self): + if self.train: + return len(self.images_yuv) * self.repeat + else: + return len(self.images_yuv) * self.repeat + + def _get_index(self, idx): + if self.train: + return idx % len(self.images_yuv) + else: + return idx % len(self.images_yuv) + + + def _load_file_get_patch(self, idx): + + idx = self._get_index(idx) + + # reconstruction + image_yuv_path = self.images_yuv[idx] + + slice_qp_idx = int(image_yuv_path.rfind("qp")) + slice_qp = int(image_yuv_path[slice_qp_idx+2:slice_qp_idx+4]) + slice_qp_map = np.uint16(np.ones(((self.args.patch_size + 2*self.args.shave)//2, (self.args.patch_size + 2*self.args.shave)//2, 1))*slice_qp) + + base_qp_idx = int(image_yuv_path.find("qp")) + base_qp = int(image_yuv_path[base_qp_idx+2:base_qp_idx+4]) + base_qp_map = np.uint16(np.ones(((self.args.patch_size + 2*self.args.shave)//2, (self.args.patch_size + 2*self.args.shave)//2, 1))*base_qp) + + # prediction + pred_str = '_prediction' + pos = image_yuv_path.find('.yuv') + image_yuv_pred_path = image_yuv_path[:pos] + pred_str + image_yuv_path[pos:] + image_yuv_pred_path = image_yuv_pred_path.replace('/yuv/', '/prediction_image/') + + # RPR + rpr_str = '_rpr' + pos = image_yuv_path.find('.yuv') + image_yuv_rpr_path = image_yuv_path[:pos] + rpr_str + image_yuv_path[pos:] + image_yuv_rpr_path = image_yuv_rpr_path.replace('/yuv/', '/rpr_image/') + + # original + image_yuv_org_path = self.images_yuv_org[idx] + org_splits = os.path.basename(os.path.dirname(image_yuv_org_path)).split('_') + wh_org=org_splits[1].split('x') + w, h = list(map(lambda x: int(x), wh_org)) + + patch_in, patch_rpr, patch_org = Utils.get_patch( image_yuv_path, image_yuv_pred_path, image_yuv_rpr_path, image_yuv_org_path, w, h, self.args.patch_size, self.args.shave ) + + patch_in = np.concatenate((patch_in, slice_qp_map), axis=2) + + + if self.train: + patch_in, patch_rpr, patch_org = Utils.augment(patch_in, patch_rpr, patch_org) + + patch_lr = patch_in + patch_hr = np.concatenate((patch_rpr, patch_org), axis=2) + + return patch_lr, patch_hr, image_yuv_path + + + diff --git a/training/training_scripts/NN_Super_Resolution/ReadMe.md b/training/training_scripts/NN_Super_Resolution/ReadMe.md new file mode 100644 index 0000000000000000000000000000000000000000..82018fc8a01bb8ca5d7bafe94905d3d2c6a236f0 --- /dev/null +++ b/training/training_scripts/NN_Super_Resolution/ReadMe.md @@ -0,0 +1,14 @@ +## Overview + +Requirements: +* One GPU with greater than 25GiB memory. +* Preferably the disk storage size is greater than 5TB. + +The overview of relationships among these scripts is shown below: +* Generate the raw data +* Generate the compression data +* Training and final model conversion + + +For the better viewing experience, please open those files with [typora](https://typora.io/). +Certainly, the guidance can be viewed as a general txt file, but the figure inside can not be displayed directly. In this case, you can find the needed figure in the figure folder and view it individually. \ No newline at end of file diff --git a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_chroma-IB.png b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_chroma-IB.png new file mode 100644 index 0000000000000000000000000000000000000000..c14898bdf23ac1d79354538c9579b1f07161cca7 Binary files /dev/null and b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_chroma-IB.png differ diff --git a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-B.png b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-B.png new file mode 100644 index 0000000000000000000000000000000000000000..e97a6476b4fea3b3bb525287563faa660d9a6a30 Binary files /dev/null and b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-B.png differ diff --git a/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-I.png b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-I.png new file mode 100644 index 0000000000000000000000000000000000000000..c6b552699415b12db2174320524a375ee6b843f3 Binary files /dev/null and b/training/training_scripts/NN_Super_Resolution/figure/convergence_curve_luma-I.png differ diff --git a/training/training_scripts/NN_Super_Resolution/figure/overview.png b/training/training_scripts/NN_Super_Resolution/figure/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..88af1b996f999c46506d2ffe0d700b0ca158f259 Binary files /dev/null and b/training/training_scripts/NN_Super_Resolution/figure/overview.png differ