Skip to content

Commit

Permalink
Lightning 3.0 updates (#727)
Browse files Browse the repository at this point in the history
* update 3.0 args in lightning

* fix test

* revert dropping kwargs in lightning logger

* add mypy types for yaml
  • Loading branch information
Dave Berenbaum authored Oct 21, 2023
1 parent 3372cce commit a488e83
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 31 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ tests = [
]
dev = [
"dvclive[all,tests]",
"mypy>=1.1.1"
"mypy>=1.1.1",
"types-PyYAML",
]
mmcv = ["mmcv"]
tf = ["tensorflow"]
Expand Down
23 changes: 3 additions & 20 deletions src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,19 @@ def _should_call_next_step():


class DVCLiveLogger(Logger):
def __init__( # noqa: PLR0913
def __init__(
self,
run_name: Optional[str] = "dvclive_run",
prefix="",
log_model: Union[str, bool] = False,
experiment=None,
dir: Optional[str] = None, # noqa: A002
resume: bool = False,
report: Optional[str] = None,
save_dvc_exp: bool = False,
dvcyaml: bool = True,
cache_images: bool = False,
exp_name: Optional[str] = None,
**kwargs,
):
super().__init__()
self._prefix = prefix
self._live_init: Dict[str, Any] = {
"resume": resume,
"report": report,
"save_dvc_exp": save_dvc_exp,
"dvcyaml": dvcyaml,
"cache_images": cache_images,
}
if dir is not None:
self._live_init["dir"] = dir
self._live_init: Dict[str, Any] = kwargs
self._experiment = experiment
self._version = run_name
if report == "notebook":
# Force Live instantiation
self.experiment # noqa: B018
self._log_model = log_model
self._logged_model_time: Dict[str, float] = {}
self._checkpoint_callback: Optional[ModelCheckpoint] = None
Expand Down
66 changes: 56 additions & 10 deletions tests/frameworks/test_lightning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
from contextlib import redirect_stdout
from io import StringIO
from unittest import mock

import pytest
import yaml

from dvclive.plots.metric import Metric
from dvclive.serialize import load_yaml
Expand All @@ -11,6 +15,8 @@
from lightning import LightningModule
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import BoringModel
from torch import nn
from torch.nn import functional as F # noqa: N812
from torch.optim import SGD, Adam
Expand Down Expand Up @@ -260,7 +266,7 @@ def validation_step(self, *args, **kwargs):
return loss


def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
def test_lightning_val_updates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post):
"""Test the `self.experiment._latest_studio_step -= 1` logic."""
mocked_post, _ = mocked_studio_post

Expand All @@ -277,22 +283,62 @@ def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio

calls = mocked_post.call_args_list
# 0: start
# 1: update_train_step_metrics
# 2: update_train_step_metrics
# 3: log_eval_end_metrics
plots = calls[3][1]["json"]["plots"]
val_loss = plots["dvclive/dvc.yaml::dvclive/plots/metrics/val/loss.tsv"]
# 1: first data event
# ...: data events
# -2: last data event
# -1: done
plots = calls[-2][1]["json"]["plots"]
val_loss = plots["dvclive/plots/metrics/val/loss.tsv"]
# Without `self.experiment._latest_studio_step -= 1`
# This would be empty
assert len(val_loss["data"]) == 1


def test_lightning_force_init(tmp_dir, mocker):
"""Regression test for https://github.com/iterative/dvclive/issues/594
Only call Live.__init__ when report is notebook.
"""Related to https://github.com/iterative/dvclive/issues/594
Don't call Live.__init__ on rank-nonzero processes.
"""
init = mocker.spy(Live, "__init__")
DVCLiveLogger()
init.assert_not_called()
DVCLiveLogger(report="notebook")
init.assert_called_once()


# LightningCLI tests
# Copied from https://github.com/Lightning-AI/lightning/blob/e7afe04ee86b64c76a5446088b3b75d9c275e5bf/tests/tests_pytorch/test_cli.py
class TestModel(BoringModel):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar


def _test_logger_init_args(logger_name, init, unresolved={}): # noqa: B006
cli_args = [f"--trainer.logger={logger_name}"]
cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()]
cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()]
cli_args.append("--print_config")

out = StringIO()
with mock.patch(
"sys.argv", ["any.py"] + cli_args # noqa: RUF005
), redirect_stdout( # noqa: RUF100
out
), pytest.raises(
SystemExit
):
LightningCLI(TestModel, run=False)

data = yaml.safe_load(out.getvalue())["trainer"]["logger"]
assert {k: data["init_args"][k] for k in init} == init
if unresolved:
assert data["dict_kwargs"] == unresolved


def test_dvclive_logger_init_args():
_test_logger_init_args(
"dvclive.lightning.DVCLiveLogger",
{
"run_name": "test_run", # Resolve from DVCLiveLogger.__init__
"dir": "results", # Resolve from Live.__init__
},
)

0 comments on commit a488e83

Please sign in to comment.