Skip to content
Snippets Groups Projects
Commit 4d7a1439 authored by Yue Li's avatar Yue Li
Browse files

Merge branch 'flush_logger' into 'VTM-11.0_nnvc'

avoid flush with None file

See merge request jvet-ahg-nnvc/VVCSoftware_VTM!117
parents 9ce6e037 a2e6fd20
No related branches found
No related tags found
No related merge requests found
......@@ -144,12 +144,16 @@ class PrintLogger(BaseLogger):
def on_train_iter_end(
self, epoch: int, iteration: int, train_metrics: Dict[str, Any]
) -> None:
if self.log_train_interval > 0 and (iteration + 1) % self.log_train_interval == 0:
if (
self.log_train_interval > 0
and (iteration + 1) % self.log_train_interval == 0
):
print(
f"Epoch {epoch}, iteration {iteration}: {self.format_metrics(train_metrics)}",
file=self.out_file,
)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
def on_val_epoch_end(
self, epoch: int, val_metrics: Dict[str, Dict[str, Any]]
......@@ -161,7 +165,8 @@ class PrintLogger(BaseLogger):
f"\t{val_tag}: {self.format_metrics(val_tag_metrics)}",
file=self.out_file,
)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
class ProgressLogger(BaseLogger):
......@@ -197,12 +202,14 @@ class ProgressLogger(BaseLogger):
def on_train_start(self):
print(f"{datetime.datetime.now()}: Training started", file=self.out_file)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
def on_train_end(self):
if self.log_stage_ends:
print(f"{datetime.datetime.now()}: Training finished", file=self.out_file)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
def on_train_epoch_start(self, epoch: int):
if self.log_train_epochs:
......@@ -210,7 +217,8 @@ class ProgressLogger(BaseLogger):
f"{datetime.datetime.now()}: Training started epoch {epoch}",
file=self.out_file,
)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
def on_train_epoch_end(self, epoch: int):
if self.log_train_epochs and self.log_stage_ends:
......@@ -218,7 +226,8 @@ class ProgressLogger(BaseLogger):
f"{datetime.datetime.now()}: Training finished epoch {epoch}",
file=self.out_file,
)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
def on_train_iter_start(self, epoch: int, iteration: int):
if self.log_train_iterations:
......@@ -226,7 +235,8 @@ class ProgressLogger(BaseLogger):
f"{datetime.datetime.now()}: Training starting epoch {epoch}, iteration {iteration}",
file=self.out_file,
)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
def on_train_iter_end(
self, epoch: int, iteration: int, train_metrics: Dict[str, Any]
......@@ -236,7 +246,8 @@ class ProgressLogger(BaseLogger):
f"{datetime.datetime.now()}: Training finished epoch {epoch}, iteration {iteration}",
file=self.out_file,
)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
def on_val_epoch_start(self, epoch: int):
if self.log_val_epochs:
......@@ -244,7 +255,8 @@ class ProgressLogger(BaseLogger):
f"{datetime.datetime.now()}: Validation starting epoch {epoch}",
file=self.out_file,
)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
def on_val_epoch_end(self, epoch: int, val_metrics: Dict[str, Dict[str, Any]]):
if self.log_val_epochs and self.log_stage_ends:
......@@ -252,7 +264,8 @@ class ProgressLogger(BaseLogger):
f"{datetime.datetime.now()}: Validation finished epoch {epoch}",
file=self.out_file,
)
self.out_file.flush()
if self.out_file is not None:
self.out_file.flush()
from torch.utils.tensorboard import SummaryWriter # noqa: E402
......@@ -290,7 +303,10 @@ class TensorboardLogger(BaseLogger):
def on_train_iter_end(
self, epoch: int, iteration: int, train_metrics: Dict[str, Any]
) -> None:
if self.log_train_interval > 0 and (iteration + 1) % self.log_train_interval == 0:
if (
self.log_train_interval > 0
and (iteration + 1) % self.log_train_interval == 0
):
self.global_iteration += self.log_train_interval
for metric, value in train_metrics.items():
self.writer.add_scalar(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment