From 66ef8b702d54ebbacd4c4b2c34dcaeb186260fd9 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 18 Oct 2023 16:47:48 -0400 Subject: [PATCH 1/4] update 3.0 args in lightning --- src/dvclive/lightning.py | 7 ++++--- tests/frameworks/test_lightning.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 9f56c5b1..10e4ad36 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -56,12 +56,13 @@ def __init__( # noqa: PLR0913 prefix="", log_model: Union[str, bool] = False, experiment=None, - dir: Optional[str] = None, # noqa: A002 + dir: str = "dvclive", # noqa: A002 resume: bool = False, report: Optional[str] = None, - save_dvc_exp: bool = False, - dvcyaml: bool = True, + save_dvc_exp: bool = True, + dvcyaml: Union[str, bool] = True, cache_images: bool = False, + exp_message: Optional[str] = None, ): super().__init__() self._prefix = prefix diff --git a/tests/frameworks/test_lightning.py b/tests/frameworks/test_lightning.py index 01a6de6a..db18419f 100644 --- a/tests/frameworks/test_lightning.py +++ b/tests/frameworks/test_lightning.py @@ -260,7 +260,7 @@ def validation_step(self, *args, **kwargs): return loss -def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): +def test_lightning_val_updates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): """Test the `self.experiment._latest_studio_step -= 1` logic.""" mocked_post, _ = mocked_studio_post @@ -281,7 +281,7 @@ def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio # 2: update_train_step_metrics # 3: log_eval_end_metrics plots = calls[3][1]["json"]["plots"] - val_loss = plots["dvclive/dvc.yaml::dvclive/plots/metrics/val/loss.tsv"] + val_loss = plots["dvclive/plots/metrics/val/loss.tsv"] # Without `self.experiment._latest_studio_step -= 1` # This would be empty assert len(val_loss["data"]) == 1 From 0f640bfd1a5abdd6ed107d755b7c38ccaa8e22be Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 18 Oct 2023 17:02:48 -0400 Subject: [PATCH 2/4] fix test --- tests/frameworks/test_lightning.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/frameworks/test_lightning.py b/tests/frameworks/test_lightning.py index db18419f..3aa07ee8 100644 --- a/tests/frameworks/test_lightning.py +++ b/tests/frameworks/test_lightning.py @@ -277,10 +277,11 @@ def test_lightning_val_updates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio calls = mocked_post.call_args_list # 0: start - # 1: update_train_step_metrics - # 2: update_train_step_metrics - # 3: log_eval_end_metrics - plots = calls[3][1]["json"]["plots"] + # 1: first data event + # ...: data events + # -2: last data event + # -1: done + plots = calls[-2][1]["json"]["plots"] val_loss = plots["dvclive/plots/metrics/val/loss.tsv"] # Without `self.experiment._latest_studio_step -= 1` # This would be empty From 332d15e911fd90c02f401ac2e99adba6dc9dbda1 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Thu, 19 Oct 2023 14:55:44 -0400 Subject: [PATCH 3/4] revert dropping kwargs in lightning logger --- src/dvclive/lightning.py | 23 ++----------- tests/frameworks/test_lightning.py | 53 +++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 10e4ad36..d621c540 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -50,36 +50,19 @@ def _should_call_next_step(): class DVCLiveLogger(Logger): - def __init__( # noqa: PLR0913 + def __init__( self, run_name: Optional[str] = "dvclive_run", prefix="", log_model: Union[str, bool] = False, experiment=None, - dir: str = "dvclive", # noqa: A002 - resume: bool = False, - report: Optional[str] = None, - save_dvc_exp: bool = True, - dvcyaml: Union[str, bool] = True, - cache_images: bool = False, - exp_message: Optional[str] = None, + **kwargs, ): super().__init__() self._prefix = prefix - self._live_init: Dict[str, Any] = { - "resume": resume, - "report": report, - "save_dvc_exp": save_dvc_exp, - "dvcyaml": dvcyaml, - "cache_images": cache_images, - } - if dir is not None: - self._live_init["dir"] = dir + self._live_init: Dict[str, Any] = kwargs self._experiment = experiment self._version = run_name - if report == "notebook": - # Force Live instantiation - self.experiment # noqa: B018 self._log_model = log_model self._logged_model_time: Dict[str, float] = {} self._checkpoint_callback: Optional[ModelCheckpoint] = None diff --git a/tests/frameworks/test_lightning.py b/tests/frameworks/test_lightning.py index 3aa07ee8..2151c78d 100644 --- a/tests/frameworks/test_lightning.py +++ b/tests/frameworks/test_lightning.py @@ -1,6 +1,10 @@ import os +from contextlib import redirect_stdout +from io import StringIO +from unittest import mock import pytest +import yaml from dvclive.plots.metric import Metric from dvclive.serialize import load_yaml @@ -11,6 +15,8 @@ from lightning import LightningModule from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint + from lightning.pytorch.cli import LightningCLI + from lightning.pytorch.demos.boring_classes import BoringModel from torch import nn from torch.nn import functional as F # noqa: N812 from torch.optim import SGD, Adam @@ -289,11 +295,50 @@ def test_lightning_val_updates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio def test_lightning_force_init(tmp_dir, mocker): - """Regression test for https://github.com/iterative/dvclive/issues/594 - Only call Live.__init__ when report is notebook. + """Related to https://github.com/iterative/dvclive/issues/594 + Don't call Live.__init__ on rank-nonzero processes. """ init = mocker.spy(Live, "__init__") DVCLiveLogger() init.assert_not_called() - DVCLiveLogger(report="notebook") - init.assert_called_once() + + +# LightningCLI tests +# Copied from https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/tests/tests_pytorch/test_cli.py +class TestModel(BoringModel): + def __init__(self, foo, bar=5): + super().__init__() + self.foo = foo + self.bar = bar + + +def _test_logger_init_args(logger_name, init, unresolved={}): # noqa: B006 + cli_args = [f"--trainer.logger={logger_name}"] + cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()] + cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()] + cli_args.append("--print_config") + + out = StringIO() + with mock.patch( + "sys.argv", ["any.py"] + cli_args # noqa: RUF005 + ), redirect_stdout( # noqa: RUF100 + out + ), pytest.raises( + SystemExit + ): + LightningCLI(TestModel, run=False) + + data = yaml.safe_load(out.getvalue())["trainer"]["logger"] + assert {k: data["init_args"][k] for k in init} == init + if unresolved: + assert data["dict_kwargs"] == unresolved + + +def test_dvclive_logger_init_args(): + _test_logger_init_args( + "dvclive.lightning.DVCLiveLogger", + { + "run_name": "test_run", # Resolve from DVCLiveLogger.__init__ + "dir": "results", # Resolve from Live.__init__ + }, + ) From a0ddedef7096f576842bb76cd6736581777427dd Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 20 Oct 2023 11:34:25 -0400 Subject: [PATCH 4/4] add mypy types for yaml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9bccf826..1d003e7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,8 @@ tests = [ ] dev = [ "dvclive[all,tests]", - "mypy>=1.1.1" + "mypy>=1.1.1", + "types-PyYAML", ] mmcv = ["mmcv"] tf = ["tensorflow"]