diff --git a/training/training_scripts/NN_Filtering_HOP/convert/to_onnx.py b/training/training_scripts/NN_Filtering_HOP/convert/to_onnx.py index fc5fd0aaa657d834a13b14f6c295826968e6eb52..8db1b4fe4b7f72b18998556159a4cae69e1a4814 100644 --- a/training/training_scripts/NN_Filtering_HOP/convert/to_onnx.py +++ b/training/training_scripts/NN_Filtering_HOP/convert/to_onnx.py @@ -68,4 +68,4 @@ modelobj = Trainer.instantiate_from_dict(model, config["model"]) checkpoint = torch.load(args.ckpt, map_location=torch.device("cpu")) modelobj.load_state_dict(checkpoint["model"]) -modelobj.SADL_model.to_onnx(args.output, args.batch_size, args.patch_size) +modelobj.SADL_model.to_onnx(args.output, args.patch_size, args.batch_size) diff --git a/training/training_scripts/NN_Filtering_HOP/training/trainer.py b/training/training_scripts/NN_Filtering_HOP/training/trainer.py index d92c91ccfec202e3643e3c94023244db9216d54c..1802b2bd60576d96a78305a2fcf4946f378a37ee 100644 --- a/training/training_scripts/NN_Filtering_HOP/training/trainer.py +++ b/training/training_scripts/NN_Filtering_HOP/training/trainer.py @@ -65,7 +65,9 @@ class Trainer: self.device = self.config_training["device"] or ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) - print(f"[INFO] tf32 {torch.backends.cuda.matmul.allow_tf32} {torch.backends.cudnn.allow_tf32}") + print( + f"[INFO] tf32 {torch.backends.cuda.matmul.allow_tf32} {torch.backends.cudnn.allow_tf32}" + ) self.base_dir = self.config_training["path"] self.save_dir = os.path.join(self.base_dir, self.config_training["ckpt_dir"]) os.makedirs(self.save_dir, exist_ok=True) @@ -232,9 +234,11 @@ class Trainer: self.loggers.on_train_start() for epoch in range(self.current_epoch, self.config_training["max_epochs"]): self.current_epoch = epoch - if epoch == self.config_training["mse_epoch"]: + if epoch >= self.config_training["mse_epoch"]: + print(f"[INFO] epoch {epoch} L2 loss, lr={self.lr_scheduler.get_last_lr()}") self.loss_function = nn.MSELoss() - + else: + print(f"[INFO] epoch {epoch} L1 loss, lr={self.lr_scheduler.get_last_lr()}") self.loggers.on_train_epoch_start(self.current_epoch) self.train_epoch() self.loggers.on_train_epoch_end(self.current_epoch)