diff --git a/training/training_scripts/NN_Filtering_HOP/training/logger.py b/training/training_scripts/NN_Filtering_HOP/training/logger.py index a7e59f97d1453c1ff17cdbb3825fa4b9d966691c..229149eb15ca432fc9413259e8d0a635c2f1d9b3 100644 --- a/training/training_scripts/NN_Filtering_HOP/training/logger.py +++ b/training/training_scripts/NN_Filtering_HOP/training/logger.py @@ -144,7 +144,7 @@ class PrintLogger(BaseLogger): def on_train_iter_end( self, epoch: int, iteration: int, train_metrics: Dict[str, Any] ) -> None: - if self.log_train_interval > 0 and (iteration + 1) % self.log_train_interval: + if self.log_train_interval > 0 and (iteration + 1) % self.log_train_interval == 0: print( f"Epoch {epoch}, iteration {iteration}: {self.format_metrics(train_metrics)}", file=self.out_file, @@ -280,7 +280,7 @@ class TensorboardLogger(BaseLogger): def on_train_iter_end( self, epoch: int, iteration: int, train_metrics: Dict[str, Any] ) -> None: - if self.log_train_interval > 0 and (iteration + 1) % self.log_train_interval: + if self.log_train_interval > 0 and (iteration + 1) % self.log_train_interval == 0: self.global_iteration += self.log_train_interval for metric, value in train_metrics.items(): self.writer.add_scalar(