diff --git a/training/data_loader/data_loader.py b/training/data_loader/data_loader.py index 5ac6de30bdfa8fda4a8ed2db43b74ac559a1ac62..3a1dc555590da821a4f88f34f00af268a1fe2e27 100644 --- a/training/data_loader/data_loader.py +++ b/training/data_loader/data_loader.py @@ -57,7 +57,7 @@ class DataLoader: # patch_size in luma sample def __init__(self, jsonfile, patch_size, poc_list, generate_type = 0, qp_filter=-1, slice_type_filter=-1): self.generate_type=generate_type - if self.generate_type: + if self.generate_type == 0: self.normalizer_rec = 1023.0 self.normalizer_pred = 1023.0 self.normalizer_bs = 1023.0 @@ -123,7 +123,7 @@ class DataLoader: self.database=dcontent['data'] # create array of patches adress - if self.generate_type: + if self.generate_type == 0: psize = self.patch_size for didx in range(len(self.database)): d=self.database[didx]