diff --git a/src/dvclive/catalyst.py b/src/dvclive/catalyst.py index eef4a32e..4d01080e 100644 --- a/src/dvclive/catalyst.py +++ b/src/dvclive/catalyst.py @@ -26,5 +26,4 @@ def on_epoch_end(self, runner) -> None: scheduler=runner.scheduler, ) utils.save_checkpoint(checkpoint, self.model_file) - self.live.make_report() self.live.next_step() diff --git a/src/dvclive/fastai.py b/src/dvclive/fastai.py index a6f5581c..b6f7cd5d 100644 --- a/src/dvclive/fastai.py +++ b/src/dvclive/fastai.py @@ -22,5 +22,4 @@ def after_epoch(self): if self.model_file: self.learn.save(self.model_file) - self.live.make_report() self.live.next_step() diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 520a5f77..4b35d3fd 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -27,7 +27,6 @@ def on_log( logs = kwargs["logs"] for key, value in logs.items(): self.live.log_metric(standardize_metric_name(key, __name__), value) - self.live.make_report() self.live.next_step() def on_epoch_end( diff --git a/src/dvclive/keras.py b/src/dvclive/keras.py index 810520e5..fb0357ad 100644 --- a/src/dvclive/keras.py +++ b/src/dvclive/keras.py @@ -53,5 +53,4 @@ def on_epoch_end( self.model.save_weights(self.model_file) else: self.model.save(self.model_file) - self.live.make_report() self.live.next_step() diff --git a/src/dvclive/lgbm.py b/src/dvclive/lgbm.py index 07e732c5..b460952b 100644 --- a/src/dvclive/lgbm.py +++ b/src/dvclive/lgbm.py @@ -17,5 +17,4 @@ def __call__(self, env): if self.model_file: env.model.save_model(self.model_file) - self.live.make_report() self.live.next_step() diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index a6cad1bf..5432342b 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -69,5 +69,4 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None): metric_val = metric_val.cpu().detach().item() metric_name = standardize_metric_name(metric_name, __name__) self.experiment.log_metric(name=metric_name, val=metric_val) - self.experiment.make_report() self.experiment.next_step() diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 91cf9d5f..014edd73 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -88,7 +88,7 @@ def __init__( else: self._cleanup() - self._latest_studio_step = self.get_step() + self._latest_studio_step = self.get_step() if resume else -1 if self.report_mode == "studio": from scmrepo.git import Git @@ -134,23 +134,16 @@ def get_step(self) -> int: return self._step or 0 def set_step(self, step: int) -> None: - if self._step is None: - self._step = 0 - self.make_summary() - - if self.report_mode == "studio": - if not post_to_studio(self, "data", logger): - logger.warning( - "`post_to_studio` `data` event failed." - " Data will be resent on next call." - ) - else: - self._latest_studio_step = step - self._step = step logger.debug(f"Step: {self._step}") def next_step(self): + if self._step is None: + self._step = 0 + + self.make_summary() + self.make_report() + self.make_checkpoint() self.set_step(self.get_step() + 1) def log_metric( @@ -169,7 +162,6 @@ def log_metric( data.dump(val, timestamp=timestamp) self.summary = nested_update(self.summary, data.to_summary(val)) - self.make_summary() logger.debug(f"Logged {name}: {val}") def log_image(self, name: str, val): @@ -233,12 +225,21 @@ def make_summary(self): dump_json(self.summary, self.metrics_file, cls=NumpyEncoder) def make_report(self): - if self.report_mode is not None: + if self.report_mode == "studio": + if not post_to_studio(self, "data", logger): + logger.warning( + "`post_to_studio` `data` event failed." + " Data will be resent on next call." + ) + else: + self._latest_studio_step = self.get_step() + elif self.report_mode is not None: make_report(self) if self.report_mode == "html" and env2bool(env.DVCLIVE_OPEN): open_file_in_browser(self.report_file) def end(self): + self.make_summary() if self.report_mode == "studio": if not post_to_studio(self, "done", logger): logger.warning("`post_to_studio` `done` event failed.") diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 8546dade..ce479f99 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -5,7 +5,7 @@ def _get_unsent_datapoints(plot, latest_step): - return [x for x in plot if int(x["step"]) >= latest_step] + return [x for x in plot if int(x["step"]) > latest_step] def _cast_to_numbers(datapoints): diff --git a/src/dvclive/xgb.py b/src/dvclive/xgb.py index e67b283c..ee421b5c 100644 --- a/src/dvclive/xgb.py +++ b/src/dvclive/xgb.py @@ -25,5 +25,4 @@ def after_iteration(self, model, epoch, evals_log): self.live.log_metric(key, latest_metric) if self.model_file: model.save_model(self.model_file) - self.live.make_report() self.live.next_step() diff --git a/tests/test_main.py b/tests/test_main.py index 5361c57b..1b0dd771 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -37,6 +37,7 @@ def test_logging_no_step(tmp_dir): dvclive = Live("logs") dvclive.log_metric("m1", 1) + dvclive.make_summary() assert not (tmp_dir / "logs" / "m1.tsv").is_file() assert (tmp_dir / dvclive.metrics_file).is_file() @@ -254,6 +255,7 @@ def test_custom_steps(tmp_dir): for step, metric in zip(steps, metrics): dvclive.set_step(step) dvclive.log_metric("m", metric) + dvclive.make_summary() assert read_history(dvclive, "m") == (steps, metrics) assert read_latest(dvclive, "m") == (last(steps), last(metrics)) @@ -265,10 +267,12 @@ def test_log_reset_with_set_step(tmp_dir): for i in range(3): dvclive.set_step(i) dvclive.log_metric("train_m", 1) + dvclive.make_summary() for i in range(3): dvclive.set_step(i) dvclive.log_metric("val_m", 1) + dvclive.make_summary() assert read_history(dvclive, "train_m") == ([0, 1, 2], [1, 1, 1]) assert read_history(dvclive, "val_m") == ([0, 1, 2], [1, 1, 1]) @@ -366,3 +370,17 @@ def test_log_metric_timestamp(timestamp): history, _ = parse_metrics(live) logged = next(iter(history.values())) assert ("timestamp" in logged[0]) == timestamp + + +def test_make_summary_is_called_on_end(tmp_dir): + live = Live() + + live.summary["foo"] = 1.0 + live.end() + + assert json.loads((tmp_dir / live.metrics_file).read_text()) == { + # no `step` + "foo": 1.0 + } + log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv" + assert not log_file.exists()