diff --git a/training/training_scripts/NN_Filtering_HOP/training/logger.py b/training/training_scripts/NN_Filtering_HOP/training/logger.py index 0d39d16132ee5ee28e2a3c864fc408f5942a3608..986ba6924c6812ec9e3ed41b3048c5df40e41c64 100644 --- a/training/training_scripts/NN_Filtering_HOP/training/logger.py +++ b/training/training_scripts/NN_Filtering_HOP/training/logger.py @@ -144,12 +144,16 @@ class PrintLogger(BaseLogger): def on_train_iter_end( self, epoch: int, iteration: int, train_metrics: Dict[str, Any] ) -> None: - if self.log_train_interval > 0 and (iteration + 1) % self.log_train_interval == 0: + if ( + self.log_train_interval > 0 + and (iteration + 1) % self.log_train_interval == 0 + ): print( f"Epoch {epoch}, iteration {iteration}: {self.format_metrics(train_metrics)}", file=self.out_file, ) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() def on_val_epoch_end( self, epoch: int, val_metrics: Dict[str, Dict[str, Any]] @@ -161,7 +165,8 @@ class PrintLogger(BaseLogger): f"\t{val_tag}: {self.format_metrics(val_tag_metrics)}", file=self.out_file, ) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() class ProgressLogger(BaseLogger): @@ -197,12 +202,14 @@ class ProgressLogger(BaseLogger): def on_train_start(self): print(f"{datetime.datetime.now()}: Training started", file=self.out_file) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() def on_train_end(self): if self.log_stage_ends: print(f"{datetime.datetime.now()}: Training finished", file=self.out_file) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() def on_train_epoch_start(self, epoch: int): if self.log_train_epochs: @@ -210,7 +217,8 @@ class ProgressLogger(BaseLogger): f"{datetime.datetime.now()}: Training started epoch {epoch}", file=self.out_file, ) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() def on_train_epoch_end(self, epoch: int): if self.log_train_epochs and self.log_stage_ends: @@ -218,7 +226,8 @@ class ProgressLogger(BaseLogger): f"{datetime.datetime.now()}: Training finished epoch {epoch}", file=self.out_file, ) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() def on_train_iter_start(self, epoch: int, iteration: int): if self.log_train_iterations: @@ -226,7 +235,8 @@ class ProgressLogger(BaseLogger): f"{datetime.datetime.now()}: Training starting epoch {epoch}, iteration {iteration}", file=self.out_file, ) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() def on_train_iter_end( self, epoch: int, iteration: int, train_metrics: Dict[str, Any] @@ -236,7 +246,8 @@ class ProgressLogger(BaseLogger): f"{datetime.datetime.now()}: Training finished epoch {epoch}, iteration {iteration}", file=self.out_file, ) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() def on_val_epoch_start(self, epoch: int): if self.log_val_epochs: @@ -244,7 +255,8 @@ class ProgressLogger(BaseLogger): f"{datetime.datetime.now()}: Validation starting epoch {epoch}", file=self.out_file, ) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() def on_val_epoch_end(self, epoch: int, val_metrics: Dict[str, Dict[str, Any]]): if self.log_val_epochs and self.log_stage_ends: @@ -252,7 +264,8 @@ class ProgressLogger(BaseLogger): f"{datetime.datetime.now()}: Validation finished epoch {epoch}", file=self.out_file, ) - self.out_file.flush() + if self.out_file is not None: + self.out_file.flush() from torch.utils.tensorboard import SummaryWriter # noqa: E402 @@ -290,7 +303,10 @@ class TensorboardLogger(BaseLogger): def on_train_iter_end( self, epoch: int, iteration: int, train_metrics: Dict[str, Any] ) -> None: - if self.log_train_interval > 0 and (iteration + 1) % self.log_train_interval == 0: + if ( + self.log_train_interval > 0 + and (iteration + 1) % self.log_train_interval == 0 + ): self.global_iteration += self.log_train_interval for metric, value in train_metrics.items(): self.writer.add_scalar(