From 2160ae475298e7453aba62d6d5200329176c11fc Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 31 Jul 2023 16:00:31 -0400 Subject: [PATCH 01/33] post to studio even without git/dvc repo --- src/dvclive/live.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 58924cf7..9d7c68d8 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -126,17 +126,18 @@ def _init_cleanup(self): def _init_dvc(self): from dvc.scm import NoSCM - if os.getenv(env.DVC_EXP_BASELINE_REV, None): + self._dvc_repo = get_dvc_repo() + + self._exp_name = os.getenv(env.DVC_EXP_NAME, "") + self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV, "") + + if self._dvc_repo and self._baseline_rev and self._exp_name: # `dvc exp` execution - self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV, "") - self._exp_name = os.getenv(env.DVC_EXP_NAME, "") self._inside_dvc_exp = True if self._save_dvc_exp: logger.info("Ignoring `save_dvc_exp` because `dvc exp run` is running") self._save_dvc_exp = False - self._dvc_repo = get_dvc_repo() - dvc_logger = logging.getLogger("dvc") dvc_logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "WARNING").upper()) @@ -177,23 +178,6 @@ def _init_studio(self): logger.debug("Skipping `studio` report `start` and `done` events.") self._studio_events_to_skip.add("start") self._studio_events_to_skip.add("done") - elif self._dvc_repo is None: - logger.warning( - "Can't connect to Studio without a DVC Repo." - "\nYou can create a DVC Repo by calling `dvc init`." - ) - self._studio_events_to_skip.add("start") - self._studio_events_to_skip.add("data") - self._studio_events_to_skip.add("done") - elif not self._save_dvc_exp: - logger.warning( - "Can't connect to Studio without creating a DVC experiment." - "\nIf you have a DVC Pipeline, run it with `dvc exp run`." - "\nIf you are using DVCLive alone, use `save_dvc_exp=True`." - ) - self._studio_events_to_skip.add("start") - self._studio_events_to_skip.add("data") - self._studio_events_to_skip.add("done") else: response = post_live_metrics( "start", From 8653732fcd8c6d14c877b17f5b50da4b6373738f Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 31 Jul 2023 17:08:44 -0400 Subject: [PATCH 02/33] tests for no-git scenario --- src/dvclive/live.py | 4 +- tests/test_dvc.py | 9 ++-- tests/test_studio.py | 99 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 103 insertions(+), 9 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 9d7c68d8..963a834a 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -128,8 +128,8 @@ def _init_dvc(self): self._dvc_repo = get_dvc_repo() - self._exp_name = os.getenv(env.DVC_EXP_NAME, "") - self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV, "") + self._exp_name = os.getenv(env.DVC_EXP_NAME) + self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV) if self._dvc_repo and self._baseline_rev and self._exp_name: # `dvc exp` execution diff --git a/tests/test_dvc.py b/tests/test_dvc.py index c554dbc3..529c5571 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -142,15 +142,14 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): mocked_dvc_repo.experiments.save.assert_not_called() -def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker): +def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocked_dvc_repo): monkeypatch.setenv(DVC_EXP_BASELINE_REV, "foo") monkeypatch.setenv(DVC_EXP_NAME, "bar") - with mocker.patch("dvclive.live.get_dvc_repo", return_value=None): - live = Live(save_dvc_exp=True) - live.end() + live = Live(save_dvc_exp=True) + live.end() - assert live._dvc_repo is None + assert live._dvc_repo is not None assert live._baseline_rev == "foo" assert live._exp_name == "bar" assert live._inside_dvc_exp diff --git a/tests/test_studio.py b/tests/test_studio.py index a0865467..5b3aedb3 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -299,10 +299,9 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po @pytest.mark.studio() def test_post_to_studio_inside_dvc_exp( - tmp_dir, mocker, monkeypatch, mocked_studio_post + tmp_dir, mocker, monkeypatch, mocked_studio_post, mocked_dvc_repo ): mocked_post, _ = mocked_studio_post - mocker.patch("dvclive.live.get_dvc_repo", return_value=None) monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) monkeypatch.setenv(DVC_EXP_NAME, "bar") @@ -485,3 +484,99 @@ def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post): }, timeout=(30, 5), ) + + +def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN") + monkeypatch.setenv(DVC_STUDIO_REPO_URL, "STUDIO_REPO_URL") + monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) + monkeypatch.setenv(DVC_EXP_NAME, "bar") + + live = Live(save_dvc_exp=True) + live.log_param("fooparam", 1) + + dvc_path = Path(live.dvc_file).as_posix() + metrics_path = Path(live.metrics_file).as_posix() + params_path = Path(live.params_file).as_posix() + foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() + + mocked_post, _ = mocked_studio_post + + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + json={ + "type": "start", + "repo_url": "STUDIO_REPO_URL", + "baseline_sha": "f" * 40, + "name": live._exp_name, + "client": "dvclive", + }, + headers={ + "Authorization": "token STUDIO_TOKEN", + "Content-type": "application/json", + }, + timeout=(30, 5), + ) + + live.log_metric("foo", 1) + + live.next_step() + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + json={ + "type": "data", + "repo_url": "STUDIO_REPO_URL", + "baseline_sha": "f" * 40, + "name": live._exp_name, + "step": 0, + "metrics": {metrics_path: {"data": {"step": 0, "foo": 1}}}, + "params": {params_path: {"fooparam": 1}}, + "plots": {f"{dvc_path}::{foo_path}": {"data": [{"step": 0, "foo": 1.0}]}}, + "client": "dvclive", + }, + headers={ + "Authorization": "token STUDIO_TOKEN", + "Content-type": "application/json", + }, + timeout=(30, 5), + ) + + live.log_metric("foo", 2) + + live.next_step() + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + json={ + "type": "data", + "repo_url": "STUDIO_REPO_URL", + "baseline_sha": "f" * 40, + "name": live._exp_name, + "step": 1, + "metrics": {metrics_path: {"data": {"step": 1, "foo": 2}}}, + "params": {params_path: {"fooparam": 1}}, + "plots": {f"{dvc_path}::{foo_path}": {"data": [{"step": 1, "foo": 2.0}]}}, + "client": "dvclive", + }, + headers={ + "Authorization": "token STUDIO_TOKEN", + "Content-type": "application/json", + }, + timeout=(30, 5), + ) + + live.end() + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + json={ + "type": "done", + "repo_url": "STUDIO_REPO_URL", + "baseline_sha": "f" * 40, + "name": live._exp_name, + "client": "dvclive", + }, + headers={ + "Authorization": "token STUDIO_TOKEN", + "Content-type": "application/json", + }, + timeout=(30, 5), + ) From ca57b1776d01bf1112df7c53f11d3d165b5fb6aa Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 4 Aug 2023 14:25:33 -0400 Subject: [PATCH 03/33] studio: make no-repo paths relative to cwd --- src/dvclive/studio.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 450d342b..75e0425e 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -25,18 +25,20 @@ def _cast_to_numbers(datapoints): return datapoints -def _rel_path(path, dvc_root_path): +def _rel_path(live, path): absolute_path = Path(path).resolve() - return str(absolute_path.relative_to(dvc_root_path).as_posix()) + if live._dvc_repo is not None: + root = live._dvc_repo.root_dir + else: + root = os.getcwd() + return str(absolute_path.relative_to(root).as_posix()) def _adapt_plot_name(live, name): - if live._dvc_repo is not None: - name = _rel_path(name, live._dvc_repo.root_dir) + name = _rel_path(live, name) if os.path.isfile(live.dvc_file): dvc_file = live.dvc_file - if live._dvc_repo is not None: - dvc_file = _rel_path(live.dvc_file, live._dvc_repo.root_dir) + dvc_file = _rel_path(live, live.dvc_file) name = f"{dvc_file}::{name}" return name @@ -64,8 +66,7 @@ def _adapt_images(live): def get_studio_updates(live): if os.path.isfile(live.params_file): params_file = live.params_file - if live._dvc_repo is not None: - params_file = _rel_path(params_file, live._dvc_repo.root_dir) + params_file = _rel_path(live, params_file) params = {params_file: load_yaml(live.params_file)} else: params = {} @@ -73,8 +74,7 @@ def get_studio_updates(live): plots, metrics = parse_metrics(live) metrics_file = live.metrics_file - if live._dvc_repo is not None: - metrics_file = _rel_path(metrics_file, live._dvc_repo.root_dir) + metrics_file = _rel_path(live, metrics_file) metrics = {metrics_file: {"data": metrics}} plots = { From 98ae256f984651e43a0446f098d3722e54f23ab4 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 4 Aug 2023 14:28:41 -0400 Subject: [PATCH 04/33] make ruff happy --- src/dvclive/studio.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 75e0425e..9c8bec01 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -27,10 +27,7 @@ def _cast_to_numbers(datapoints): def _rel_path(live, path): absolute_path = Path(path).resolve() - if live._dvc_repo is not None: - root = live._dvc_repo.root_dir - else: - root = os.getcwd() + root = live._dvc_repo.root_dir if live._dvc_repo is not None else os.getcwd() return str(absolute_path.relative_to(root).as_posix()) From df1a20bc6b7fc7fe5f8c206a8837cfaa0818c118 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 4 Aug 2023 15:57:58 -0400 Subject: [PATCH 05/33] don't require exp name --- src/dvclive/dvc.py | 6 +++++- src/dvclive/live.py | 5 ++++- tests/test_dvc.py | 2 +- tests/test_studio.py | 11 ++++------- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index 2e6511d3..ada2cb0d 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -147,8 +147,12 @@ def mark_dvclive_only_ended(): def get_random_exp_name(scm, baseline_rev): + from dvc.repo.experiments.utils import gen_random_name from dvc.repo.experiments.utils import ( get_random_exp_name as dvc_get_random_exp_name, ) - return dvc_get_random_exp_name(scm, baseline_rev) + if scm and baseline_rev: + return dvc_get_random_exp_name(scm, baseline_rev) + # TODO: ping studio for list of existing names to check against + return gen_random_name() diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 963a834a..d28da142 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -130,6 +130,9 @@ def _init_dvc(self): self._exp_name = os.getenv(env.DVC_EXP_NAME) self._baseline_rev = os.getenv(env.DVC_EXP_BASELINE_REV) + if not self._exp_name: + scm = self._dvc_repo.scm if self._dvc_repo else None + self._exp_name = get_random_exp_name(scm, self._baseline_rev) if self._dvc_repo and self._baseline_rev and self._exp_name: # `dvc exp` execution @@ -162,8 +165,8 @@ def _init_dvc(self): return self._baseline_rev = self._dvc_repo.scm.get_rev() + if self._save_dvc_exp: - self._exp_name = get_random_exp_name(self._dvc_repo.scm, self._baseline_rev) mark_dvclive_only_started() self._include_untracked.append(self.dir) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 529c5571..970ff4ff 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -138,7 +138,7 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): ) else: assert live._baseline_rev is not None - assert live._exp_name is None + assert live._exp_name is not None mocked_dvc_repo.experiments.save.assert_not_called() diff --git a/tests/test_studio.py b/tests/test_studio.py index 5b3aedb3..1aa0dc85 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -397,11 +397,6 @@ def test_post_to_studio_inside_subdir_dvc_exp( ) -def test_post_to_studio_requires_exp(tmp_dir, mocked_dvc_repo, mocked_studio_post): - assert Live()._studio_events_to_skip == {"start", "data", "done"} - assert not Live(save_dvc_exp=True)._studio_events_to_skip - - def test_get_dvc_studio_config_none(mocker): mocker.patch("dvclive.live.get_dvc_repo", return_value=None) live = Live() @@ -486,11 +481,13 @@ def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post): ) -def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): +@pytest.mark.parametrize("exp_name", [True, False]) +def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_name): monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN") monkeypatch.setenv(DVC_STUDIO_REPO_URL, "STUDIO_REPO_URL") monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) - monkeypatch.setenv(DVC_EXP_NAME, "bar") + if exp_name: + monkeypatch.setenv(DVC_EXP_NAME, "bar") live = Live(save_dvc_exp=True) live.log_param("fooparam", 1) From 9ddcf07cce8d5a6d6d63d30c49fccc7fe3e9cfd4 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 4 Aug 2023 16:42:19 -0400 Subject: [PATCH 06/33] don't require baseline rev --- tests/test_studio.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/test_studio.py b/tests/test_studio.py index 1aa0dc85..d263a0ec 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -482,10 +482,16 @@ def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post): @pytest.mark.parametrize("exp_name", [True, False]) -def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_name): +@pytest.mark.parametrize("baseline_rev", [True, False]) +def test_post_to_studio_no_repo( + tmp_dir, monkeypatch, mocked_studio_post, exp_name, baseline_rev +): monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN") monkeypatch.setenv(DVC_STUDIO_REPO_URL, "STUDIO_REPO_URL") - monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) + baseline = None + if baseline_rev: + baseline = "f" * 40 + monkeypatch.setenv(DVC_EXP_BASELINE_REV, baseline) if exp_name: monkeypatch.setenv(DVC_EXP_NAME, "bar") @@ -504,7 +510,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_na json={ "type": "start", "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, + "baseline_sha": baseline, "name": live._exp_name, "client": "dvclive", }, @@ -523,7 +529,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_na json={ "type": "data", "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, + "baseline_sha": baseline, "name": live._exp_name, "step": 0, "metrics": {metrics_path: {"data": {"step": 0, "foo": 1}}}, @@ -546,7 +552,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_na json={ "type": "data", "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, + "baseline_sha": baseline, "name": live._exp_name, "step": 1, "metrics": {metrics_path: {"data": {"step": 1, "foo": 2}}}, @@ -567,7 +573,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_na json={ "type": "done", "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, + "baseline_sha": baseline, "name": live._exp_name, "client": "dvclive", }, @@ -577,3 +583,5 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post, exp_na }, timeout=(30, 5), ) + + assert live._exp_name is not None From 3c68cb60269be34e212391f1e6eeafe7e878f1ae Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 4 Aug 2023 17:24:14 -0400 Subject: [PATCH 07/33] refactor studio path formatting --- src/dvclive/studio.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 9c8bec01..77d6c57e 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -25,17 +25,18 @@ def _cast_to_numbers(datapoints): return datapoints -def _rel_path(live, path): - absolute_path = Path(path).resolve() - root = live._dvc_repo.root_dir if live._dvc_repo is not None else os.getcwd() - return str(absolute_path.relative_to(root).as_posix()) +def _format_path(live, path): + if live._dvc_repo: + absolute_path = Path(path).resolve() + path = absolute_path.relative_to(live._dvc_repo.root_dir) + return str(Path(path).as_posix()) def _adapt_plot_name(live, name): - name = _rel_path(live, name) + name = _format_path(live, name) if os.path.isfile(live.dvc_file): dvc_file = live.dvc_file - dvc_file = _rel_path(live, live.dvc_file) + dvc_file = _format_path(live, live.dvc_file) name = f"{dvc_file}::{name}" return name @@ -63,7 +64,7 @@ def _adapt_images(live): def get_studio_updates(live): if os.path.isfile(live.params_file): params_file = live.params_file - params_file = _rel_path(live, params_file) + params_file = _format_path(live, params_file) params = {params_file: load_yaml(live.params_file)} else: params = {} @@ -71,7 +72,7 @@ def get_studio_updates(live): plots, metrics = parse_metrics(live) metrics_file = live.metrics_file - metrics_file = _rel_path(live, metrics_file) + metrics_file = _format_path(live, metrics_file) metrics = {metrics_file: {"data": metrics}} plots = { From f6a0a29029b13a2f066835c1be001fdb628a551f Mon Sep 17 00:00:00 2001 From: daavoo Date: Thu, 17 Aug 2023 11:54:04 +0200 Subject: [PATCH 08/33] live: Set new defaults `report=None` and `save_dvc_exp=True`. --- README.md | 2 +- examples/DVCLive-HuggingFace.ipynb | 2 +- examples/DVCLive-PyTorch-Lightning.ipynb | 2 +- examples/DVCLive-Quickstart.ipynb | 2 +- examples/DVCLive-scikit-learn.ipynb | 2 +- src/dvclive/live.py | 8 ++++---- src/dvclive/optuna.py | 2 +- tests/test_dvc.py | 18 +++++++++--------- tests/test_frameworks/test_lightning.py | 2 +- tests/test_log_artifact.py | 20 ++++++++++---------- tests/test_main.py | 6 +++--- tests/test_report.py | 10 +++++----- tests/test_studio.py | 24 +++++++++++++----------- 13 files changed, 51 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index a00dbf10..b3f7f41e 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ from dvclive import Live params = {"learning_rate": 0.002, "optimizer": "Adam", "epochs": 20} -with Live(save_dvc_exp=True) as live: +with Live() as live: # log a parameters for param in params: diff --git a/examples/DVCLive-HuggingFace.ipynb b/examples/DVCLive-HuggingFace.ipynb index 2e5ed571..31b4e4f4 100644 --- a/examples/DVCLive-HuggingFace.ipynb +++ b/examples/DVCLive-HuggingFace.ipynb @@ -106,7 +106,7 @@ " train_dataset=small_train_dataset,\n", " eval_dataset=small_eval_dataset,\n", " compute_metrics=compute_metrics,\n", - " callbacks=[DVCLiveCallback(report=\"notebook\", save_dvc_exp=True, log_model=True)],\n", + " callbacks=[DVCLiveCallback(report=\"notebook\", \n", " )\n", " trainer.train()" ] diff --git a/examples/DVCLive-PyTorch-Lightning.ipynb b/examples/DVCLive-PyTorch-Lightning.ipynb index b4c83714..2ee1a322 100644 --- a/examples/DVCLive-PyTorch-Lightning.ipynb +++ b/examples/DVCLive-PyTorch-Lightning.ipynb @@ -173,7 +173,7 @@ " limit_train_batches=200,\n", " limit_val_batches=100,\n", " max_epochs=5,\n", - " logger=DVCLiveLogger(save_dvc_exp=True, report=\"notebook\", log_model=True),\n", + " logger=DVCLiveLogger(, log_model=True),\n", " )\n", " trainer.fit(model, train_loader, validation_loader)\n" ] diff --git a/examples/DVCLive-Quickstart.ipynb b/examples/DVCLive-Quickstart.ipynb index 6f0c7185..8c589a93 100644 --- a/examples/DVCLive-Quickstart.ipynb +++ b/examples/DVCLive-Quickstart.ipynb @@ -220,7 +220,7 @@ "\n", "best_test_acc = 0\n", "\n", - "with Live(save_dvc_exp=True, report=\"notebook\") as live:\n", + "with Live() as live:\n", "\n", " live.log_params(params)\n", "\n", diff --git a/examples/DVCLive-scikit-learn.ipynb b/examples/DVCLive-scikit-learn.ipynb index 16f6e0f7..178eefbf 100644 --- a/examples/DVCLive-scikit-learn.ipynb +++ b/examples/DVCLive-scikit-learn.ipynb @@ -96,7 +96,7 @@ "\n", "for n_estimators in (10, 50, 100):\n", "\n", - " with Live(report=None, save_dvc_exp=True) as live:\n", + " with Live() as live:\n", "\n", " live.log_param(\"n_estimators\", n_estimators)\n", "\n", diff --git a/src/dvclive/live.py b/src/dvclive/live.py index da0586c2..4c8201d7 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -56,8 +56,8 @@ def __init__( self, dir: str = "dvclive", # noqa: A002 resume: bool = False, - report: Optional[str] = "auto", - save_dvc_exp: bool = False, + report: Optional[str] = None, + save_dvc_exp: bool = True, dvcyaml: bool = True, cache_images: bool = False, exp_message: Optional[str] = None, @@ -118,8 +118,9 @@ def _init_cleanup(self): for f in ( self.metrics_file, - self.report_file, self.params_file, + os.path.join(self.dir, "report.html"), + os.path.join(self.dir, "report.md"), ): if f and os.path.exists(f): os.remove(f) @@ -202,7 +203,6 @@ def _init_studio(self): logger.warning( "Can't connect to Studio without creating a DVC experiment." "\nIf you have a DVC Pipeline, run it with `dvc exp run`." - "\nIf you are using DVCLive alone, use `save_dvc_exp=True`." ) self._studio_events_to_skip.add("start") self._studio_events_to_skip.add("data") diff --git a/src/dvclive/optuna.py b/src/dvclive/optuna.py index 9954b484..10069557 100644 --- a/src/dvclive/optuna.py +++ b/src/dvclive/optuna.py @@ -10,7 +10,7 @@ def __init__(self, metric_name="metric", **kwargs) -> None: self.live_kwargs = kwargs def __call__(self, study, trial) -> None: - with Live(save_dvc_exp=True, **self.live_kwargs) as live: + with Live(**self.live_kwargs) as live: self._log_metrics(trial.values, live) live.log_params(trial.params) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index 1e9c43b8..97452e28 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -158,7 +158,7 @@ def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch, mocker): monkeypatch.setenv(DVC_EXP_NAME, "bar") with mocker.patch("dvclive.live.get_dvc_repo", return_value=None): - live = Live(save_dvc_exp=True) + live = Live() live.end() assert live._dvc_repo is None @@ -177,7 +177,7 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker): dvc_repo.scm.no_commits = False dvc_repo.config = {} with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo): - live = Live(save_dvc_exp=True) + live = Live() assert live._save_dvc_exp assert live._baseline_rev is not None assert live._exp_name is not None @@ -219,7 +219,7 @@ def test_exp_save_with_dvc_files(tmp_dir, mocker): dvc_repo.config = {} with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo): - live = Live(save_dvc_exp=True) + live = Live() live.end() dvc_repo.experiments.save.assert_called_with( @@ -238,7 +238,7 @@ def test_exp_save_dvcexception_is_ignored(tmp_dir, mocker): dvc_repo.experiments.save.side_effect = DvcException("foo") mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) - with Live(save_dvc_exp=True): + with Live(): pass @@ -252,7 +252,7 @@ def test_untracked_dvclive_files_inside_dvc_exp_run_are_added( "dvclive/metrics.json", plot_file, ] - with Live(report=None) as live: + with Live() as live: live.log_metric("foo", 1) live.next_step() live._dvc_repo.scm.add.assert_called_with(["dvclive/metrics.json", plot_file]) @@ -269,7 +269,7 @@ def test_dvc_outs_are_not_added(tmp_dir, mocked_dvc_repo, monkeypatch): plot_file, ] - with Live(report=None) as live: + with Live() as live: live.log_metric("foo", 1) live.next_step() @@ -282,7 +282,7 @@ def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch): mocked_dvc_repo.scm.untracked_files.return_value = ["dvclive/metrics.json"] mocked_dvc_repo.scm.add.side_effect = DvcException("foo") - with Live(report=None) as live: + with Live() as live: live.summary["foo"] = 1 @@ -302,7 +302,7 @@ def test_make_dvcyaml_idempotent(tmp_dir, mocked_dvc_repo): def test_exp_save_message(tmp_dir, mocked_dvc_repo): - live = Live(save_dvc_exp=True, exp_message="Custom message") + live = Live(exp_message="Custom message") live.end() mocked_dvc_repo.experiments.save.assert_called_with( name=live._exp_name, @@ -320,7 +320,7 @@ def test_no_scm_repo(tmp_dir, mocker): live = Live() assert live._dvc_repo == dvc_repo - live = Live(save_dvc_exp=True) + live = Live() assert live._save_dvc_exp is False diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index 2e22ea53..01a6de6a 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -265,7 +265,7 @@ def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio mocked_post, _ = mocked_studio_post model = ValLitXOR() - dvclive_logger = DVCLiveLogger(save_dvc_exp=True) + dvclive_logger = DVCLiveLogger() trainer = Trainer( logger=dvclive_logger, max_steps=4, diff --git a/tests/test_log_artifact.py b/tests/test_log_artifact.py index d6a4b45d..dcea9e5f 100644 --- a/tests/test_log_artifact.py +++ b/tests/test_log_artifact.py @@ -20,7 +20,7 @@ def test_log_artifact(tmp_dir, dvc_repo, cache): data = tmp_dir / "data" data.touch() - with Live() as live: + with Live(save_dvc_exp=False) as live: live.log_artifact("data", cache=cache) assert data.with_suffix(".dvc").exists() is cache assert load_yaml(live.dvc_file) == {} @@ -29,12 +29,12 @@ def test_log_artifact(tmp_dir, dvc_repo, cache): def test_log_artifact_on_existing_dvc_file(tmp_dir, dvc_repo): data = tmp_dir / "data" data.write_text("foo") - with Live() as live: + with Live(save_dvc_exp=False) as live: live.log_artifact("data") prev_content = data.with_suffix(".dvc").read_text() - with Live() as live: + with Live(save_dvc_exp=False) as live: data.write_text("bar") live.log_artifact("data") @@ -43,7 +43,7 @@ def test_log_artifact_on_existing_dvc_file(tmp_dir, dvc_repo): def test_log_artifact_twice(tmp_dir, dvc_repo): data = tmp_dir / "data" - with Live() as live: + with Live(save_dvc_exp=False) as live: for i in range(2): data.write_text(str(i)) live.log_artifact("data") @@ -54,7 +54,7 @@ def test_log_artifact_with_save_dvc_exp(tmp_dir, mocker, mocked_dvc_repo): stage = mocker.MagicMock() stage.addressing = "data" mocked_dvc_repo.add.return_value = [stage] - with Live(save_dvc_exp=True) as live: + with Live() as live: live.log_artifact("data") mocked_dvc_repo.experiments.save.assert_called_with( name=live._exp_name, @@ -78,7 +78,7 @@ def test_log_artifact_type_model(tmp_dir, mocked_dvc_repo): def test_log_artifact_dvc_symlink(tmp_dir, dvc_repo): (tmp_dir / "model.pth").touch() - with Live() as live: + with Live(save_dvc_exp=False) as live: live._dvc_repo.cache.local.cache_types = ["symlink"] live.log_artifact("model.pth", type="model") @@ -90,7 +90,7 @@ def test_log_artifact_dvc_symlink(tmp_dir, dvc_repo): def test_log_artifact_copy(tmp_dir, dvc_repo): (tmp_dir / "model.pth").touch() - with Live() as live: + with Live(save_dvc_exp=False) as live: live.log_artifact("model.pth", type="model", copy=True) artifacts_dir = Path(live.artifacts_dir) @@ -105,7 +105,7 @@ def test_log_artifact_copy(tmp_dir, dvc_repo): def test_log_artifact_copy_overwrite(tmp_dir, dvc_repo): (tmp_dir / "model.pth").touch() - with Live() as live: + with Live(save_dvc_exp=False) as live: artifacts_dir = Path(live.artifacts_dir) # testing with symlink cache to make sure that DVC protected mode # does not prevent the overwrite @@ -127,7 +127,7 @@ def test_log_artifact_copy_directory_overwrite(tmp_dir, dvc_repo): model_path.mkdir() (tmp_dir / "weights" / "model-epoch-1.pth").touch() - with Live() as live: + with Live(save_dvc_exp=False) as live: artifacts_dir = Path(live.artifacts_dir) # testing with symlink cache to make sure that DVC protected mode # does not prevent the overwrite @@ -237,7 +237,7 @@ def test_log_artifact_inside_exp(tmp_dir, mocker, dvc_repo, tracked): dvcyaml_path = tmp_dir / "dvc.yaml" with open(dvcyaml_path, "w") as f: f.write(dvcyaml) - live = Live() + live = Live(save_dvc_exp=False) spy = mocker.spy(live._dvc_repo, "add") live._inside_dvc_exp = True live.log_artifact("data") diff --git a/tests/test_main.py b/tests/test_main.py index f72526c2..270d90ce 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -379,7 +379,7 @@ def test_make_summary_on_end_dont_increment_step(tmp_dir): def test_context_manager(tmp_dir): - with Live() as live: + with Live(report="html") as live: live.summary["foo"] = 1.0 assert json.loads((tmp_dir / live.metrics_file).read_text()) == { @@ -420,7 +420,7 @@ def test_vscode_dvclive_only_signal_file(tmp_dir, dvc_root, mocker): with mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo), mocker.patch( "dvclive.live.os.getpid", return_value=test_pid ): - dvclive = Live(save_dvc_exp=True) + dvclive = Live() if dvc_root: assert os.path.exists(signal_file) @@ -456,7 +456,7 @@ def test_suppress_dvc_logs(tmp_dir, mocked_dvc_repo): @pytest.mark.parametrize("cache", [False, True]) def test_cache_images(tmp_dir, dvc_repo, cache): - live = Live(cache_images=cache) + live = Live(save_dvc_exp=False, cache_images=cache) img = Image.new("RGB", (10, 10), (250, 250, 250)) live.log_image("image.png", img) live.end() diff --git a/tests/test_report.py b/tests/test_report.py index cd7832e6..0259bae8 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -72,15 +72,15 @@ def test_get_renderers(tmp_dir, mocker): def test_report_init(monkeypatch, mocker): monkeypatch.setenv("CI", "false") - live = Live() + live = Live(report="auto") assert live._report_mode == "html" monkeypatch.setenv("CI", "true") - live = Live() + live = Live(report="auto") assert live._report_mode == "md" mocker.patch("dvclive.live.matplotlib_installed", return_value=False) - live = Live() + live = Live(report="auto") assert live._report_mode == "html" for report in (None, "html", "md"): @@ -116,7 +116,7 @@ def test_make_report_open(tmp_dir, mocker, monkeypatch): assert not mocked_open.called - live = Live(report=None) + live = Live(report="html") live.log_metric("foo", 1) live.next_step() @@ -124,7 +124,7 @@ def test_make_report_open(tmp_dir, mocker, monkeypatch): monkeypatch.setenv(DVCLIVE_OPEN, True) - live = Live() + live = Live(report="html") live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [1, 0, 0, 1]) live.make_report() diff --git a/tests/test_studio.py b/tests/test_studio.py index a0865467..9b7731df 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -12,7 +12,7 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): - live = Live(save_dvc_exp=True) + live = Live() live.log_param("fooparam", 1) dvc_path = Path(live.dvc_file).as_posix() @@ -108,7 +108,7 @@ def test_post_to_studio_failed_data_request( ): mocked_post, valid_response = mocked_studio_post - live = Live(save_dvc_exp=True) + live = Live() dvc_path = Path(live.dvc_file).as_posix() metrics_path = Path(live.metrics_file).as_posix() @@ -157,7 +157,7 @@ def test_post_to_studio_failed_start_request( mocked_response.status_code = 400 mocked_post = mocker.patch("requests.post", return_value=mocked_response) - live = Live(save_dvc_exp=True) + live = Live() live.log_metric("foo", 1) live.next_step() @@ -170,7 +170,7 @@ def test_post_to_studio_failed_start_request( def test_post_to_studio_end_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post): mocked_post, _ = mocked_studio_post - with Live(save_dvc_exp=True) as live: + with Live() as live: live.log_metric("foo", 1) live.next_step() @@ -237,7 +237,9 @@ def test_post_to_studio_include_prefix_if_needed( ): mocked_post, _ = mocked_studio_post # Create dvclive/dvc.yaml - live = Live("custom_dir", save_dvc_exp=True) + live = Live( + "custom_dir", + ) live.log_metric("foo", 1) live.next_step() @@ -268,7 +270,7 @@ def test_post_to_studio_include_prefix_if_needed( def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_post): mocked_post, _ = mocked_studio_post - live = Live(save_dvc_exp=True) + live = Live() live.log_metric("eval/loss", 1) live.next_step() @@ -322,7 +324,7 @@ def test_post_to_studio_inside_subdir( subdir.mkdir() monkeypatch.chdir(subdir) - live = Live(save_dvc_exp=True) + live = Live() live.log_metric("foo", 1) live.next_step() @@ -399,8 +401,8 @@ def test_post_to_studio_inside_subdir_dvc_exp( def test_post_to_studio_requires_exp(tmp_dir, mocked_dvc_repo, mocked_studio_post): - assert Live()._studio_events_to_skip == {"start", "data", "done"} - assert not Live(save_dvc_exp=True)._studio_events_to_skip + assert Live(save_dvc_exp=False)._studio_events_to_skip == {"start", "data", "done"} + assert not Live()._studio_events_to_skip def test_get_dvc_studio_config_none(mocker): @@ -434,7 +436,7 @@ def test_get_dvc_studio_config_dvc_repo(mocked_dvc_repo): def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post): mocked_post, _ = mocked_studio_post - live = Live(save_dvc_exp=True) + live = Live() live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0))) live.next_step() @@ -465,7 +467,7 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post): def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post): - live = Live(save_dvc_exp=True, exp_message="Custom message") + live = Live(exp_message="Custom message") mocked_post, _ = mocked_studio_post From aa3610a1ff566c124a157e60d21b2ef8fe9fa317 Mon Sep 17 00:00:00 2001 From: daavoo Date: Tue, 22 Aug 2023 12:38:27 +0200 Subject: [PATCH 09/33] frameworks: Drop model_file. --- src/dvclive/catalyst.py | 13 +-- src/dvclive/fastai.py | 5 - src/dvclive/huggingface.py | 21 ---- src/dvclive/keras.py | 20 ---- src/dvclive/lgbm.py | 6 +- src/dvclive/xgb.py | 4 - tests/test_frameworks/test_catalyst.py | 18 ---- tests/test_frameworks/test_fastai.py | 13 --- tests/test_frameworks/test_huggingface.py | 30 ------ tests/test_frameworks/test_keras.py | 119 ---------------------- tests/test_frameworks/test_lgbm.py | 21 ---- tests/test_frameworks/test_xgboost.py | 16 --- 12 files changed, 2 insertions(+), 284 deletions(-) diff --git a/src/dvclive/catalyst.py b/src/dvclive/catalyst.py index 1f4f29c5..532d31b3 100644 --- a/src/dvclive/catalyst.py +++ b/src/dvclive/catalyst.py @@ -1,16 +1,14 @@ # ruff: noqa: ARG002 from typing import Optional -from catalyst import utils from catalyst.core.callback import Callback, CallbackOrder from dvclive import Live class DVCLiveCallback(Callback): - def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs): + def __init__(self, live: Optional[Live] = None, **kwargs): super().__init__(order=CallbackOrder.external) - self.model_file = model_file self.live = live if live is not None else Live(**kwargs) def on_epoch_end(self, runner) -> None: @@ -19,15 +17,6 @@ def on_epoch_end(self, runner) -> None: self.live.log_metric( f"{loader_key}/{key.replace('/', '_')}", float(value) ) - - if self.model_file: - checkpoint = utils.pack_checkpoint( - model=runner.model, - criterion=runner.criterion, - optimizer=runner.optimizer, - scheduler=runner.scheduler, - ) - utils.save_checkpoint(checkpoint, self.model_file) self.live.next_step() def on_experiment_end(self, runner): diff --git a/src/dvclive/fastai.py b/src/dvclive/fastai.py index a1750c2c..08028d20 100644 --- a/src/dvclive/fastai.py +++ b/src/dvclive/fastai.py @@ -27,13 +27,11 @@ def _inside_fine_tune(): class DVCLiveCallback(Callback): def __init__( self, - model_file: Optional[str] = None, with_opt: bool = False, live: Optional[Live] = None, **kwargs, ): super().__init__() - self.model_file = model_file self.with_opt = with_opt self.live = live if live is not None else Live(**kwargs) self.freeze_stage_ended = False @@ -66,9 +64,6 @@ def after_epoch(self): # When resuming (i.e. passing `start_epoch` to learner) # fast.ai calls after_epoch but we don't want to increase the step. if logged_metrics: - if self.model_file: - file = self.learn.save(self.model_file, with_opt=self.with_opt) - self.live.log_artifact(str(file)) self.live.next_step() def after_fit(self): diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 49fa47e3..bc14a814 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -26,12 +26,6 @@ def __init__( ): super().__init__() self._log_model = log_model - self.model_file = kwargs.pop("model_file", None) - if self.model_file: - logger.warning( - "model_file is deprecated and will be removed" - " in the next major version, use log_model instead" - ) self.live = live if live is not None else Live(**kwargs) def on_train_begin( @@ -65,21 +59,6 @@ def on_save( if self._log_model == "all" and state.is_world_process_zero: self.live.log_artifact(args.output_dir) - def on_epoch_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - if self.model_file: - model = kwargs["model"] - model.save_pretrained(self.model_file) - tokenizer = kwargs.get("tokenizer") - if tokenizer: - tokenizer.save_pretrained(self.model_file) - self.live.log_artifact(self.model_file) - def on_train_end( self, args: TrainingArguments, diff --git a/src/dvclive/keras.py b/src/dvclive/keras.py index c85bfa82..f7e7bffe 100644 --- a/src/dvclive/keras.py +++ b/src/dvclive/keras.py @@ -1,5 +1,4 @@ # ruff: noqa: ARG002 -import os from typing import Dict, Optional import tensorflow as tf @@ -11,37 +10,18 @@ class DVCLiveCallback(tf.keras.callbacks.Callback): def __init__( self, - model_file=None, save_weights_only: bool = False, live: Optional[Live] = None, **kwargs, ): super().__init__() - self.model_file = model_file self.save_weights_only = save_weights_only self.live = live if live is not None else Live(**kwargs) - def on_train_begin(self, logs=None): - if ( - self.live._resume # noqa: SLF001 - and self.model_file is not None - and os.path.exists(self.model_file) - ): - if self.save_weights_only: - self.model.load_weights(self.model_file) - else: - self.model = tf.keras.models.load_model(self.model_file) - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): logs = logs or {} for metric, value in logs.items(): self.live.log_metric(standardize_metric_name(metric, __name__), value) - if self.model_file: - if self.save_weights_only: - self.model.save_weights(self.model_file) - else: - self.model.save(self.model_file) - self.live.log_artifact(self.model_file) self.live.next_step() def on_train_end(self, logs: Optional[Dict] = None): diff --git a/src/dvclive/lgbm.py b/src/dvclive/lgbm.py index 3211a693..69b9a034 100644 --- a/src/dvclive/lgbm.py +++ b/src/dvclive/lgbm.py @@ -4,9 +4,8 @@ class DVCLiveCallback: - def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs): + def __init__(self, live: Optional[Live] = None, **kwargs): super().__init__() - self.model_file = model_file self.live = live if live is not None else Live(**kwargs) def __call__(self, env): @@ -16,7 +15,4 @@ def __call__(self, env): self.live.log_metric( f"{data_name}/{eval_name}" if multi_eval else eval_name, result ) - - if self.model_file: - env.model.save_model(self.model_file) self.live.next_step() diff --git a/src/dvclive/xgb.py b/src/dvclive/xgb.py index daf30a2e..2f3079d6 100644 --- a/src/dvclive/xgb.py +++ b/src/dvclive/xgb.py @@ -11,7 +11,6 @@ class DVCLiveCallback(TrainingCallback): def __init__( self, metric_data: Optional[str] = None, - model_file=None, live: Optional[Live] = None, **kwargs, ): @@ -23,7 +22,6 @@ def __init__( stacklevel=2, ) self._metric_data = metric_data - self.model_file = model_file self.live = live if live is not None else Live(**kwargs) def after_iteration(self, model, epoch, evals_log): @@ -32,8 +30,6 @@ def after_iteration(self, model, epoch, evals_log): for subdir, data in evals_log.items(): for key, values in data.items(): self.live.log_metric(f"{subdir}/{key}" if subdir else key, values[-1]) - if self.model_file: - model.save_model(self.model_file) self.live.next_step() def after_training(self, model): diff --git a/tests/test_frameworks/test_catalyst.py b/tests/test_frameworks/test_catalyst.py index 90165faa..bb195986 100644 --- a/tests/test_frameworks/test_catalyst.py +++ b/tests/test_frameworks/test_catalyst.py @@ -85,24 +85,6 @@ def test_catalyst_callback(tmp_dir, runner, runner_params, mocker): assert any("accuracy" in x.name for x in valid_path.iterdir()) -def test_catalyst_model_file(tmp_dir, runner, runner_params): - runner.train( - **runner_params, - num_epochs=2, - callbacks=[ - dl.AccuracyCallback(input_key="logits", target_key="targets"), - DVCLiveCallback("model.pth"), - ], - logdir="./logs", - valid_loader="valid", - valid_metric="loss", - minimize_valid_metric=True, - verbose=True, - load_best_on_end=True, - ) - assert (tmp_dir / "model.pth").is_file() - - def test_catalyst_pass_logger(): logger = Live("train_logs") diff --git a/tests/test_frameworks/test_fastai.py b/tests/test_frameworks/test_fastai.py index 5ae28455..452e54cd 100644 --- a/tests/test_frameworks/test_fastai.py +++ b/tests/test_frameworks/test_fastai.py @@ -68,19 +68,6 @@ def test_fastai_callback(tmp_dir, data_loader, mocker): assert not (metrics_path / "epoch.tsv").exists() -def test_fastai_model_file(tmp_dir, data_loader, mocker): - learn = tabular_learner(data_loader, metrics=accuracy) - learn.remove_cb(ProgressCallback) - learn.model_dir = os.path.abspath("./") - save = mocker.spy(learn, "save") - live_callback = DVCLiveCallback("model", with_opt=True) - log_artifact = mocker.patch.object(live_callback.live, "log_artifact") - learn.fit_one_cycle(2, cbs=[live_callback]) - assert (tmp_dir / "model.pth").is_file() - save.assert_called_with("model", with_opt=True) - log_artifact.assert_called_with(str(tmp_dir / "model.pth")) - - def test_fastai_pass_logger(): logger = Live("train_logs") diff --git a/tests/test_frameworks/test_huggingface.py b/tests/test_frameworks/test_huggingface.py index bc65e44a..b6d4fc04 100644 --- a/tests/test_frameworks/test_huggingface.py +++ b/tests/test_frameworks/test_huggingface.py @@ -176,33 +176,3 @@ def test_huggingface_pass_logger(): assert DVCLiveCallback().live is not logger assert DVCLiveCallback(live=logger).live is logger - - -def test_huggingface_model_file(tmp_dir, model, args, data, mocker): - logger = mocker.patch("dvclive.huggingface.logger") - - model_path = tmp_dir / "model_hf" - - live_callback = DVCLiveCallback(model_file=model_path) - log_artifact = mocker.patch.object(live_callback.live, "log_artifact") - - trainer = Trainer( - model, - args, - train_dataset=data[0], - eval_dataset=data[1], - compute_metrics=compute_metrics, - ) - trainer.add_callback(live_callback) - trainer.train() - - assert model_path.is_dir() - - assert (model_path / "pytorch_model.bin").exists() - assert (model_path / "config.json").exists() - log_artifact.assert_called_with(model_path) - - logger.warning.assert_called_with( - "model_file is deprecated and will be removed" - " in the next major version, use log_model instead" - ) diff --git a/tests/test_frameworks/test_keras.py b/tests/test_frameworks/test_keras.py index e12e10d6..46239091 100644 --- a/tests/test_frameworks/test_keras.py +++ b/tests/test_frameworks/test_keras.py @@ -63,122 +63,3 @@ def test_keras_callback_pass_logger(): assert DVCLiveCallback().live is not logger assert DVCLiveCallback(live=logger).live is logger - - -@pytest.mark.parametrize("save_weights_only", [True, False]) -def test_keras_model_file(tmp_dir, xor_model, mocker, save_weights_only): - model, x, y = xor_model() - save = mocker.spy(model, "save") - save_weights = mocker.spy(model, "save_weights") - - live_callback = DVCLiveCallback( - model_file="model.h5", save_weights_only=save_weights_only - ) - log_artifact = mocker.patch.object(live_callback.live, "log_artifact") - model.fit( - x, - y, - epochs=1, - batch_size=1, - callbacks=[live_callback], - ) - assert save.call_count != save_weights_only - assert save_weights.call_count == save_weights_only - log_artifact.assert_called_with(live_callback.model_file) - - -@pytest.mark.parametrize("save_weights_only", [True, False]) -def test_keras_load_model_on_resume(tmp_dir, xor_model, mocker, save_weights_only): - model, x, y = xor_model() - - if save_weights_only: - model.save_weights("model.h5") - else: - model.save("model.h5") - - load_weights = mocker.spy(model, "load_weights") - - model.fit( - x, - y, - epochs=1, - batch_size=1, - callbacks=[ - DVCLiveCallback( - model_file="model.h5", - save_weights_only=save_weights_only, - resume=True, - ) - ], - ) - - assert load_weights.call_count == save_weights_only - - -def test_keras_no_resume_skip_load(tmp_dir, xor_model, mocker): - model, x, y = xor_model() - - model.save_weights("model.h5") - - load_weights = mocker.spy(model, "load_weights") - - model.fit( - x, - y, - epochs=1, - batch_size=1, - callbacks=[ - DVCLiveCallback( - model_file="model.h5", - save_weights_only=True, - resume=False, - ) - ], - ) - - assert load_weights.call_count == 0 - - -def test_keras_no_existing_model_file_skip_load(tmp_dir, xor_model, mocker): - model, x, y = xor_model() - - load_weights = mocker.spy(model, "load_weights") - - model.fit( - x, - y, - epochs=1, - batch_size=1, - callbacks=[ - DVCLiveCallback( - model_file="model.h5", - save_weights_only=True, - resume=True, - ) - ], - ) - - assert load_weights.call_count == 0 - - -def test_keras_none_model_file_skip_load(tmp_dir, xor_model, mocker): - model, x, y = xor_model() - - model.save_weights("model.h5") - - load_weights = mocker.spy(model, "load_weights") - - model.fit( - x, - y, - epochs=1, - batch_size=1, - callbacks=[ - DVCLiveCallback( - save_weights_only=True, - resume=True, - ) - ], - ) - - assert load_weights.call_count == 0 diff --git a/tests/test_frameworks/test_lgbm.py b/tests/test_frameworks/test_lgbm.py index 7ef39a48..250f355a 100644 --- a/tests/test_frameworks/test_lgbm.py +++ b/tests/test_frameworks/test_lgbm.py @@ -8,7 +8,6 @@ try: import lightgbm as lgbm - import numpy as np import pandas as pd from sklearn import datasets from sklearn.model_selection import train_test_split @@ -82,26 +81,6 @@ def test_lgbm_integration_multi_eval(tmp_dir, model_params, iris_data): assert len(next(iter(logs.values()))) == 5 -@pytest.mark.skipif(platform == "darwin", reason="LIBOMP Segmentation fault on MacOS") -def test_lgbm_model_file(tmp_dir, model_params, iris_data): - model = lgbm.LGBMClassifier() - model.set_params(**model_params) - - model.fit( - iris_data[0][0], - iris_data[0][1], - eval_set=(iris_data[1][0], iris_data[1][1]), - eval_metric=["multi_logloss"], - callbacks=[DVCLiveCallback("lgbm_model")], - ) - - preds = model.predict(iris_data[1][0]) - model2 = lgbm.Booster(model_file="lgbm_model") - preds2 = model2.predict(iris_data[1][0]) - preds2 = np.argmax(preds2, axis=1) - assert np.sum(np.abs(preds2 - preds)) == 0 - - def test_lgbm_pass_logger(): logger = Live("train_logs") diff --git a/tests/test_frameworks/test_xgboost.py b/tests/test_frameworks/test_xgboost.py index 43c8edee..0b375450 100644 --- a/tests/test_frameworks/test_xgboost.py +++ b/tests/test_frameworks/test_xgboost.py @@ -8,7 +8,6 @@ from dvclive.utils import parse_metrics try: - import numpy as np import pandas as pd import xgboost as xgb from sklearn import datasets @@ -80,21 +79,6 @@ def test_xgb_integration( ) -def test_xgb_model_file(tmp_dir, train_params, iris_data): - model = xgb.train( - train_params, - iris_data, - callbacks=[DVCLiveCallback("eval_data", model_file="model_xgb.json")], - num_boost_round=5, - evals=[(iris_data, "eval_data")], - ) - - preds = model.predict(iris_data) - model2 = xgb.Booster(model_file="model_xgb.json") - preds2 = model2.predict(iris_data) - assert np.sum(np.abs(preds2 - preds)) == 0 - - def test_xgb_pass_logger(): logger = Live("train_logs") From 2e538ce80cf44d03e19ef03c371c47aa9b2bba4f Mon Sep 17 00:00:00 2001 From: daavoo Date: Tue, 22 Aug 2023 12:42:14 +0200 Subject: [PATCH 10/33] update examples --- examples/DVCLive-HuggingFace.ipynb | 2 +- examples/DVCLive-PyTorch-Lightning.ipynb | 2 +- examples/DVCLive-Quickstart.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/DVCLive-HuggingFace.ipynb b/examples/DVCLive-HuggingFace.ipynb index 31b4e4f4..f1d7462d 100644 --- a/examples/DVCLive-HuggingFace.ipynb +++ b/examples/DVCLive-HuggingFace.ipynb @@ -106,7 +106,7 @@ " train_dataset=small_train_dataset,\n", " eval_dataset=small_eval_dataset,\n", " compute_metrics=compute_metrics,\n", - " callbacks=[DVCLiveCallback(report=\"notebook\", \n", + " callbacks=[DVCLiveCallback(log_model=True, report=\"notebook\")]\n", " )\n", " trainer.train()" ] diff --git a/examples/DVCLive-PyTorch-Lightning.ipynb b/examples/DVCLive-PyTorch-Lightning.ipynb index 2ee1a322..1e140e48 100644 --- a/examples/DVCLive-PyTorch-Lightning.ipynb +++ b/examples/DVCLive-PyTorch-Lightning.ipynb @@ -173,7 +173,7 @@ " limit_train_batches=200,\n", " limit_val_batches=100,\n", " max_epochs=5,\n", - " logger=DVCLiveLogger(, log_model=True),\n", + " logger=DVCLiveLogger(log_model=True, report=\"notebook\"),\n", " )\n", " trainer.fit(model, train_loader, validation_loader)\n" ] diff --git a/examples/DVCLive-Quickstart.ipynb b/examples/DVCLive-Quickstart.ipynb index 8c589a93..00026152 100644 --- a/examples/DVCLive-Quickstart.ipynb +++ b/examples/DVCLive-Quickstart.ipynb @@ -220,7 +220,7 @@ "\n", "best_test_acc = 0\n", "\n", - "with Live() as live:\n", + "with Live(report=\"notebook\") as live:\n", "\n", " live.log_params(params)\n", "\n", From 00ac8871e648b32bbafcfd7edae02a5930f7cd95 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Thu, 7 Sep 2023 03:16:34 -0400 Subject: [PATCH 11/33] Write to root dvc.yaml (#687) * add dvcyaml to root * clean up dvcyaml implementation * fix existing tests * add new tests * add unit tests for updating dvcyaml * use posix paths * don't resolve symlinks * drop entire dvclive dir on cleanup * fix studio tests * revert cleanup changes * unify rel_path util func * cleanup test * refactor tests * add test for multiple dvclive instances * put dvc_file logic into _init_dvc_file --------- Co-authored-by: daavoo --- src/dvclive/dvc.py | 76 +++++-- src/dvclive/error.py | 5 + src/dvclive/live.py | 39 +++- src/dvclive/serialize.py | 9 + src/dvclive/studio.py | 16 +- src/dvclive/utils.py | 5 + tests/test_dvc.py | 146 +------------ tests/test_log_artifact.py | 42 ++-- tests/test_main.py | 46 +++- tests/test_make_dvcyaml.py | 418 +++++++++++++++++++++++++++++++++++++ tests/test_studio.py | 24 +-- tests/test_vscode.py | 20 -- 12 files changed, 599 insertions(+), 247 deletions(-) create mode 100644 tests/test_make_dvcyaml.py diff --git a/src/dvclive/dvc.py b/src/dvclive/dvc.py index f363100e..4aafb1e9 100644 --- a/src/dvclive/dvc.py +++ b/src/dvclive/dvc.py @@ -6,7 +6,7 @@ from dvclive.plots import Image, Metric from dvclive.serialize import dump_yaml -from dvclive.utils import StrPath +from dvclive.utils import StrPath, rel_path if TYPE_CHECKING: from dvc.repo import Repo @@ -51,38 +51,78 @@ def get_dvc_repo() -> Optional["Repo"]: return None -def make_dvcyaml(live) -> None: +def make_dvcyaml(live) -> None: # noqa: C901 + dvcyaml_dir = Path(live.dvc_file).parent.absolute().as_posix() + dvcyaml = {} if live._params: - dvcyaml["params"] = [os.path.relpath(live.params_file, live.dir)] + dvcyaml["params"] = [rel_path(live.params_file, dvcyaml_dir)] if live._metrics or live.summary: - dvcyaml["metrics"] = [os.path.relpath(live.metrics_file, live.dir)] + dvcyaml["metrics"] = [rel_path(live.metrics_file, dvcyaml_dir)] plots: List[Any] = [] plots_path = Path(live.plots_dir) - metrics_path = plots_path / Metric.subfolder - if metrics_path.exists(): - metrics_relpath = metrics_path.relative_to(live.dir).as_posix() - metrics_config = {metrics_relpath: {"x": "step"}} + plots_metrics_path = plots_path / Metric.subfolder + if plots_metrics_path.exists(): + metrics_config = {rel_path(plots_metrics_path, dvcyaml_dir): {"x": "step"}} plots.append(metrics_config) if live._images: - images_path = (plots_path / Image.subfolder).relative_to(live.dir) - plots.append(images_path.as_posix()) + images_path = rel_path(plots_path / Image.subfolder, dvcyaml_dir) + plots.append(images_path) if live._plots: for plot in live._plots.values(): - plot_path = plot.output_path.relative_to(live.dir) - plots.append({plot_path.as_posix(): plot.plot_config}) + plot_path = rel_path(plot.output_path, dvcyaml_dir) + plots.append({plot_path: plot.plot_config}) if plots: dvcyaml["plots"] = plots if live._artifacts: dvcyaml["artifacts"] = copy.deepcopy(live._artifacts) for artifact in dvcyaml["artifacts"].values(): # type: ignore - abs_path = os.path.abspath(artifact["path"]) - abs_dir = os.path.realpath(live.dir) - relative_path = os.path.relpath(abs_path, abs_dir) - artifact["path"] = Path(relative_path).as_posix() - - dump_yaml(dvcyaml, live.dvc_file) + artifact["path"] = rel_path(artifact["path"], dvcyaml_dir) + + if not os.path.exists(live.dvc_file): + dump_yaml(dvcyaml, live.dvc_file) + else: + update_dvcyaml(live, dvcyaml) + + +def update_dvcyaml(live, updates): # noqa: C901 + from dvc.utils.serialize import modify_yaml + + dvcyaml_dir = os.path.abspath(os.path.dirname(live.dvc_file)) + dvclive_dir = os.path.relpath(live.dir, dvcyaml_dir) + "/" + + def _drop_stale_dvclive_entries(entries): + non_dvclive = [] + for e in entries: + if isinstance(e, str): + if dvclive_dir not in e: + non_dvclive.append(e) + elif isinstance(e, dict) and len(e) == 1: + if dvclive_dir not in next(iter(e.keys())): + non_dvclive.append(e) + else: + non_dvclive.append(e) + return non_dvclive + + def _update_entries(old, new, key): + keepers = _drop_stale_dvclive_entries(old.get(key, [])) + old[key] = keepers + new.get(key, []) + if not old[key]: + del old[key] + return old + + with modify_yaml(live.dvc_file) as orig: + orig = _update_entries(orig, updates, "params") # noqa: PLW2901 + orig = _update_entries(orig, updates, "metrics") # noqa: PLW2901 + orig = _update_entries(orig, updates, "plots") # noqa: PLW2901 + old_artifacts = {} + for name, meta in orig.get("artifacts", {}).items(): + if dvclive_dir not in meta.get("path", dvclive_dir): + old_artifacts[name] = meta + orig["artifacts"] = {**old_artifacts, **updates.get("artifacts", {})} + if not orig["artifacts"]: + del orig["artifacts"] def get_random_exp_name(scm, baseline_rev) -> str: diff --git a/src/dvclive/error.py b/src/dvclive/error.py index 3b795b03..790e64c3 100644 --- a/src/dvclive/error.py +++ b/src/dvclive/error.py @@ -12,6 +12,11 @@ def __init__(self, name, val): super().__init__(f"Data '{name}' has not supported type {val}") +class InvalidDvcyamlError(DvcLiveError): + def __init__(self): + super().__init__("`dvcyaml` path must have filename 'dvc.yaml'") + + class InvalidPlotTypeError(DvcLiveError): def __init__(self, name): from .plots import SKLEARN_PLOTS diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 76303a74..679b5601 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -1,3 +1,4 @@ +import glob import json import logging import math @@ -21,6 +22,7 @@ ) from .error import ( InvalidDataTypeError, + InvalidDvcyamlError, InvalidParameterTypeError, InvalidPlotTypeError, InvalidReportModeError, @@ -62,7 +64,7 @@ def __init__( resume: bool = False, report: Optional[str] = None, save_dvc_exp: bool = True, - dvcyaml: bool = True, + dvcyaml: Union[str, bool] = True, cache_images: bool = False, exp_message: Optional[str] = None, ): @@ -87,11 +89,6 @@ def __init__( self._report_notebook = None self._init_report() - if self._resume: - self._init_resume() - else: - self._init_cleanup() - self._baseline_rev: Optional[str] = None self._exp_name: Optional[str] = None self._exp_message: Optional[str] = exp_message @@ -101,6 +98,11 @@ def __init__( self._include_untracked: List[str] = [] self._init_dvc() + if self._resume: + self._init_resume() + else: + self._init_cleanup() + self._latest_studio_step = self.step if resume else -1 self._studio_events_to_skip: Set[str] = set() self._dvc_studio_config: Dict[str, Any] = {} @@ -129,8 +131,8 @@ def _init_cleanup(self): if f and os.path.exists(f): os.remove(f) - if self.dvc_file and os.path.exists(self.dvc_file): - os.remove(self.dvc_file) + for dvc_file in glob.glob(os.path.join(self.dir, "**dvc.yaml")): + os.remove(dvc_file) @catch_and_warn(DvcException, logger) def _init_dvc(self): @@ -150,6 +152,8 @@ def _init_dvc(self): dvc_logger = logging.getLogger("dvc") dvc_logger.setLevel(os.getenv(env.DVCLIVE_LOGLEVEL, "WARNING").upper()) + self._dvc_file = self._init_dvc_file() + if (self._dvc_repo is None) or isinstance(self._dvc_repo.scm, NoSCM): if self._save_dvc_exp: logger.warning( @@ -184,6 +188,19 @@ def _init_dvc(self): mark_dvclive_only_started(self._exp_name) self._include_untracked.append(self.dir) + def _init_dvc_file(self) -> str: + if isinstance(self._dvcyaml, str): + if os.path.basename(self._dvcyaml) == "dvc.yaml": + return self._dvcyaml + raise InvalidDvcyamlError + if self._dvc_repo is not None: + return os.path.join(self._dvc_repo.root_dir, "dvc.yaml") + logger.warning( + "Can't infer dvcyaml path without a DVC repo. " + "`dvc.yaml` file will not be written." + ) + return "" + def _init_studio(self): self._dvc_studio_config = get_dvc_studio_config(self) if not self._dvc_studio_config: @@ -263,7 +280,7 @@ def metrics_file(self) -> str: @property def dvc_file(self) -> str: - return os.path.join(self.dir, "dvc.yaml") + return self._dvc_file @property def plots_dir(self) -> str: @@ -536,8 +553,10 @@ def make_report(self): if self._report_mode == "html" and env2bool(env.DVCLIVE_OPEN): open_file_in_browser(self.report_file) + @catch_and_warn(DvcException, logger) def make_dvcyaml(self): - make_dvcyaml(self) + if self.dvc_file: + make_dvcyaml(self) def end(self): if self._inside_with: diff --git a/src/dvclive/serialize.py b/src/dvclive/serialize.py index e97dbe4f..5e1ed221 100644 --- a/src/dvclive/serialize.py +++ b/src/dvclive/serialize.py @@ -1,4 +1,5 @@ import json +import os from collections import OrderedDict from dvclive.error import DvcLiveError @@ -39,11 +40,19 @@ def get_yaml(): def dump_yaml(content, output_file): yaml = get_yaml() + make_dir(output_file) with open(output_file, "w", encoding="utf-8") as fd: yaml.dump(content, fd) def dump_json(content, output_file, indent=4, **kwargs): + make_dir(output_file) with open(output_file, "w", encoding="utf-8") as f: json.dump(content, f, indent=indent, **kwargs) f.write("\n") + + +def make_dir(output_file): + output_dir = os.path.dirname(output_file) + if output_dir: + os.makedirs(output_dir, exist_ok=True) diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 1ad14b1a..fd6ab66a 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -2,12 +2,11 @@ import base64 import math import os -from pathlib import Path from dvc_studio_client.post_live_metrics import get_studio_config from dvclive.serialize import load_yaml -from dvclive.utils import parse_metrics +from dvclive.utils import parse_metrics, rel_path def _get_unsent_datapoints(plot, latest_step): @@ -30,18 +29,13 @@ def _cast_to_numbers(datapoints): return datapoints -def _rel_path(path, dvc_root_path): - absolute_path = Path(path).resolve() - return str(absolute_path.relative_to(dvc_root_path).as_posix()) - - def _adapt_plot_name(live, name): if live._dvc_repo is not None: - name = _rel_path(name, live._dvc_repo.root_dir) + name = rel_path(name, live._dvc_repo.root_dir) if os.path.isfile(live.dvc_file): dvc_file = live.dvc_file if live._dvc_repo is not None: - dvc_file = _rel_path(live.dvc_file, live._dvc_repo.root_dir) + dvc_file = rel_path(live.dvc_file, live._dvc_repo.root_dir) name = f"{dvc_file}::{name}" return name @@ -70,7 +64,7 @@ def get_studio_updates(live): if os.path.isfile(live.params_file): params_file = live.params_file if live._dvc_repo is not None: - params_file = _rel_path(params_file, live._dvc_repo.root_dir) + params_file = rel_path(params_file, live._dvc_repo.root_dir) params = {params_file: load_yaml(live.params_file)} else: params = {} @@ -79,7 +73,7 @@ def get_studio_updates(live): metrics_file = live.metrics_file if live._dvc_repo is not None: - metrics_file = _rel_path(metrics_file, live._dvc_repo.root_dir) + metrics_file = rel_path(metrics_file, live._dvc_repo.root_dir) metrics = {metrics_file: {"data": metrics}} plots = { diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 383b6f23..a5cfd319 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -171,3 +171,8 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def rel_path(path, dvc_root_path): + absolute_path = Path(path).absolute() + return str(Path(os.path.relpath(absolute_path, dvc_root_path)).as_posix()) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index adf6ef04..754360e1 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -4,16 +4,11 @@ from dvc.exceptions import DvcException from dvc.repo import Repo from dvc.scm import NoSCM -from PIL import Image -from ruamel.yaml import YAML from scmrepo.git import Git from dvclive import Live -from dvclive.dvc import get_dvc_repo, make_dvcyaml +from dvclive.dvc import get_dvc_repo from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME -from dvclive.serialize import load_yaml - -YAML_LOADER = YAML(typ="safe") def test_get_dvc_repo(tmp_dir): @@ -30,110 +25,6 @@ def test_get_dvc_repo_subdir(tmp_dir): assert get_dvc_repo().root_dir == str(tmp_dir) -def test_make_dvcyaml_empty(tmp_dir): - live = Live() - make_dvcyaml(live) - - assert load_yaml(live.dvc_file) == {} - - -def test_make_dvcyaml_param(tmp_dir): - live = Live() - live.log_param("foo", 1) - make_dvcyaml(live) - - assert load_yaml(live.dvc_file) == { - "params": ["params.yaml"], - } - - -def test_make_dvcyaml_metrics(tmp_dir): - live = Live() - live.log_metric("bar", 2) - make_dvcyaml(live) - - assert load_yaml(live.dvc_file) == { - "metrics": ["metrics.json"], - "plots": [{"plots/metrics": {"x": "step"}}], - } - - -def test_make_dvcyaml_metrics_no_plots(tmp_dir): - live = Live() - live.log_metric("bar", 2, plot=False) - make_dvcyaml(live) - - assert load_yaml(live.dvc_file) == { - "metrics": ["metrics.json"], - } - - -def test_make_dvcyaml_summary(tmp_dir): - live = Live() - live.summary["bar"] = 2 - make_dvcyaml(live) - - assert load_yaml(live.dvc_file) == { - "metrics": ["metrics.json"], - } - - -def test_make_dvcyaml_all_plots(tmp_dir): - live = Live() - live.log_param("foo", 1) - live.log_metric("bar", 2) - live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250))) - live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0]) - live.log_sklearn_plot( - "confusion_matrix", - [0, 0, 1, 1], - [0, 1, 1, 0], - name="confusion_matrix_normalized", - normalized=True, - ) - live.log_sklearn_plot("roc", [0, 0, 1, 1], [0.0, 0.5, 0.5, 0.0], "custom_name_roc") - make_dvcyaml(live) - - assert load_yaml(live.dvc_file) == { - "metrics": ["metrics.json"], - "params": ["params.yaml"], - "plots": [ - {"plots/metrics": {"x": "step"}}, - "plots/images", - { - "plots/sklearn/confusion_matrix.json": { - "template": "confusion", - "x": "actual", - "y": "predicted", - "title": "Confusion Matrix", - "x_label": "True Label", - "y_label": "Predicted Label", - }, - }, - { - "plots/sklearn/confusion_matrix_normalized.json": { - "template": "confusion_normalized", - "title": "Confusion Matrix", - "x": "actual", - "x_label": "True Label", - "y": "predicted", - "y_label": "Predicted Label", - } - }, - { - "plots/sklearn/custom_name_roc.json": { - "template": "simple", - "x": "fpr", - "y": "tpr", - "title": "Receiver operating characteristic (ROC)", - "x_label": "False Positive Rate", - "y_label": "True Positive Rate", - } - }, - ], - } - - @pytest.mark.parametrize("save", [True, False]) def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): live = Live(save_dvc_exp=save) @@ -188,26 +79,6 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker): ) -@pytest.mark.parametrize("dvcyaml", [True, False]) -def test_dvcyaml_on_next_step(tmp_dir, dvcyaml, mocked_dvc_repo): - live = Live(dvcyaml=dvcyaml) - live.next_step() - if dvcyaml: - assert (tmp_dir / live.dvc_file).exists() - else: - assert not (tmp_dir / live.dvc_file).exists() - - -@pytest.mark.parametrize("dvcyaml", [True, False]) -def test_dvcyaml_on_end(tmp_dir, dvcyaml, mocked_dvc_repo): - live = Live(dvcyaml=dvcyaml) - live.end() - if dvcyaml: - assert (tmp_dir / live.dvc_file).exists() - else: - assert not (tmp_dir / live.dvc_file).exists() - - def test_exp_save_with_dvc_files(tmp_dir, mocker): dvc_repo = mocker.MagicMock() dvc_file = mocker.MagicMock() @@ -286,21 +157,6 @@ def test_errors_on_git_add_are_catched(tmp_dir, mocked_dvc_repo, monkeypatch): live.summary["foo"] = 1 -def test_make_dvcyaml_idempotent(tmp_dir, mocked_dvc_repo): - (tmp_dir / "model.pth").touch() - - with Live() as live: - live.log_artifact("model.pth", type="model") - - live.make_dvcyaml() - - assert load_yaml(live.dvc_file) == { - "artifacts": { - "model": {"path": "../model.pth", "type": "model"}, - } - } - - def test_exp_save_message(tmp_dir, mocked_dvc_repo): live = Live(exp_message="Custom message") live.end() diff --git a/tests/test_log_artifact.py b/tests/test_log_artifact.py index 05c50bad..3622d4b3 100644 --- a/tests/test_log_artifact.py +++ b/tests/test_log_artifact.py @@ -71,26 +71,26 @@ def test_log_artifact_type_model(tmp_dir, mocked_dvc_repo): live.log_artifact("model.pth", type="model") assert load_yaml(live.dvc_file) == { - "artifacts": {"model": {"path": "../model.pth", "type": "model"}} + "artifacts": {"model": {"path": "model.pth", "type": "model"}} } def test_log_artifact_dvc_symlink(tmp_dir, dvc_repo): (tmp_dir / "model.pth").touch() - with Live(save_dvc_exp=False) as live: + with Live(save_dvc_exp=False, dvcyaml="dvc.yaml") as live: live._dvc_repo.cache.local.cache_types = ["symlink"] live.log_artifact("model.pth", type="model") assert load_yaml(live.dvc_file) == { - "artifacts": {"model": {"path": "../model.pth", "type": "model"}} + "artifacts": {"model": {"path": "model.pth", "type": "model"}} } def test_log_artifact_copy(tmp_dir, dvc_repo): (tmp_dir / "model.pth").touch() - with Live(save_dvc_exp=False) as live: + with Live(save_dvc_exp=False, dvcyaml="dvc.yaml") as live: live.log_artifact("model.pth", type="model", copy=True) artifacts_dir = Path(live.artifacts_dir) @@ -98,14 +98,14 @@ def test_log_artifact_copy(tmp_dir, dvc_repo): assert (artifacts_dir / "model.pth.dvc").exists() assert load_yaml(live.dvc_file) == { - "artifacts": {"model": {"path": "artifacts/model.pth", "type": "model"}} + "artifacts": {"model": {"path": "dvclive/artifacts/model.pth", "type": "model"}} } def test_log_artifact_copy_overwrite(tmp_dir, dvc_repo): (tmp_dir / "model.pth").touch() - with Live(save_dvc_exp=False) as live: + with Live(save_dvc_exp=False, dvcyaml="dvc.yaml") as live: artifacts_dir = Path(live.artifacts_dir) # testing with symlink cache to make sure that DVC protected mode # does not prevent the overwrite @@ -118,7 +118,7 @@ def test_log_artifact_copy_overwrite(tmp_dir, dvc_repo): assert (artifacts_dir / "model.pth.dvc").exists() assert load_yaml(live.dvc_file) == { - "artifacts": {"model": {"path": "artifacts/model.pth", "type": "model"}} + "artifacts": {"model": {"path": "dvclive/artifacts/model.pth", "type": "model"}} } @@ -127,7 +127,7 @@ def test_log_artifact_copy_directory_overwrite(tmp_dir, dvc_repo): model_path.mkdir() (tmp_dir / "weights" / "model-epoch-1.pth").touch() - with Live(save_dvc_exp=False) as live: + with Live(save_dvc_exp=False, dvcyaml="dvc.yaml") as live: artifacts_dir = Path(live.artifacts_dir) # testing with symlink cache to make sure that DVC protected mode # does not prevent the overwrite @@ -148,50 +148,50 @@ def test_log_artifact_copy_directory_overwrite(tmp_dir, dvc_repo): assert len(list((artifacts_dir / "weights").iterdir())) == 2 assert load_yaml(live.dvc_file) == { - "artifacts": {"weights": {"path": "artifacts/weights", "type": "model"}} + "artifacts": {"weights": {"path": "dvclive/artifacts/weights", "type": "model"}} } def test_log_artifact_type_model_provided_name(tmp_dir, mocked_dvc_repo): (tmp_dir / "model.pth").touch() - with Live() as live: + with Live(dvcyaml="dvc.yaml") as live: live.log_artifact("model.pth", type="model", name="custom") assert load_yaml(live.dvc_file) == { - "artifacts": {"custom": {"path": "../model.pth", "type": "model"}} + "artifacts": {"custom": {"path": "model.pth", "type": "model"}} } def test_log_artifact_type_model_on_step_and_final(tmp_dir, mocked_dvc_repo): (tmp_dir / "model.pth").touch() - with Live() as live: + with Live(dvcyaml="dvc.yaml") as live: for _ in range(3): live.log_artifact("model.pth", type="model") live.next_step() live.log_artifact("model.pth", type="model", labels=["final"]) assert load_yaml(live.dvc_file) == { "artifacts": { - "model": {"path": "../model.pth", "type": "model", "labels": ["final"]}, + "model": {"path": "model.pth", "type": "model", "labels": ["final"]}, }, - "metrics": ["metrics.json"], + "metrics": ["dvclive/metrics.json"], } def test_log_artifact_type_model_on_step(tmp_dir, mocked_dvc_repo): (tmp_dir / "model.pth").touch() - with Live() as live: + with Live(dvcyaml="dvc.yaml") as live: for _ in range(3): live.log_artifact("model.pth", type="model") live.next_step() assert load_yaml(live.dvc_file) == { "artifacts": { - "model": {"path": "../model.pth", "type": "model"}, + "model": {"path": "model.pth", "type": "model"}, }, - "metrics": ["metrics.json"], + "metrics": ["dvclive/metrics.json"], } @@ -205,12 +205,12 @@ def test_log_artifact_attrs(tmp_dir, mocked_dvc_repo): "labels": ["foo"], "meta": {"foo": "bar"}, } - with Live() as live: + with Live(dvcyaml="dvc.yaml") as live: live.log_artifact("model.pth", **attrs) attrs.pop("name") assert load_yaml(live.dvc_file) == { "artifacts": { - "foo": {"path": "../model.pth", **attrs}, + "foo": {"path": "model.pth", **attrs}, } } @@ -218,11 +218,11 @@ def test_log_artifact_attrs(tmp_dir, mocked_dvc_repo): def test_log_artifact_type_model_when_dvc_add_fails(tmp_dir, mocker, mocked_dvc_repo): (tmp_dir / "model.pth").touch() mocked_dvc_repo.add.side_effect = DvcException("foo") - with Live(save_dvc_exp=True) as live: + with Live(save_dvc_exp=True, dvcyaml="dvc.yaml") as live: live.log_artifact("model.pth", type="model") assert load_yaml(live.dvc_file) == { - "artifacts": {"model": {"path": "../model.pth", "type": "model"}} + "artifacts": {"model": {"path": "model.pth", "type": "model"}} } diff --git a/tests/test_main.py b/tests/test_main.py index c859740f..d076a5e7 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -6,7 +6,7 @@ from PIL import Image from dvclive import Live, env -from dvclive.error import InvalidParameterTypeError +from dvclive.error import InvalidDvcyamlError, InvalidParameterTypeError from dvclive.plots import Metric from dvclive.serialize import load_yaml from dvclive.utils import parse_metrics, parse_tsv @@ -158,8 +158,12 @@ def test_nested_logging(tmp_dir): "html", [True, False], ) -def test_cleanup(tmp_dir, html): - dvclive = Live("logs", report="html" if html else None) +@pytest.mark.parametrize( + "dvcyaml", + ["dvc.yaml", "logs/dvc.yaml"], +) +def test_cleanup(tmp_dir, html, dvcyaml): + dvclive = Live("logs", report="html" if html else None, dvcyaml=dvcyaml) dvclive.log_metric("m1", 1) dvclive.next_step() @@ -168,6 +172,7 @@ def test_cleanup(tmp_dir, html): html_path.touch() (tmp_dir / "logs" / "some_user_file.txt").touch() + (tmp_dir / "dvc.yaml").touch() assert (tmp_dir / dvclive.plots_dir / Metric.subfolder / "m1.tsv").is_file() assert (tmp_dir / dvclive.metrics_file).is_file() @@ -179,7 +184,10 @@ def test_cleanup(tmp_dir, html): assert (tmp_dir / "logs" / "some_user_file.txt").is_file() assert not (tmp_dir / dvclive.plots_dir / Metric.subfolder).exists() assert not (tmp_dir / dvclive.metrics_file).is_file() - assert not (tmp_dir / dvclive.dvc_file).is_file() + if dvcyaml == "dvc.yaml": + assert (tmp_dir / dvcyaml).is_file() + if dvcyaml == "logs/dvc.yaml": + assert not (tmp_dir / dvcyaml).is_file() assert not (html_path).is_file() @@ -402,16 +410,38 @@ def test_context_manager_skips_end_calls(tmp_dir): @pytest.mark.parametrize( "dvcyaml", - [True, False], + [True, False, "dvc.yaml"], ) -def test_make_dvcyaml(tmp_dir, dvcyaml): +def test_make_dvcyaml(tmp_dir, mocked_dvc_repo, dvcyaml): dvclive = Live("logs", dvcyaml=dvcyaml) dvclive.log_metric("m1", 1) + dvclive.next_step() + + if dvcyaml: + assert "metrics" in load_yaml(dvclive.dvc_file) + else: + assert not os.path.exists(dvclive.dvc_file) + dvclive.make_dvcyaml() + assert "metrics" in load_yaml(dvclive.dvc_file) + + +def test_make_dvcyaml_no_repo(tmp_dir, mocker): + logger = mocker.patch("dvclive.live.logger") + dvclive = Live("logs") + dvclive.make_dvcyaml() + + assert not os.path.exists("dvc.yaml") + assert not dvclive.dvc_file + logger.warning.assert_any_call( + "Can't infer dvcyaml path without a DVC repo. " + "`dvc.yaml` file will not be written." + ) - dvcyaml_path = tmp_dir / dvclive.dir / "dvc.yaml" - assert dvcyaml_path.is_file() +def test_make_dvcyaml_invalid(tmp_dir, mocker): + with pytest.raises(InvalidDvcyamlError): + Live("logs", dvcyaml="invalid") def test_suppress_dvc_logs(tmp_dir, mocked_dvc_repo): diff --git a/tests/test_make_dvcyaml.py b/tests/test_make_dvcyaml.py new file mode 100644 index 00000000..465e0a24 --- /dev/null +++ b/tests/test_make_dvcyaml.py @@ -0,0 +1,418 @@ +import pytest +from PIL import Image + +from dvclive import Live +from dvclive.dvc import make_dvcyaml +from dvclive.serialize import dump_yaml, load_yaml + + +def test_make_dvcyaml_empty(tmp_dir): + live = Live(dvcyaml="dvc.yaml") + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == {} + + +def test_make_dvcyaml_param(tmp_dir): + live = Live(dvcyaml="dvc.yaml") + live.log_param("foo", 1) + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == { + "params": ["dvclive/params.yaml"], + } + + +def test_make_dvcyaml_metrics(tmp_dir): + live = Live(dvcyaml="dvc.yaml") + live.log_metric("bar", 2) + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == { + "metrics": ["dvclive/metrics.json"], + "plots": [{"dvclive/plots/metrics": {"x": "step"}}], + } + + +def test_make_dvcyaml_metrics_no_plots(tmp_dir): + live = Live(dvcyaml="dvc.yaml") + live.log_metric("bar", 2, plot=False) + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == { + "metrics": ["dvclive/metrics.json"], + } + + +def test_make_dvcyaml_summary(tmp_dir): + live = Live(dvcyaml="dvc.yaml") + live.summary["bar"] = 2 + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == { + "metrics": ["dvclive/metrics.json"], + } + + +def test_make_dvcyaml_all_plots(tmp_dir): + live = Live(dvcyaml="dvc.yaml") + live.log_param("foo", 1) + live.log_metric("bar", 2) + live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250))) + live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0]) + live.log_sklearn_plot( + "confusion_matrix", + [0, 0, 1, 1], + [0, 1, 1, 0], + name="confusion_matrix_normalized", + normalized=True, + ) + live.log_sklearn_plot("roc", [0, 0, 1, 1], [0.0, 0.5, 0.5, 0.0], "custom_name_roc") + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == { + "metrics": ["dvclive/metrics.json"], + "params": ["dvclive/params.yaml"], + "plots": [ + {"dvclive/plots/metrics": {"x": "step"}}, + "dvclive/plots/images", + { + "dvclive/plots/sklearn/confusion_matrix.json": { + "template": "confusion", + "x": "actual", + "y": "predicted", + "title": "Confusion Matrix", + "x_label": "True Label", + "y_label": "Predicted Label", + }, + }, + { + "dvclive/plots/sklearn/confusion_matrix_normalized.json": { + "template": "confusion_normalized", + "title": "Confusion Matrix", + "x": "actual", + "x_label": "True Label", + "y": "predicted", + "y_label": "Predicted Label", + } + }, + { + "dvclive/plots/sklearn/custom_name_roc.json": { + "template": "simple", + "x": "fpr", + "y": "tpr", + "title": "Receiver operating characteristic (ROC)", + "x_label": "False Positive Rate", + "y_label": "True Positive Rate", + } + }, + ], + } + + +def test_make_dvcyaml_relpath(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + live = Live(dvcyaml="dir/dvc.yaml") + live.log_metric("foo", 1) + live.log_artifact("model.pth", type="model") + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == { + "metrics": ["../dvclive/metrics.json"], + "plots": [{"../dvclive/plots/metrics": {"x": "step"}}], + "artifacts": { + "model": {"path": "../model.pth", "type": "model"}, + }, + } + + +@pytest.mark.parametrize( + ("orig_yaml", "updated_yaml"), + [ + pytest.param( + {"stages": {"train": {"cmd": "train.py"}}}, + { + "stages": {"train": {"cmd": "train.py"}}, + "metrics": ["dvclive/metrics.json"], + "plots": [ + {"dvclive/plots/metrics": {"x": "step"}}, + ], + }, + id="stages", + ), + pytest.param( + {"params": ["dvclive/params.yaml"]}, + { + "metrics": ["dvclive/metrics.json"], + "plots": [{"dvclive/plots/metrics": {"x": "step"}}], + }, + id="drop_extra_sections", + ), + pytest.param( + {"plots": ["dvclive/plots/images"]}, + { + "metrics": ["dvclive/metrics.json"], + "plots": [{"dvclive/plots/metrics": {"x": "step"}}], + }, + id="drop_unlogged_plots", + ), + pytest.param( + {"plots": [{"dvclive/plots/metrics": {"x": "step", "y": "foo"}}]}, + { + "metrics": ["dvclive/metrics.json"], + "plots": [{"dvclive/plots/metrics": {"x": "step"}}], + }, + id="plot_props", + ), + pytest.param( + { + "plots": [ + { + "custom": { + "x": "step", + "y": {"dvclive/plots/metrics": "foo"}, + "title": "custom", + } + }, + ], + }, + { + "metrics": ["dvclive/metrics.json"], + "plots": [ + { + "custom": { + "x": "step", + "y": {"dvclive/plots/metrics": "foo"}, + "title": "custom", + } + }, + {"dvclive/plots/metrics": {"x": "step"}}, + ], + }, + id="keep_custom_plots", + ), + ], +) +def test_make_dvcyaml_update(tmp_dir, orig_yaml, updated_yaml): + dump_yaml(orig_yaml, "dvc.yaml") + + live = Live(dvcyaml="dvc.yaml") + live.log_metric("foo", 2) + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == updated_yaml + + +@pytest.mark.parametrize( + ("orig_yaml", "updated_yaml"), + [ + pytest.param( + { + "artifacts": { + "model": { + "path": "model.pth", + "type": "model", + "desc": "best model", + }, + }, + }, + { + "artifacts": { + "model": {"path": "dvclive/artifacts/model.pth", "type": "model"}, + }, + }, + id="props", + ), + pytest.param( + { + "artifacts": { + "duplicate": {"path": "dvclive/artifacts/model.pth"}, + }, + }, + { + "artifacts": { + "model": {"path": "dvclive/artifacts/model.pth", "type": "model"}, + }, + }, + id="duplicate", + ), + pytest.param( + { + "artifacts": { + "data": {"path": "data.csv", "desc": "source data"}, + }, + }, + { + "artifacts": { + "model": {"path": "dvclive/artifacts/model.pth", "type": "model"}, + "data": {"path": "data.csv", "desc": "source data"}, + }, + }, + id="keep_extra", + ), + ], +) +def test_make_dvcyaml_update_artifact( + tmp_dir, mocked_dvc_repo, orig_yaml, updated_yaml +): + dump_yaml(orig_yaml, "dvc.yaml") + (tmp_dir / "model.pth").touch() + + live = Live(dvcyaml="dvc.yaml") + live.log_artifact("model.pth", type="model", copy=True) + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == updated_yaml + + +def test_make_dvcyaml_update_all(tmp_dir, mocked_dvc_repo): + orig_yaml = { + "stages": {"train": {"cmd": "train.py"}}, + "metrics": [ + "dvclive/metrics.json", + "dvclive/metrics.yaml", + "other/metrics.json", + ], + "params": ["dvclive/params.yaml"], + "plots": [ + {"dvclive/plots/metrics": {"x": "step", "y": "foo"}}, + "dvclive/plots/images", + "other/plots", + { + "custom": { + "x": "step", + "y": {"dvclive/plots/metrics": "foo"}, + "title": "custom", + } + }, + { + "dvclive/plots/sklearn/confusion_matrix.json": { + "template": "confusion", + "x": "actual", + "y": "predicted", + "title": "Confusion Matrix", + "x_label": "True Label", + "y_label": "Predicted Label", + }, + }, + ], + "artifacts": { + "model": {"path": "dvclive/artifacts/model.pth", "type": "model"}, + "duplicate": {"path": "dvclive/artifacts/model.pth"}, + "data": {"path": "data.csv", "desc": "source data"}, + "other": {"path": "other.pth"}, + }, + } + + updated_yaml = { + "stages": {"train": {"cmd": "train.py"}}, + "metrics": ["other/metrics.json", "dvclive/metrics.json"], + "plots": [ + "other/plots", + { + "custom": { + "x": "step", + "y": {"dvclive/plots/metrics": "foo"}, + "title": "custom", + } + }, + {"dvclive/plots/metrics": {"x": "step"}}, + "dvclive/plots/images", + ], + "artifacts": { + "model": {"path": "dvclive/artifacts/model.pth", "type": "model"}, + "data": {"path": "data.csv", "desc": "source data"}, + "other": {"path": "other.pth"}, + }, + } + + dump_yaml(orig_yaml, "dvc.yaml") + (tmp_dir / "model.pth").touch() + (tmp_dir / "data.csv").touch() + + live = Live(dvcyaml="dvc.yaml") + live.log_metric("foo", 2) + live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250))) + live.log_artifact("model.pth", type="model", copy=True) + live.log_artifact("data.csv", desc="source data") + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == updated_yaml + + +def test_make_dvcyaml_update_multiple(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + + live = Live("train", dvcyaml="dvc.yaml") + live.log_metric("foo", 2) + live.log_artifact("model.pth", type="model", copy=True) + make_dvcyaml(live) + + live = Live("eval", dvcyaml="dvc.yaml") + live.log_metric("bar", 3) + make_dvcyaml(live) + + assert load_yaml(live.dvc_file) == { + "metrics": ["train/metrics.json", "eval/metrics.json"], + "plots": [ + {"train/plots/metrics": {"x": "step"}}, + {"eval/plots/metrics": {"x": "step"}}, + ], + "artifacts": { + "model": {"path": "train/artifacts/model.pth", "type": "model"}, + }, + } + + +@pytest.mark.parametrize("dvcyaml", [True, False]) +def test_dvcyaml_on_next_step(tmp_dir, dvcyaml, mocked_dvc_repo): + live = Live(dvcyaml=dvcyaml) + live.next_step() + if dvcyaml: + assert (tmp_dir / live.dvc_file).exists() + else: + assert not (tmp_dir / live.dvc_file).exists() + + +@pytest.mark.parametrize("dvcyaml", [True, False]) +def test_dvcyaml_on_end(tmp_dir, dvcyaml, mocked_dvc_repo): + live = Live(dvcyaml=dvcyaml) + live.end() + if dvcyaml: + assert (tmp_dir / live.dvc_file).exists() + else: + assert not (tmp_dir / live.dvc_file).exists() + + +def test_make_dvcyaml_idempotent(tmp_dir, mocked_dvc_repo): + (tmp_dir / "model.pth").touch() + + with Live() as live: + live.log_artifact("model.pth", type="model") + + live.make_dvcyaml() + + assert load_yaml(live.dvc_file) == { + "artifacts": { + "model": {"path": "model.pth", "type": "model"}, + } + } + + +@pytest.mark.parametrize("dvcyaml", [True, False, "dvclive/dvc.yaml"]) +def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyaml): + logger = mocker.patch("dvclive.live.logger") + dvc_stage = mocker.MagicMock() + dvc_stage.addressing = "train" + dvc_out = mocker.MagicMock() + dvc_out.fs_path = tmp_dir / "dvclive" + dvc_stage.outs = [dvc_out] + mocked_dvc_repo.index.stages = [dvc_stage] + live = Live(dvcyaml=dvcyaml) + + if dvcyaml == "dvclive/dvc.yaml": + msg = f"'{live.dvc_file}' is in outputs of stage 'train'.\n" + msg += "Remove it from outputs to make DVCLive work as expected." + logger.warning.assert_called_with(msg) + else: + logger.warning.assert_not_called() diff --git a/tests/test_studio.py b/tests/test_studio.py index 24a9dbdc..78df948f 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -36,7 +36,7 @@ def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): live = Live() live.log_param("fooparam", 1) - dvc_path = Path(live.dvc_file).as_posix() + dvc_path = Path(live.dvc_file).relative_to(mocked_dvc_repo.root_dir).as_posix() foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() mocked_post, _ = mocked_studio_post @@ -89,7 +89,7 @@ def test_post_to_studio_failed_data_request( live = Live() - dvc_path = Path(live.dvc_file).as_posix() + dvc_path = Path(live.dvc_file).relative_to(mocked_dvc_repo.root_dir).as_posix() foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() error_response = mocker.MagicMock() @@ -205,7 +205,7 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po live.log_metric("eval/loss", 1) live.next_step() - dvc_path = Path(live.dvc_file).as_posix() + dvc_path = Path(live.dvc_file).relative_to(mocked_dvc_repo.root_dir).as_posix() plots_path = Path(live.plots_dir) loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix() @@ -238,7 +238,7 @@ def test_post_to_studio_inside_dvc_exp( @pytest.mark.studio() def test_post_to_studio_inside_subdir( - tmp_dir, dvc_repo, mocker, monkeypatch, mocked_studio_post + tmp_dir, dvc_repo, mocker, monkeypatch, mocked_studio_post, mocked_dvc_repo ): mocked_post, _ = mocked_studio_post subdir = tmp_dir / "subdir" @@ -249,7 +249,7 @@ def test_post_to_studio_inside_subdir( live.log_metric("foo", 1) live.next_step() - dvc_path = Path(live.dvc_file).as_posix() + dvc_path = Path(live.dvc_file).relative_to(mocked_dvc_repo.root_dir).as_posix() foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() mocked_post.assert_called_with( @@ -260,9 +260,7 @@ def test_post_to_studio_inside_subdir( exp_name=live._exp_name, step=0, plots={ - f"subdir/{dvc_path}::subdir/{foo_path}": { - "data": [{"step": 0, "foo": 1.0}] - } + f"{dvc_path}::subdir/{foo_path}": {"data": [{"step": 0, "foo": 1.0}]} }, ), ) @@ -270,7 +268,7 @@ def test_post_to_studio_inside_subdir( @pytest.mark.studio() def test_post_to_studio_inside_subdir_dvc_exp( - tmp_dir, dvc_repo, monkeypatch, mocked_studio_post + tmp_dir, dvc_repo, monkeypatch, mocked_studio_post, mocked_dvc_repo ): mocked_post, _ = mocked_studio_post subdir = tmp_dir / "subdir" @@ -284,7 +282,7 @@ def test_post_to_studio_inside_subdir_dvc_exp( live.log_metric("foo", 1) live.next_step() - dvc_path = Path(live.dvc_file).as_posix() + dvc_path = Path(live.dvc_file).relative_to(mocked_dvc_repo.root_dir).as_posix() foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() mocked_post.assert_called_with( @@ -295,9 +293,7 @@ def test_post_to_studio_inside_subdir_dvc_exp( exp_name=live._exp_name, step=0, plots={ - f"subdir/{dvc_path}::subdir/{foo_path}": { - "data": [{"step": 0, "foo": 1.0}] - } + f"{dvc_path}::subdir/{foo_path}": {"data": [{"step": 0, "foo": 1.0}]} }, ), ) @@ -343,7 +339,7 @@ def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post): live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0))) live.next_step() - dvc_path = Path(live.dvc_file).as_posix() + dvc_path = Path(live.dvc_file).relative_to(mocked_dvc_repo.root_dir).as_posix() foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix() mocked_post.assert_called_with( diff --git a/tests/test_vscode.py b/tests/test_vscode.py index 70f40580..7f56ab09 100644 --- a/tests/test_vscode.py +++ b/tests/test_vscode.py @@ -91,23 +91,3 @@ def test_vscode_dvclive_only_signal_file(tmp_dir, dvc_root, mocker): dvclive.end() assert not os.path.exists(signal_file) - - -@pytest.mark.vscode() -@pytest.mark.parametrize("dvcyaml", [True, False]) -def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyaml): - logger = mocker.patch("dvclive.live.logger") - dvc_stage = mocker.MagicMock() - dvc_stage.addressing = "train" - dvc_out = mocker.MagicMock() - dvc_out.fs_path = tmp_dir / "dvclive" - dvc_stage.outs = [dvc_out] - mocked_dvc_repo.index.stages = [dvc_stage] - live = Live(dvcyaml=dvcyaml) - - if dvcyaml: - msg = f"'{live.dvc_file}' is in outputs of stage 'train'.\n" - msg += "Remove it from outputs to make DVCLive work as expected." - logger.warning.assert_called_with(msg) - else: - logger.warning.assert_not_called() From 5dcc43be9a90ce10f73227316859b98ffd1225ad Mon Sep 17 00:00:00 2001 From: daavoo Date: Thu, 7 Sep 2023 09:35:30 +0200 Subject: [PATCH 12/33] report: Drop "auto" logic. Fallback to `None` when conditions are not met for other types. --- src/dvclive/lightning.py | 2 +- src/dvclive/live.py | 23 ++++++++++++++--------- tests/test_report.py | 36 ++++++++++-------------------------- 3 files changed, 25 insertions(+), 36 deletions(-) diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index 100fbb6c..9f56c5b1 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -58,7 +58,7 @@ def __init__( # noqa: PLR0913 experiment=None, dir: Optional[str] = None, # noqa: A002 resume: bool = False, - report: Optional[str] = "auto", + report: Optional[str] = None, save_dvc_exp: bool = False, dvcyaml: bool = True, cache_images: bool = False, diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 679b5601..d4ac6029 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -247,12 +247,9 @@ def _init_studio(self): self._studio_events_to_skip.add("done") def _init_report(self): - if self._report_mode == "auto": - if env2bool("CI") and matplotlib_installed(): - self._report_mode = "md" - else: - self._report_mode = "html" - elif self._report_mode == "notebook": + if self._report_mode not in {None, "html", "notebook", "md"}: + raise InvalidReportModeError(self._report_mode) + if self._report_mode == "notebook": if inside_notebook(): from IPython.display import Markdown, display @@ -261,9 +258,17 @@ def _init_report(self): Markdown(BLANK_NOTEBOOK_REPORT), display_id=True ) else: - self._report_mode = "html" - elif self._report_mode not in {None, "html", "notebook", "md"}: - raise InvalidReportModeError(self._report_mode) + logger.warning( + "Report mode 'notebook' requires to be" + " inside a notebook. Disabling report." + ) + self._report_mode = None + if self._report_mode != "html" and not matplotlib_installed(): + logger.warning( + f"Report mode '{self._report_mode}' requires 'matplotlib'" + " to be installed. Disabling report." + ) + self._report_mode = None logger.debug(f"{self._report_mode=}") @property diff --git a/tests/test_report.py b/tests/test_report.py index 161e03d0..4ccbc136 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -1,6 +1,5 @@ import numpy as np import pytest -from IPython import display from PIL import Image from dvclive import Live @@ -71,22 +70,21 @@ def test_get_renderers(tmp_dir, mocker): def test_report_init(monkeypatch, mocker): - monkeypatch.setenv("CI", "false") - live = Live(report="auto") - assert live._report_mode == "html" + mocker.patch("dvclive.live.inside_notebook", return_value=False) + live = Live(report="notebook") + assert live._report_mode is None + + mocker.patch("dvclive.live.matplotlib_installed", return_value=False) + live = Live(report="md") + assert live._report_mode is None - monkeypatch.setenv("CI", "true") - live = Live(report="auto") + mocker.patch("dvclive.live.matplotlib_installed", return_value=True) + live = Live(report="md") assert live._report_mode == "md" - mocker.patch("dvclive.live.matplotlib_installed", return_value=False) - live = Live(report="auto") + live = Live(report="html") assert live._report_mode == "html" - for report in (None, "html", "md"): - live = Live(report=report) - assert live._report_mode == report - with pytest.raises(InvalidReportModeError, match="Got foo instead."): Live(report="foo") @@ -204,20 +202,6 @@ def test_get_plot_renderers_custom(tmp_dir): assert plot_renderer.properties == live._plots[name].plot_config -def test_report_auto_doesnt_set_notebook(tmp_dir, mocker): - mocker.patch("dvclive.live.inside_notebook", return_value=True) - live = Live() - assert live._report_mode != "notebook" - - -def test_report_notebook_fallsback_to_html(tmp_dir, mocker): - mocker.patch("dvclive.live.inside_notebook", return_value=False) - spy = mocker.spy(display, "display") - live = Live(report="notebook") - assert live._report_mode == "html" - assert not spy.called - - def test_report_notebook(tmp_dir, mocker): mocker.patch("dvclive.live.inside_notebook", return_value=True) mocked_display = mocker.MagicMock() From 7d2528e74421fb837d2d55982bedcb8bb8e3988a Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Thu, 7 Sep 2023 11:38:16 +0200 Subject: [PATCH 13/33] studio: Extract `post_to_studio` and decoulple from `make_report` (#705) --- src/dvclive/live.py | 73 +++++++------------------------------------ src/dvclive/studio.py | 44 +++++++++++++++++++++++++- tests/test_studio.py | 7 +++-- 3 files changed, 60 insertions(+), 64 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index d4ac6029..a8a60e6a 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional, Set, Union from dvc.exceptions import DvcException -from dvc_studio_client.post_live_metrics import post_live_metrics from funcy import set_in from ruamel.yaml.representer import RepresenterError @@ -30,7 +29,7 @@ from .plots import PLOT_TYPES, SKLEARN_PLOTS, CustomPlot, Image, Metric, NumpyEncoder from .report import BLANK_NOTEBOOK_REPORT, make_report from .serialize import dump_json, dump_yaml, load_yaml -from .studio import get_dvc_studio_config, get_studio_updates +from .studio import get_dvc_studio_config, post_to_studio from .utils import ( StrPath, catch_and_warn, @@ -229,22 +228,7 @@ def _init_studio(self): self._studio_events_to_skip.add("data") self._studio_events_to_skip.add("done") else: - response = post_live_metrics( - "start", - self._baseline_rev, - self._exp_name, - "dvclive", - dvc_studio_config=self._dvc_studio_config, - message=self._exp_message, - ) - if not response: - logger.debug( - "`studio` report `start` event failed. " - "`studio` report cancelled." - ) - self._studio_events_to_skip.add("start") - self._studio_events_to_skip.add("data") - self._studio_events_to_skip.add("done") + self.post_to_studio("start") def _init_report(self): if self._report_mode not in {None, "html", "notebook", "md"}: @@ -321,6 +305,9 @@ def next_step(self): self.make_dvcyaml() self.make_report() + + self.post_to_studio("data") + mark_dvclive_step_completed(self.step) self.step += 1 @@ -530,29 +517,6 @@ def make_summary(self, update_step: bool = True): dump_json(self.summary, self.metrics_file, cls=NumpyEncoder) def make_report(self): - if "data" not in self._studio_events_to_skip: - response = False - if post_live_metrics is not None: - metrics, params, plots = get_studio_updates(self) - response = post_live_metrics( - "data", - self._baseline_rev, - self._exp_name, - "dvclive", - step=self.step, - metrics=metrics, - params=params, - plots=plots, - dvc_studio_config=self._dvc_studio_config, - ) - if not response: - logger.warning( - "`post_to_studio` `data` event failed." - " Data will be resent on next call." - ) - else: - self._latest_studio_step = self.step - if self._report_mode is not None: make_report(self) if self._report_mode == "html" and env2bool(env.DVCLIVE_OPEN): @@ -563,6 +527,10 @@ def make_dvcyaml(self): if self.dvc_file: make_dvcyaml(self) + @catch_and_warn(DvcException, logger) + def post_to_studio(self, event): + post_to_studio(self, event) + def end(self): if self._inside_with: # Prevent `live.end` calls inside context manager @@ -581,28 +549,11 @@ def end(self): self.dir, self._dvc_repo ) + self.make_report() + self.save_dvc_exp() - if "done" not in self._studio_events_to_skip: - response = False - if post_live_metrics is not None: - kwargs = {} - if self._experiment_rev: - kwargs["experiment_rev"] = self._experiment_rev - response = post_live_metrics( - "done", - self._baseline_rev, - self._exp_name, - "dvclive", - dvc_studio_config=self._dvc_studio_config, - **kwargs, - ) - if not response: - logger.warning("`post_to_studio` `done` event failed.") - self._studio_events_to_skip.add("done") - self._studio_events_to_skip.add("data") - else: - self.make_report() + self.post_to_studio("done") cleanup_dvclive_step_completed() diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index fd6ab66a..29bb22fc 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -1,13 +1,17 @@ # ruff: noqa: SLF001 import base64 +import logging import math import os -from dvc_studio_client.post_live_metrics import get_studio_config +from dvc_studio_client.config import get_studio_config +from dvc_studio_client.post_live_metrics import post_live_metrics from dvclive.serialize import load_yaml from dvclive.utils import parse_metrics, rel_path +logger = logging.getLogger("dvclive") + def _get_unsent_datapoints(plot, latest_step): return [x for x in plot if int(x["step"]) > latest_step] @@ -92,3 +96,41 @@ def get_dvc_studio_config(live): if live._dvc_repo: config = live._dvc_repo.config.get("studio") return get_studio_config(dvc_studio_config=config) + + +def post_to_studio(live, event): + if event in live._studio_events_to_skip: + return + + kwargs = {} + if event == "start" and live._exp_message: + kwargs["message"] = live._exp_message + elif event == "data": + metrics, params, plots = get_studio_updates(live) + kwargs["step"] = live.step + kwargs["metrics"] = metrics + kwargs["params"] = params + kwargs["plots"] = plots + elif event == "done" and live._experiment_rev: + kwargs["experiment_rev"] = live._experiment_rev + + response = post_live_metrics( + event, + live._baseline_rev, + live._exp_name, + "dvclive", + dvc_studio_config=live._dvc_studio_config, + **kwargs, + ) + if not response: + logger.warning(f"`post_to_studio` `{event}` failed.") + if event == "start": + live._studio_events_to_skip.add("start") + live._studio_events_to_skip.add("data") + live._studio_events_to_skip.add("done") + elif event == "data": + live._latest_studio_step = live.step + + if event == "done": + live._studio_events_to_skip.add("done") + live._studio_events_to_skip.add("data") diff --git a/tests/test_studio.py b/tests/test_studio.py index 78df948f..747a9b42 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -146,7 +146,7 @@ def test_post_to_studio_end_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_po @pytest.mark.studio() -def test_post_to_studio_skip_on_env_var( +def test_post_to_studio_skip_start_and_done_on_env_var( tmp_dir, mocked_dvc_repo, mocked_studio_post, monkeypatch ): mocked_post, _ = mocked_studio_post @@ -156,6 +156,7 @@ def test_post_to_studio_skip_on_env_var( with Live() as live: live.log_metric("foo", 1) + live.next_step() assert mocked_post.call_count == 2 @@ -173,6 +174,7 @@ def test_post_to_studio_dvc_studio_config( with Live() as live: live.log_metric("foo", 1) + live.next_step() assert mocked_post.call_count == 2 @@ -184,7 +186,7 @@ def test_post_to_studio_skip_if_no_token( monkeypatch, mocked_dvc_repo, ): - mocked_post = mocker.patch("dvclive.live.post_live_metrics", return_value=None) + mocked_post = mocker.patch("dvclive.studio.post_live_metrics", return_value=None) monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) monkeypatch.setenv(DVC_EXP_NAME, "bar") @@ -232,6 +234,7 @@ def test_post_to_studio_inside_dvc_exp( with Live() as live: live.log_metric("foo", 1) + live.next_step() assert mocked_post.call_count == 2 From a02d16099b072fa86e0947cdc5cfe412864e5fbe Mon Sep 17 00:00:00 2001 From: daavoo Date: Thu, 7 Sep 2023 11:59:10 +0200 Subject: [PATCH 14/33] refactor(tests): Split `test_main` into separate files. Rename test_frameworks to frameworks. --- .github/workflows/tests.yml | 2 +- src/dvclive/utils.py | 18 + .../test_catalyst.py | 0 .../test_fastai.py | 0 .../test_huggingface.py | 0 .../test_keras.py | 0 .../test_lgbm.py | 0 .../test_lightning.py | 0 .../test_optuna.py | 0 .../test_xgboost.py | 0 tests/plots/test_image.py | 9 + tests/test_cleanup.py | 41 ++ tests/test_context_manager.py | 26 + tests/test_log_metric.py | 70 +++ tests/test_log_param.py | 88 ++++ tests/test_logging.py | 23 + tests/test_main.py | 458 ------------------ tests/test_make_dvcyaml.py | 39 ++ tests/{test_report.py => test_make_report.py} | 0 tests/test_make_summary.py | 44 ++ ...{test_studio.py => test_post_to_studio.py} | 0 tests/test_resume.py | 42 ++ tests/test_step.py | 88 ++++ 23 files changed, 489 insertions(+), 459 deletions(-) rename tests/{test_frameworks => frameworks}/test_catalyst.py (100%) rename tests/{test_frameworks => frameworks}/test_fastai.py (100%) rename tests/{test_frameworks => frameworks}/test_huggingface.py (100%) rename tests/{test_frameworks => frameworks}/test_keras.py (100%) rename tests/{test_frameworks => frameworks}/test_lgbm.py (100%) rename tests/{test_frameworks => frameworks}/test_lightning.py (100%) rename tests/{test_frameworks => frameworks}/test_optuna.py (100%) rename tests/{test_frameworks => frameworks}/test_xgboost.py (100%) create mode 100644 tests/test_cleanup.py create mode 100644 tests/test_context_manager.py create mode 100644 tests/test_log_param.py create mode 100644 tests/test_logging.py delete mode 100644 tests/test_main.py rename tests/{test_report.py => test_make_report.py} (100%) create mode 100644 tests/test_make_summary.py rename tests/{test_studio.py => test_post_to_studio.py} (100%) create mode 100644 tests/test_resume.py create mode 100644 tests/test_step.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a71c7b44..ef9faa35 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -92,4 +92,4 @@ jobs: pip install -e '.[tests]' - name: Run tests - run: pytest -v tests --ignore=tests/test_frameworks + run: pytest -v tests --ignore=tests/frameworks diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index a5cfd319..65f0266c 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -176,3 +176,21 @@ def wrapper(*args, **kwargs): def rel_path(path, dvc_root_path): absolute_path = Path(path).absolute() return str(Path(os.path.relpath(absolute_path, dvc_root_path)).as_posix()) + + +def read_history(live, metric): + from dvclive.plots.metric import Metric + + history, _ = parse_metrics(live) + steps = [] + values = [] + name = os.path.join(live.plots_dir, Metric.subfolder, f"{metric}.tsv") + for e in history[name]: + steps.append(int(e["step"])) + values.append(float(e[metric])) + return steps, values + + +def read_latest(live, metric_name): + _, latest = parse_metrics(live) + return latest["step"], latest[metric_name] diff --git a/tests/test_frameworks/test_catalyst.py b/tests/frameworks/test_catalyst.py similarity index 100% rename from tests/test_frameworks/test_catalyst.py rename to tests/frameworks/test_catalyst.py diff --git a/tests/test_frameworks/test_fastai.py b/tests/frameworks/test_fastai.py similarity index 100% rename from tests/test_frameworks/test_fastai.py rename to tests/frameworks/test_fastai.py diff --git a/tests/test_frameworks/test_huggingface.py b/tests/frameworks/test_huggingface.py similarity index 100% rename from tests/test_frameworks/test_huggingface.py rename to tests/frameworks/test_huggingface.py diff --git a/tests/test_frameworks/test_keras.py b/tests/frameworks/test_keras.py similarity index 100% rename from tests/test_frameworks/test_keras.py rename to tests/frameworks/test_keras.py diff --git a/tests/test_frameworks/test_lgbm.py b/tests/frameworks/test_lgbm.py similarity index 100% rename from tests/test_frameworks/test_lgbm.py rename to tests/frameworks/test_lgbm.py diff --git a/tests/test_frameworks/test_lightning.py b/tests/frameworks/test_lightning.py similarity index 100% rename from tests/test_frameworks/test_lightning.py rename to tests/frameworks/test_lightning.py diff --git a/tests/test_frameworks/test_optuna.py b/tests/frameworks/test_optuna.py similarity index 100% rename from tests/test_frameworks/test_optuna.py rename to tests/frameworks/test_optuna.py diff --git a/tests/test_frameworks/test_xgboost.py b/tests/frameworks/test_xgboost.py similarity index 100% rename from tests/test_frameworks/test_xgboost.py rename to tests/frameworks/test_xgboost.py diff --git a/tests/plots/test_image.py b/tests/plots/test_image.py index ed52efbc..f5621b67 100644 --- a/tests/plots/test_image.py +++ b/tests/plots/test_image.py @@ -115,3 +115,12 @@ def test_matplotlib(tmp_dir): assert not plt.fignum_exists(fig.number) assert (tmp_dir / live.plots_dir / LiveImage.subfolder / "image.png").exists() + + +@pytest.mark.parametrize("cache", [False, True]) +def test_cache_images(tmp_dir, dvc_repo, cache): + live = Live(save_dvc_exp=False, cache_images=cache) + img = Image.new("RGB", (10, 10), (250, 250, 250)) + live.log_image("image.png", img) + live.end() + assert (tmp_dir / "dvclive" / "plots" / "images.dvc").exists() == cache diff --git a/tests/test_cleanup.py b/tests/test_cleanup.py new file mode 100644 index 00000000..2730e9a1 --- /dev/null +++ b/tests/test_cleanup.py @@ -0,0 +1,41 @@ +import pytest + +from dvclive import Live +from dvclive.plots import Metric + + +@pytest.mark.parametrize( + "html", + [True, False], +) +@pytest.mark.parametrize( + "dvcyaml", + ["dvc.yaml", "logs/dvc.yaml"], +) +def test_cleanup(tmp_dir, html, dvcyaml): + dvclive = Live("logs", report="html" if html else None, dvcyaml=dvcyaml) + dvclive.log_metric("m1", 1) + dvclive.next_step() + + html_path = tmp_dir / dvclive.dir / "report.html" + if html: + html_path.touch() + + (tmp_dir / "logs" / "some_user_file.txt").touch() + (tmp_dir / "dvc.yaml").touch() + + assert (tmp_dir / dvclive.plots_dir / Metric.subfolder / "m1.tsv").is_file() + assert (tmp_dir / dvclive.metrics_file).is_file() + assert (tmp_dir / dvclive.dvc_file).is_file() + assert html_path.is_file() == html + + dvclive = Live("logs") + + assert (tmp_dir / "logs" / "some_user_file.txt").is_file() + assert not (tmp_dir / dvclive.plots_dir / Metric.subfolder).exists() + assert not (tmp_dir / dvclive.metrics_file).is_file() + if dvcyaml == "dvc.yaml": + assert (tmp_dir / dvcyaml).is_file() + if dvcyaml == "logs/dvc.yaml": + assert not (tmp_dir / dvcyaml).is_file() + assert not (html_path).is_file() diff --git a/tests/test_context_manager.py b/tests/test_context_manager.py new file mode 100644 index 00000000..23b3ed19 --- /dev/null +++ b/tests/test_context_manager.py @@ -0,0 +1,26 @@ +import json + +from dvclive import Live +from dvclive.plots import Metric + + +def test_context_manager(tmp_dir): + with Live(report="html") as live: + live.summary["foo"] = 1.0 + + assert json.loads((tmp_dir / live.metrics_file).read_text()) == { + # no `step` + "foo": 1.0 + } + log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv" + assert not log_file.exists() + report_file = tmp_dir / live.report_file + assert report_file.exists() + + +def test_context_manager_skips_end_calls(tmp_dir): + with Live() as live: + live.summary["foo"] = 1.0 + live.end() + assert not (tmp_dir / live.metrics_file).exists() + assert (tmp_dir / live.metrics_file).exists() diff --git a/tests/test_log_metric.py b/tests/test_log_metric.py index 3a7ff3ab..8d1b17fd 100644 --- a/tests/test_log_metric.py +++ b/tests/test_log_metric.py @@ -1,10 +1,80 @@ import math +import os import numpy as np import pytest from dvclive import Live from dvclive.error import InvalidDataTypeError +from dvclive.plots import Metric +from dvclive.serialize import load_yaml +from dvclive.utils import parse_metrics, parse_tsv + + +def test_logging_no_step(tmp_dir): + dvclive = Live("logs") + + dvclive.log_metric("m1", 1, plot=False) + dvclive.make_summary() + + assert not (tmp_dir / "logs" / "plots" / "metrics" / "m1.tsv").is_file() + assert (tmp_dir / dvclive.metrics_file).is_file() + + s = load_yaml(dvclive.metrics_file) + assert s["m1"] == 1 + assert "step" not in s + + +@pytest.mark.parametrize("path", ["logs", os.path.join("subdir", "logs")]) +def test_logging_step(tmp_dir, path): + dvclive = Live(path) + dvclive.log_metric("m1", 1) + dvclive.next_step() + assert (tmp_dir / dvclive.dir).is_dir() + assert (tmp_dir / dvclive.plots_dir / Metric.subfolder / "m1.tsv").is_file() + assert (tmp_dir / dvclive.metrics_file).is_file() + + s = load_yaml(dvclive.metrics_file) + assert s["m1"] == 1 + assert s["step"] == 0 + + +def test_nested_logging(tmp_dir): + dvclive = Live("logs") + + out = tmp_dir / dvclive.plots_dir / Metric.subfolder + + dvclive.log_metric("train/m1", 1) + dvclive.log_metric("val/val_1/m1", 1) + dvclive.log_metric("val/val_1/m2", 1) + + dvclive.next_step() + + assert (out / "val" / "val_1").is_dir() + assert (out / "train" / "m1.tsv").is_file() + assert (out / "val" / "val_1" / "m1.tsv").is_file() + assert (out / "val" / "val_1" / "m2.tsv").is_file() + + assert "m1" in parse_tsv(out / "train" / "m1.tsv")[0] + assert "m1" in parse_tsv(out / "val" / "val_1" / "m1.tsv")[0] + assert "m2" in parse_tsv(out / "val" / "val_1" / "m2.tsv")[0] + + summary = load_yaml(dvclive.metrics_file) + + assert summary["train"]["m1"] == 1 + assert summary["val"]["val_1"]["m1"] == 1 + assert summary["val"]["val_1"]["m2"] == 1 + + +@pytest.mark.parametrize("timestamp", [True, False]) +def test_log_metric_timestamp(tmp_dir, timestamp): + live = Live() + live.log_metric("foo", 1.0, timestamp=timestamp) + live.next_step() + + history, _ = parse_metrics(live) + logged = next(iter(history.values())) + assert ("timestamp" in logged[0]) == timestamp @pytest.mark.parametrize("invalid_type", [{0: 1}, [0, 1], (0, 1)]) diff --git a/tests/test_log_param.py b/tests/test_log_param.py new file mode 100644 index 00000000..bb008146 --- /dev/null +++ b/tests/test_log_param.py @@ -0,0 +1,88 @@ +import os + +import pytest + +from dvclive import Live +from dvclive.error import InvalidParameterTypeError +from dvclive.serialize import load_yaml + + +def test_cleanup_params(tmp_dir): + dvclive = Live("logs") + dvclive.log_param("param", 42) + + assert os.path.isfile(dvclive.params_file) + + dvclive = Live("logs") + assert not os.path.exists(dvclive.params_file) + + +@pytest.mark.parametrize( + ("param_name", "param_value"), + [ + ("param_string", "value"), + ("param_int", 42), + ("param_float", 42.0), + ("param_bool_true", True), + ("param_bool_false", False), + ("param_list", [1, 2, 3]), + ( + "param_dict_simple", + {"str": "value", "int": 42, "bool": True, "list": [1, 2, 3]}, + ), + ( + "param_dict_nested", + { + "str": "value", + "int": 42, + "bool": True, + "list": [1, 2, 3], + "dict": {"nested-str": "value", "nested-int": 42}, + }, + ), + ], +) +def test_log_param(tmp_dir, param_name, param_value): + dvclive = Live() + + dvclive.log_param(param_name, param_value) + + s = load_yaml(dvclive.params_file) + assert s[param_name] == param_value + + +def test_log_params(tmp_dir): + dvclive = Live() + params = { + "param_string": "value", + "param_int": 42, + "param_float": 42.0, + "param_bool_true": True, + "param_bool_false": False, + } + + dvclive.log_params(params) + + s = load_yaml(dvclive.params_file) + assert s == params + + +@pytest.mark.parametrize("resume", [False, True]) +def test_log_params_resume(tmp_dir, resume): + dvclive = Live(resume=resume) + dvclive.log_param("param", 42) + + dvclive = Live(resume=resume) + assert ("param" in dvclive._params) == resume + + +def test_log_param_custom_obj(tmp_dir): + dvclive = Live("logs") + + class Dummy: + val = 42 + + param_value = Dummy() + + with pytest.raises(InvalidParameterTypeError): + dvclive.log_param("param_complex", param_value) diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 00000000..b410c3dc --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,23 @@ +import logging + +from dvclive import Live + + +def test_logger(tmp_dir, mocker): + logger = mocker.patch("dvclive.live.logger") + + live = Live() + live.log_metric("foo", 0) + logger.debug.assert_called_with("Logged foo: 0") + live.next_step() + logger.debug.assert_called_with("Step: 1") + live.log_metric("foo", 1) + live.next_step() + + live = Live(resume=True) + logger.info.assert_called_with("Resuming from step 1") + + +def test_suppress_dvc_logs(tmp_dir, mocked_dvc_repo): + Live() + assert logging.getLogger("dvc").level == 30 diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index d076a5e7..00000000 --- a/tests/test_main.py +++ /dev/null @@ -1,458 +0,0 @@ -import json -import logging -import os - -import pytest -from PIL import Image - -from dvclive import Live, env -from dvclive.error import InvalidDvcyamlError, InvalidParameterTypeError -from dvclive.plots import Metric -from dvclive.serialize import load_yaml -from dvclive.utils import parse_metrics, parse_tsv - - -def read_history(live, metric): - history, _ = parse_metrics(live) - steps = [] - values = [] - name = os.path.join(live.plots_dir, Metric.subfolder, f"{metric}.tsv") - for e in history[name]: - steps.append(int(e["step"])) - values.append(float(e[metric])) - return steps, values - - -def read_latest(live, metric_name): - _, latest = parse_metrics(live) - return latest["step"], latest[metric_name] - - -def test_logging_no_step(tmp_dir): - dvclive = Live("logs") - - dvclive.log_metric("m1", 1, plot=False) - dvclive.make_summary() - - assert not (tmp_dir / "logs" / "plots" / "metrics" / "m1.tsv").is_file() - assert (tmp_dir / dvclive.metrics_file).is_file() - - s = load_yaml(dvclive.metrics_file) - assert s["m1"] == 1 - assert "step" not in s - - -@pytest.mark.parametrize( - ("param_name", "param_value"), - [ - ("param_string", "value"), - ("param_int", 42), - ("param_float", 42.0), - ("param_bool_true", True), - ("param_bool_false", False), - ("param_list", [1, 2, 3]), - ( - "param_dict_simple", - {"str": "value", "int": 42, "bool": True, "list": [1, 2, 3]}, - ), - ( - "param_dict_nested", - { - "str": "value", - "int": 42, - "bool": True, - "list": [1, 2, 3], - "dict": {"nested-str": "value", "nested-int": 42}, - }, - ), - ], -) -def test_log_param(tmp_dir, param_name, param_value): - dvclive = Live() - - dvclive.log_param(param_name, param_value) - - s = load_yaml(dvclive.params_file) - assert s[param_name] == param_value - - -def test_log_params(tmp_dir): - dvclive = Live() - params = { - "param_string": "value", - "param_int": 42, - "param_float": 42.0, - "param_bool_true": True, - "param_bool_false": False, - } - - dvclive.log_params(params) - - s = load_yaml(dvclive.params_file) - assert s == params - - -@pytest.mark.parametrize("resume", [False, True]) -def test_log_params_resume(tmp_dir, resume): - dvclive = Live(resume=resume) - dvclive.log_param("param", 42) - - dvclive = Live(resume=resume) - assert ("param" in dvclive._params) == resume - - -def test_log_param_custom_obj(tmp_dir): - dvclive = Live("logs") - - class Dummy: - val = 42 - - param_value = Dummy() - - with pytest.raises(InvalidParameterTypeError): - dvclive.log_param("param_complex", param_value) - - -@pytest.mark.parametrize("path", ["logs", os.path.join("subdir", "logs")]) -def test_logging_step(tmp_dir, path): - dvclive = Live(path) - dvclive.log_metric("m1", 1) - dvclive.next_step() - assert (tmp_dir / dvclive.dir).is_dir() - assert (tmp_dir / dvclive.plots_dir / Metric.subfolder / "m1.tsv").is_file() - assert (tmp_dir / dvclive.metrics_file).is_file() - - s = load_yaml(dvclive.metrics_file) - assert s["m1"] == 1 - assert s["step"] == 0 - - -def test_nested_logging(tmp_dir): - dvclive = Live("logs") - - out = tmp_dir / dvclive.plots_dir / Metric.subfolder - - dvclive.log_metric("train/m1", 1) - dvclive.log_metric("val/val_1/m1", 1) - dvclive.log_metric("val/val_1/m2", 1) - - dvclive.next_step() - - assert (out / "val" / "val_1").is_dir() - assert (out / "train" / "m1.tsv").is_file() - assert (out / "val" / "val_1" / "m1.tsv").is_file() - assert (out / "val" / "val_1" / "m2.tsv").is_file() - - assert "m1" in parse_tsv(out / "train" / "m1.tsv")[0] - assert "m1" in parse_tsv(out / "val" / "val_1" / "m1.tsv")[0] - assert "m2" in parse_tsv(out / "val" / "val_1" / "m2.tsv")[0] - - summary = load_yaml(dvclive.metrics_file) - - assert summary["train"]["m1"] == 1 - assert summary["val"]["val_1"]["m1"] == 1 - assert summary["val"]["val_1"]["m2"] == 1 - - -@pytest.mark.parametrize( - "html", - [True, False], -) -@pytest.mark.parametrize( - "dvcyaml", - ["dvc.yaml", "logs/dvc.yaml"], -) -def test_cleanup(tmp_dir, html, dvcyaml): - dvclive = Live("logs", report="html" if html else None, dvcyaml=dvcyaml) - dvclive.log_metric("m1", 1) - dvclive.next_step() - - html_path = tmp_dir / dvclive.dir / "report.html" - if html: - html_path.touch() - - (tmp_dir / "logs" / "some_user_file.txt").touch() - (tmp_dir / "dvc.yaml").touch() - - assert (tmp_dir / dvclive.plots_dir / Metric.subfolder / "m1.tsv").is_file() - assert (tmp_dir / dvclive.metrics_file).is_file() - assert (tmp_dir / dvclive.dvc_file).is_file() - assert html_path.is_file() == html - - dvclive = Live("logs") - - assert (tmp_dir / "logs" / "some_user_file.txt").is_file() - assert not (tmp_dir / dvclive.plots_dir / Metric.subfolder).exists() - assert not (tmp_dir / dvclive.metrics_file).is_file() - if dvcyaml == "dvc.yaml": - assert (tmp_dir / dvcyaml).is_file() - if dvcyaml == "logs/dvc.yaml": - assert not (tmp_dir / dvcyaml).is_file() - assert not (html_path).is_file() - - -def test_cleanup_params(tmp_dir): - dvclive = Live("logs") - dvclive.log_param("param", 42) - - assert os.path.isfile(dvclive.params_file) - - dvclive = Live("logs") - assert not os.path.exists(dvclive.params_file) - - -@pytest.mark.parametrize( - ("resume", "steps", "metrics"), - [(True, [0, 1, 2, 3], [0.9, 0.8, 0.7, 0.6]), (False, [0, 1], [0.7, 0.6])], -) -def test_continue(tmp_dir, resume, steps, metrics): - dvclive = Live("logs") - - for metric in [0.9, 0.8]: - dvclive.log_metric("metric", metric) - dvclive.next_step() - - assert read_history(dvclive, "metric") == ([0, 1], [0.9, 0.8]) - assert read_latest(dvclive, "metric") == (1, 0.8) - - dvclive = Live("logs", resume=resume) - - for new_metric in [0.7, 0.6]: - dvclive.log_metric("metric", new_metric) - dvclive.next_step() - - assert read_history(dvclive, "metric") == (steps, metrics) - assert read_latest(dvclive, "metric") == (steps[-1], metrics[-1]) - - -def test_resume_on_first_init(tmp_dir): - dvclive = Live(resume=True) - - assert dvclive._step == 0 - - -def test_resume_env_var(tmp_dir, monkeypatch): - assert not Live()._resume - - monkeypatch.setenv(env.DVCLIVE_RESUME, "true") - assert Live()._resume - - -@pytest.mark.parametrize("metric", ["m1", os.path.join("train", "m1")]) -def test_allow_step_override(tmp_dir, metric): - dvclive = Live("logs") - - dvclive.log_metric(metric, 1.0) - dvclive.log_metric(metric, 2.0) - - -def test_custom_steps(tmp_dir): - dvclive = Live("logs") - - steps = [0, 62, 1000] - metrics = [0.9, 0.8, 0.7] - - for step, metric in zip(steps, metrics): - dvclive.step = step - dvclive.log_metric("m", metric) - dvclive.make_summary() - - assert read_history(dvclive, "m") == (steps, metrics) - assert read_latest(dvclive, "m") == (steps[-1], metrics[-1]) - - -def test_log_reset_with_set_step(tmp_dir): - dvclive = Live() - - for i in range(3): - dvclive.step = i - dvclive.log_metric("train_m", 1) - dvclive.make_summary() - - for i in range(3): - dvclive.step = i - dvclive.log_metric("val_m", 1) - dvclive.make_summary() - - assert read_history(dvclive, "train_m") == ([0, 1, 2], [1, 1, 1]) - assert read_history(dvclive, "val_m") == ([0, 1, 2], [1, 1, 1]) - assert read_latest(dvclive, "train_m") == (2, 1) - assert read_latest(dvclive, "val_m") == (2, 1) - - -def test_get_step_resume(tmp_dir): - dvclive = Live() - - for metric in [0.9, 0.8]: - dvclive.log_metric("metric", metric) - dvclive.next_step() - - assert dvclive.step == 2 - - dvclive = Live(resume=True) - assert dvclive.step == 2 - - dvclive = Live(resume=False) - assert dvclive.step == 0 - - -def test_get_step_custom_steps(tmp_dir): - dvclive = Live() - - steps = [0, 62, 1000] - metrics = [0.9, 0.8, 0.7] - - for step, metric in zip(steps, metrics): - dvclive.step = step - dvclive.log_metric("x", metric) - assert dvclive.step == step - - -def test_get_step_control_flow(tmp_dir): - dvclive = Live() - - while dvclive.step < 10: - dvclive.log_metric("i", dvclive.step) - dvclive.next_step() - - steps, values = read_history(dvclive, "i") - assert steps == list(range(10)) - assert values == [float(x) for x in range(10)] - - -def test_logger(tmp_dir, mocker, monkeypatch): - logger = mocker.patch("dvclive.live.logger") - - live = Live() - live.log_metric("foo", 0) - logger.debug.assert_called_with("Logged foo: 0") - live.next_step() - logger.debug.assert_called_with("Step: 1") - live.log_metric("foo", 1) - live.next_step() - - live = Live(resume=True) - logger.info.assert_called_with("Resuming from step 1") - - -def test_make_summary_without_calling_log(tmp_dir): - dvclive = Live() - - dvclive.summary["foo"] = 1.0 - dvclive.make_summary() - - assert json.loads((tmp_dir / dvclive.metrics_file).read_text()) == { - # no `step` - "foo": 1.0 - } - log_file = tmp_dir / dvclive.plots_dir / Metric.subfolder / "foo.tsv" - assert not log_file.exists() - - -@pytest.mark.parametrize("timestamp", [True, False]) -def test_log_metric_timestamp(tmp_dir, timestamp): - live = Live() - live.log_metric("foo", 1.0, timestamp=timestamp) - live.next_step() - - history, _ = parse_metrics(live) - logged = next(iter(history.values())) - assert ("timestamp" in logged[0]) == timestamp - - -def test_make_summary_is_called_on_end(tmp_dir): - live = Live() - - live.summary["foo"] = 1.0 - live.end() - - assert json.loads((tmp_dir / live.metrics_file).read_text()) == { - # no `step` - "foo": 1.0 - } - log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv" - assert not log_file.exists() - - -def test_make_summary_on_end_dont_increment_step(tmp_dir): - with Live() as live: - for i in range(2): - live.log_metric("foo", i) - live.next_step() - - assert json.loads((tmp_dir / live.metrics_file).read_text()) == { - "foo": 1.0, - "step": 1, - } - - -def test_context_manager(tmp_dir): - with Live(report="html") as live: - live.summary["foo"] = 1.0 - - assert json.loads((tmp_dir / live.metrics_file).read_text()) == { - # no `step` - "foo": 1.0 - } - log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv" - assert not log_file.exists() - report_file = tmp_dir / live.report_file - assert report_file.exists() - - -def test_context_manager_skips_end_calls(tmp_dir): - with Live() as live: - live.summary["foo"] = 1.0 - live.end() - assert not (tmp_dir / live.metrics_file).exists() - assert (tmp_dir / live.metrics_file).exists() - - -@pytest.mark.parametrize( - "dvcyaml", - [True, False, "dvc.yaml"], -) -def test_make_dvcyaml(tmp_dir, mocked_dvc_repo, dvcyaml): - dvclive = Live("logs", dvcyaml=dvcyaml) - dvclive.log_metric("m1", 1) - dvclive.next_step() - - if dvcyaml: - assert "metrics" in load_yaml(dvclive.dvc_file) - else: - assert not os.path.exists(dvclive.dvc_file) - - dvclive.make_dvcyaml() - assert "metrics" in load_yaml(dvclive.dvc_file) - - -def test_make_dvcyaml_no_repo(tmp_dir, mocker): - logger = mocker.patch("dvclive.live.logger") - dvclive = Live("logs") - dvclive.make_dvcyaml() - - assert not os.path.exists("dvc.yaml") - assert not dvclive.dvc_file - logger.warning.assert_any_call( - "Can't infer dvcyaml path without a DVC repo. " - "`dvc.yaml` file will not be written." - ) - - -def test_make_dvcyaml_invalid(tmp_dir, mocker): - with pytest.raises(InvalidDvcyamlError): - Live("logs", dvcyaml="invalid") - - -def test_suppress_dvc_logs(tmp_dir, mocked_dvc_repo): - Live() - assert logging.getLogger("dvc").level == 30 - - -@pytest.mark.parametrize("cache", [False, True]) -def test_cache_images(tmp_dir, dvc_repo, cache): - live = Live(save_dvc_exp=False, cache_images=cache) - img = Image.new("RGB", (10, 10), (250, 250, 250)) - live.log_image("image.png", img) - live.end() - assert (tmp_dir / "dvclive" / "plots" / "images.dvc").exists() == cache diff --git a/tests/test_make_dvcyaml.py b/tests/test_make_dvcyaml.py index 465e0a24..428fd820 100644 --- a/tests/test_make_dvcyaml.py +++ b/tests/test_make_dvcyaml.py @@ -1,8 +1,11 @@ +import os + import pytest from PIL import Image from dvclive import Live from dvclive.dvc import make_dvcyaml +from dvclive.error import InvalidDvcyamlError from dvclive.serialize import dump_yaml, load_yaml @@ -416,3 +419,39 @@ def test_warn_on_dvcyaml_output_overlap(tmp_dir, mocker, mocked_dvc_repo, dvcyam logger.warning.assert_called_with(msg) else: logger.warning.assert_not_called() + + +@pytest.mark.parametrize( + "dvcyaml", + [True, False, "dvc.yaml"], +) +def test_make_dvcyaml(tmp_dir, mocked_dvc_repo, dvcyaml): + dvclive = Live("logs", dvcyaml=dvcyaml) + dvclive.log_metric("m1", 1) + dvclive.next_step() + + if dvcyaml: + assert "metrics" in load_yaml(dvclive.dvc_file) + else: + assert not os.path.exists(dvclive.dvc_file) + + dvclive.make_dvcyaml() + assert "metrics" in load_yaml(dvclive.dvc_file) + + +def test_make_dvcyaml_no_repo(tmp_dir, mocker): + logger = mocker.patch("dvclive.live.logger") + dvclive = Live("logs") + dvclive.make_dvcyaml() + + assert not os.path.exists("dvc.yaml") + assert not dvclive.dvc_file + logger.warning.assert_any_call( + "Can't infer dvcyaml path without a DVC repo. " + "`dvc.yaml` file will not be written." + ) + + +def test_make_dvcyaml_invalid(tmp_dir, mocker): + with pytest.raises(InvalidDvcyamlError): + Live("logs", dvcyaml="invalid") diff --git a/tests/test_report.py b/tests/test_make_report.py similarity index 100% rename from tests/test_report.py rename to tests/test_make_report.py diff --git a/tests/test_make_summary.py b/tests/test_make_summary.py new file mode 100644 index 00000000..6e671b1e --- /dev/null +++ b/tests/test_make_summary.py @@ -0,0 +1,44 @@ +import json + +from dvclive import Live +from dvclive.plots import Metric + + +def test_make_summary_without_calling_log(tmp_dir): + dvclive = Live() + + dvclive.summary["foo"] = 1.0 + dvclive.make_summary() + + assert json.loads((tmp_dir / dvclive.metrics_file).read_text()) == { + # no `step` + "foo": 1.0 + } + log_file = tmp_dir / dvclive.plots_dir / Metric.subfolder / "foo.tsv" + assert not log_file.exists() + + +def test_make_summary_is_called_on_end(tmp_dir): + live = Live() + + live.summary["foo"] = 1.0 + live.end() + + assert json.loads((tmp_dir / live.metrics_file).read_text()) == { + # no `step` + "foo": 1.0 + } + log_file = tmp_dir / live.plots_dir / Metric.subfolder / "foo.tsv" + assert not log_file.exists() + + +def test_make_summary_on_end_dont_increment_step(tmp_dir): + with Live() as live: + for i in range(2): + live.log_metric("foo", i) + live.next_step() + + assert json.loads((tmp_dir / live.metrics_file).read_text()) == { + "foo": 1.0, + "step": 1, + } diff --git a/tests/test_studio.py b/tests/test_post_to_studio.py similarity index 100% rename from tests/test_studio.py rename to tests/test_post_to_studio.py diff --git a/tests/test_resume.py b/tests/test_resume.py new file mode 100644 index 00000000..9ffdea1d --- /dev/null +++ b/tests/test_resume.py @@ -0,0 +1,42 @@ +import pytest + +from dvclive import Live +from dvclive.env import DVCLIVE_RESUME +from dvclive.utils import read_history, read_latest + + +@pytest.mark.parametrize( + ("resume", "steps", "metrics"), + [(True, [0, 1, 2, 3], [0.9, 0.8, 0.7, 0.6]), (False, [0, 1], [0.7, 0.6])], +) +def test_resume(tmp_dir, resume, steps, metrics): + dvclive = Live("logs") + + for metric in [0.9, 0.8]: + dvclive.log_metric("metric", metric) + dvclive.next_step() + + assert read_history(dvclive, "metric") == ([0, 1], [0.9, 0.8]) + assert read_latest(dvclive, "metric") == (1, 0.8) + + dvclive = Live("logs", resume=resume) + + for new_metric in [0.7, 0.6]: + dvclive.log_metric("metric", new_metric) + dvclive.next_step() + + assert read_history(dvclive, "metric") == (steps, metrics) + assert read_latest(dvclive, "metric") == (steps[-1], metrics[-1]) + + +def test_resume_on_first_init(tmp_dir): + dvclive = Live(resume=True) + + assert dvclive._step == 0 + + +def test_resume_env_var(tmp_dir, monkeypatch): + assert not Live()._resume + + monkeypatch.setenv(DVCLIVE_RESUME, "true") + assert Live()._resume diff --git a/tests/test_step.py b/tests/test_step.py new file mode 100644 index 00000000..15497758 --- /dev/null +++ b/tests/test_step.py @@ -0,0 +1,88 @@ +import os + +import pytest + +from dvclive import Live +from dvclive.utils import read_history, read_latest + + +@pytest.mark.parametrize("metric", ["m1", os.path.join("train", "m1")]) +def test_allow_step_override(tmp_dir, metric): + dvclive = Live("logs") + + dvclive.log_metric(metric, 1.0) + dvclive.log_metric(metric, 2.0) + + +def test_custom_steps(tmp_dir): + dvclive = Live("logs") + + steps = [0, 62, 1000] + metrics = [0.9, 0.8, 0.7] + + for step, metric in zip(steps, metrics): + dvclive.step = step + dvclive.log_metric("m", metric) + dvclive.make_summary() + + assert read_history(dvclive, "m") == (steps, metrics) + assert read_latest(dvclive, "m") == (steps[-1], metrics[-1]) + + +def test_log_reset_with_set_step(tmp_dir): + dvclive = Live() + + for i in range(3): + dvclive.step = i + dvclive.log_metric("train_m", 1) + dvclive.make_summary() + + for i in range(3): + dvclive.step = i + dvclive.log_metric("val_m", 1) + dvclive.make_summary() + + assert read_history(dvclive, "train_m") == ([0, 1, 2], [1, 1, 1]) + assert read_history(dvclive, "val_m") == ([0, 1, 2], [1, 1, 1]) + assert read_latest(dvclive, "train_m") == (2, 1) + assert read_latest(dvclive, "val_m") == (2, 1) + + +def test_get_step_resume(tmp_dir): + dvclive = Live() + + for metric in [0.9, 0.8]: + dvclive.log_metric("metric", metric) + dvclive.next_step() + + assert dvclive.step == 2 + + dvclive = Live(resume=True) + assert dvclive.step == 2 + + dvclive = Live(resume=False) + assert dvclive.step == 0 + + +def test_get_step_custom_steps(tmp_dir): + dvclive = Live() + + steps = [0, 62, 1000] + metrics = [0.9, 0.8, 0.7] + + for step, metric in zip(steps, metrics): + dvclive.step = step + dvclive.log_metric("x", metric) + assert dvclive.step == step + + +def test_get_step_control_flow(tmp_dir): + dvclive = Live() + + while dvclive.step < 10: + dvclive.log_metric("i", dvclive.step) + dvclive.next_step() + + steps, values = read_history(dvclive, "i") + assert steps == list(range(10)) + assert values == [float(x) for x in range(10)] From 6ccc95978188902085c2d29fd905429f67f68837 Mon Sep 17 00:00:00 2001 From: daavoo Date: Thu, 7 Sep 2023 19:58:02 +0200 Subject: [PATCH 15/33] fix matplotlib warning --- src/dvclive/live.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index a8a60e6a..656499a5 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -247,7 +247,7 @@ def _init_report(self): " inside a notebook. Disabling report." ) self._report_mode = None - if self._report_mode != "html" and not matplotlib_installed(): + if self._report_mode in ("notebook", "md") and not matplotlib_installed(): logger.warning( f"Report mode '{self._report_mode}' requires 'matplotlib'" " to be installed. Disabling report." From f3ebcd01199ce350815b3e747c1a4e7db5587b0a Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 11 Sep 2023 16:40:40 -0400 Subject: [PATCH 16/33] fix studio tests --- tests/test_post_to_studio.py | 58 +++- tests/test_studio.py | 587 ----------------------------------- 2 files changed, 51 insertions(+), 594 deletions(-) delete mode 100644 tests/test_studio.py diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index 747a9b42..dc5c60bd 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -224,10 +224,9 @@ def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_po @pytest.mark.studio() def test_post_to_studio_inside_dvc_exp( - tmp_dir, mocker, monkeypatch, mocked_studio_post + tmp_dir, mocker, monkeypatch, mocked_studio_post, mocked_dvc_repo ): mocked_post, _ = mocked_studio_post - mocker.patch("dvclive.live.get_dvc_repo", return_value=None) monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) monkeypatch.setenv(DVC_EXP_NAME, "bar") @@ -302,11 +301,6 @@ def test_post_to_studio_inside_subdir_dvc_exp( ) -def test_post_to_studio_requires_exp(tmp_dir, mocked_dvc_repo, mocked_studio_post): - assert Live(save_dvc_exp=False)._studio_events_to_skip == {"start", "data", "done"} - assert not Live()._studio_events_to_skip - - def test_get_dvc_studio_config_none(mocker): mocker.patch("dvclive.live.get_dvc_repo", return_value=None) live = Live() @@ -366,3 +360,53 @@ def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post): "https://0.0.0.0/api/live", **get_studio_call("start", exp_name=live._exp_name, message="Custom message"), ) + + +def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): + monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN") + monkeypatch.setenv(DVC_STUDIO_REPO_URL, "STUDIO_REPO_URL") + monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) + monkeypatch.setenv(DVC_EXP_NAME, "bar") + + live = Live(save_dvc_exp=True) + live.log_param("fooparam", 1) + + foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() + + mocked_post, _ = mocked_studio_post + + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", **get_studio_call("start", exp_name=live._exp_name) + ) + + live.log_metric("foo", 1) + + live.next_step() + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + **get_studio_call( + "data", + exp_name=live._exp_name, + step=0, + plots={f"{foo_path}": {"data": [{"step": 0, "foo": 1.0}]}}, + ), + ) + + live.log_metric("foo", 2) + + live.next_step() + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + **get_studio_call( + "data", + exp_name=live._exp_name, + step=1, + plots={f"{foo_path}": {"data": [{"step": 1, "foo": 2.0}]}}, + ), + ) + + live.end() + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + **get_studio_call("done", exp_name=live._exp_name), + ) diff --git a/tests/test_studio.py b/tests/test_studio.py deleted file mode 100644 index d263a0ec..00000000 --- a/tests/test_studio.py +++ /dev/null @@ -1,587 +0,0 @@ -from pathlib import Path - -import pytest -from dvc_studio_client.env import DVC_STUDIO_REPO_URL, DVC_STUDIO_TOKEN -from dvc_studio_client.post_live_metrics import STUDIO_URL -from PIL import Image as ImagePIL - -from dvclive import Live -from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME -from dvclive.plots import Image, Metric -from dvclive.studio import _adapt_image, get_dvc_studio_config - - -def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): - live = Live(save_dvc_exp=True) - live.log_param("fooparam", 1) - - dvc_path = Path(live.dvc_file).as_posix() - metrics_path = Path(live.metrics_file).as_posix() - params_path = Path(live.params_file).as_posix() - foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() - - mocked_post, _ = mocked_studio_post - - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "start", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, - "name": live._exp_name, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - live.log_metric("foo", 1) - - live.next_step() - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, - "name": live._exp_name, - "step": 0, - "metrics": {metrics_path: {"data": {"step": 0, "foo": 1}}}, - "params": {params_path: {"fooparam": 1}}, - "plots": {f"{dvc_path}::{foo_path}": {"data": [{"step": 0, "foo": 1.0}]}}, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - live.log_metric("foo", 2) - - live.next_step() - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, - "name": live._exp_name, - "step": 1, - "metrics": {metrics_path: {"data": {"step": 1, "foo": 2}}}, - "params": {params_path: {"fooparam": 1}}, - "plots": {f"{dvc_path}::{foo_path}": {"data": [{"step": 1, "foo": 2.0}]}}, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - live.end() - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "done", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, - "experiment_rev": live._experiment_rev, - "name": live._exp_name, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - -def test_post_to_studio_failed_data_request( - tmp_dir, mocker, mocked_dvc_repo, mocked_studio_post -): - mocked_post, valid_response = mocked_studio_post - - live = Live(save_dvc_exp=True) - - dvc_path = Path(live.dvc_file).as_posix() - metrics_path = Path(live.metrics_file).as_posix() - foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() - - error_response = mocker.MagicMock() - error_response.status_code = 400 - mocker.patch("requests.post", return_value=error_response) - live.log_metric("foo", 1) - live.next_step() - - mocked_post = mocker.patch("requests.post", return_value=valid_response) - live.log_metric("foo", 2) - live.next_step() - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, - "name": live._exp_name, - "step": 1, - "metrics": {metrics_path: {"data": {"step": 1, "foo": 2}}}, - "plots": { - f"{dvc_path}::{foo_path}": { - "data": [ - {"step": 0, "foo": 1.0}, - {"step": 1, "foo": 2.0}, - ] - } - }, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - -def test_post_to_studio_failed_start_request( - tmp_dir, mocker, mocked_dvc_repo, mocked_studio_post -): - mocked_response = mocker.MagicMock() - mocked_response.status_code = 400 - mocked_post = mocker.patch("requests.post", return_value=mocked_response) - - live = Live(save_dvc_exp=True) - - live.log_metric("foo", 1) - live.next_step() - - live.log_metric("foo", 2) - live.next_step() - - assert mocked_post.call_count == 1 - - -def test_post_to_studio_end_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post): - mocked_post, _ = mocked_studio_post - with Live(save_dvc_exp=True) as live: - live.log_metric("foo", 1) - live.next_step() - - assert mocked_post.call_count == 3 - live.end() - assert mocked_post.call_count == 3 - - -@pytest.mark.studio() -def test_post_to_studio_skip_on_env_var( - tmp_dir, mocked_dvc_repo, mocked_studio_post, monkeypatch -): - mocked_post, _ = mocked_studio_post - - monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) - monkeypatch.setenv(DVC_EXP_NAME, "bar") - - with Live() as live: - live.log_metric("foo", 1) - - assert mocked_post.call_count == 1 - - -@pytest.mark.studio() -def test_post_to_studio_dvc_studio_config( - tmp_dir, mocker, mocked_dvc_repo, mocked_studio_post, monkeypatch -): - mocked_post, _ = mocked_studio_post - - monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) - monkeypatch.setenv(DVC_EXP_NAME, "bar") - - mocked_dvc_repo.config = {"studio": {"token": "token"}} - - with Live() as live: - live.log_metric("foo", 1) - - assert mocked_post.call_count == 1 - - -@pytest.mark.studio() -def test_post_to_studio_skip_if_no_token( - tmp_dir, - mocker, - monkeypatch, - mocked_dvc_repo, -): - mocked_post = mocker.patch("dvclive.live.post_live_metrics", return_value=None) - - monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) - monkeypatch.setenv(DVC_EXP_NAME, "bar") - - mocked_dvc_repo.config = {} - - with Live() as live: - live.log_metric("foo", 1) - live.next_step() - - assert mocked_post.call_count == 0 - - -def test_post_to_studio_include_prefix_if_needed( - tmp_dir, mocked_dvc_repo, mocked_studio_post -): - mocked_post, _ = mocked_studio_post - # Create dvclive/dvc.yaml - live = Live("custom_dir", save_dvc_exp=True) - live.log_metric("foo", 1) - live.next_step() - - dvc_path = Path(live.dvc_file).as_posix() - metrics_path = Path(live.metrics_file).as_posix() - foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() - - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, - "name": live._exp_name, - "step": 0, - "metrics": {metrics_path: {"data": {"step": 0, "foo": 1}}}, - "plots": {f"{dvc_path}::{foo_path}": {"data": [{"step": 0, "foo": 1.0}]}}, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - -def test_post_to_studio_shorten_names(tmp_dir, mocked_dvc_repo, mocked_studio_post): - mocked_post, _ = mocked_studio_post - - live = Live(save_dvc_exp=True) - live.log_metric("eval/loss", 1) - live.next_step() - - dvc_path = Path(live.dvc_file).as_posix() - metrics_path = Path(live.metrics_file).as_posix() - plots_path = Path(live.plots_dir) - loss_path = (plots_path / Metric.subfolder / "eval/loss.tsv").as_posix() - - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, - "name": live._exp_name, - "step": 0, - "metrics": {metrics_path: {"data": {"step": 0, "eval": {"loss": 1}}}}, - "plots": {f"{dvc_path}::{loss_path}": {"data": [{"step": 0, "loss": 1.0}]}}, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - -@pytest.mark.studio() -def test_post_to_studio_inside_dvc_exp( - tmp_dir, mocker, monkeypatch, mocked_studio_post, mocked_dvc_repo -): - mocked_post, _ = mocked_studio_post - - monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) - monkeypatch.setenv(DVC_EXP_NAME, "bar") - - with Live() as live: - live.log_metric("foo", 1) - - assert mocked_post.call_count == 1 - - -@pytest.mark.studio() -def test_post_to_studio_inside_subdir( - tmp_dir, dvc_repo, mocker, monkeypatch, mocked_studio_post -): - mocked_post, _ = mocked_studio_post - subdir = tmp_dir / "subdir" - subdir.mkdir() - monkeypatch.chdir(subdir) - - live = Live(save_dvc_exp=True) - live.log_metric("foo", 1) - live.next_step() - - dvc_path = Path(live.dvc_file).as_posix() - metrics_path = Path(live.metrics_file).as_posix() - foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() - - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": live._baseline_rev, - "name": live._exp_name, - "step": 0, - "metrics": {f"subdir/{metrics_path}": {"data": {"step": 0, "foo": 1}}}, - "plots": { - f"subdir/{dvc_path}::subdir/{foo_path}": { - "data": [{"step": 0, "foo": 1.0}] - } - }, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - -@pytest.mark.studio() -def test_post_to_studio_inside_subdir_dvc_exp( - tmp_dir, dvc_repo, monkeypatch, mocked_studio_post -): - mocked_post, _ = mocked_studio_post - subdir = tmp_dir / "subdir" - subdir.mkdir() - monkeypatch.chdir(subdir) - - monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) - monkeypatch.setenv(DVC_EXP_NAME, "bar") - - live = Live() - live.log_metric("foo", 1) - live.next_step() - - dvc_path = Path(live.dvc_file).as_posix() - metrics_path = Path(live.metrics_file).as_posix() - foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() - - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": live._baseline_rev, - "name": live._exp_name, - "step": 0, - "metrics": {f"subdir/{metrics_path}": {"data": {"step": 0, "foo": 1}}}, - "plots": { - f"subdir/{dvc_path}::subdir/{foo_path}": { - "data": [{"step": 0, "foo": 1.0}] - } - }, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - -def test_get_dvc_studio_config_none(mocker): - mocker.patch("dvclive.live.get_dvc_repo", return_value=None) - live = Live() - assert get_dvc_studio_config(live) == {} - - -def test_get_dvc_studio_config_env_var(monkeypatch, mocker): - monkeypatch.setenv(DVC_STUDIO_TOKEN, "token") - monkeypatch.setenv(DVC_STUDIO_REPO_URL, "repo_url") - mocker.patch("dvclive.live.get_dvc_repo", return_value=None) - live = Live() - assert get_dvc_studio_config(live) == { - "token": "token", - "repo_url": "repo_url", - "url": STUDIO_URL, - } - - -def test_get_dvc_studio_config_dvc_repo(mocked_dvc_repo): - mocked_dvc_repo.config = {"studio": {"token": "token", "repo_url": "repo_url"}} - live = Live() - assert get_dvc_studio_config(live) == { - "token": "token", - "repo_url": "repo_url", - "url": STUDIO_URL, - } - - -def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post): - mocked_post, _ = mocked_studio_post - - live = Live(save_dvc_exp=True) - live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0))) - live.next_step() - - dvc_path = Path(live.dvc_file).as_posix() - metrics_path = Path(live.metrics_file).as_posix() - foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix() - - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": live._baseline_rev, - "name": live._exp_name, - "step": 0, - "metrics": {f"{metrics_path}": {"data": {"step": 0}}}, - "plots": { - f"{dvc_path}::{foo_path}": {"image": _adapt_image(foo_path)}, - }, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - -def test_post_to_studio_message(tmp_dir, mocked_dvc_repo, mocked_studio_post): - live = Live(save_dvc_exp=True, exp_message="Custom message") - - mocked_post, _ = mocked_studio_post - - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "start", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": "f" * 40, - "name": live._exp_name, - "client": "dvclive", - "message": "Custom message", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - -@pytest.mark.parametrize("exp_name", [True, False]) -@pytest.mark.parametrize("baseline_rev", [True, False]) -def test_post_to_studio_no_repo( - tmp_dir, monkeypatch, mocked_studio_post, exp_name, baseline_rev -): - monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN") - monkeypatch.setenv(DVC_STUDIO_REPO_URL, "STUDIO_REPO_URL") - baseline = None - if baseline_rev: - baseline = "f" * 40 - monkeypatch.setenv(DVC_EXP_BASELINE_REV, baseline) - if exp_name: - monkeypatch.setenv(DVC_EXP_NAME, "bar") - - live = Live(save_dvc_exp=True) - live.log_param("fooparam", 1) - - dvc_path = Path(live.dvc_file).as_posix() - metrics_path = Path(live.metrics_file).as_posix() - params_path = Path(live.params_file).as_posix() - foo_path = (Path(live.plots_dir) / Metric.subfolder / "foo.tsv").as_posix() - - mocked_post, _ = mocked_studio_post - - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "start", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": baseline, - "name": live._exp_name, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - live.log_metric("foo", 1) - - live.next_step() - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": baseline, - "name": live._exp_name, - "step": 0, - "metrics": {metrics_path: {"data": {"step": 0, "foo": 1}}}, - "params": {params_path: {"fooparam": 1}}, - "plots": {f"{dvc_path}::{foo_path}": {"data": [{"step": 0, "foo": 1.0}]}}, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - live.log_metric("foo", 2) - - live.next_step() - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "data", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": baseline, - "name": live._exp_name, - "step": 1, - "metrics": {metrics_path: {"data": {"step": 1, "foo": 2}}}, - "params": {params_path: {"fooparam": 1}}, - "plots": {f"{dvc_path}::{foo_path}": {"data": [{"step": 1, "foo": 2.0}]}}, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - live.end() - mocked_post.assert_called_with( - "https://0.0.0.0/api/live", - json={ - "type": "done", - "repo_url": "STUDIO_REPO_URL", - "baseline_sha": baseline, - "name": live._exp_name, - "client": "dvclive", - }, - headers={ - "Authorization": "token STUDIO_TOKEN", - "Content-type": "application/json", - }, - timeout=(30, 5), - ) - - assert live._exp_name is not None From 68b4f90b63146949369de607d26732f6988d22ea Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 11 Sep 2023 17:01:54 -0400 Subject: [PATCH 17/33] fix windows studio paths --- .dvc/.gitignore | 3 +++ .dvc/config | 0 .dvcignore | 3 +++ src/dvclive/studio.py | 16 +++++++++------- src/dvclive/utils.py | 6 ++++-- 5 files changed, 19 insertions(+), 9 deletions(-) create mode 100644 .dvc/.gitignore create mode 100644 .dvc/config create mode 100644 .dvcignore diff --git a/.dvc/.gitignore b/.dvc/.gitignore new file mode 100644 index 00000000..528f30c7 --- /dev/null +++ b/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/.dvc/config b/.dvc/config new file mode 100644 index 00000000..e69de29b diff --git a/.dvcignore b/.dvcignore new file mode 100644 index 00000000..51973055 --- /dev/null +++ b/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 29bb22fc..7945c77b 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -34,12 +34,13 @@ def _cast_to_numbers(datapoints): def _adapt_plot_name(live, name): + root = None if live._dvc_repo is not None: - name = rel_path(name, live._dvc_repo.root_dir) + root = live._dvc_repo.root_dir + name = rel_path(name, root) if os.path.isfile(live.dvc_file): dvc_file = live.dvc_file - if live._dvc_repo is not None: - dvc_file = rel_path(live.dvc_file, live._dvc_repo.root_dir) + dvc_file = rel_path(live.dvc_file, root) name = f"{dvc_file}::{name}" return name @@ -65,10 +66,12 @@ def _adapt_images(live): def get_studio_updates(live): + root = None + if live._dvc_repo is not None: + root = live._dvc_repo.root_dir if os.path.isfile(live.params_file): params_file = live.params_file - if live._dvc_repo is not None: - params_file = rel_path(params_file, live._dvc_repo.root_dir) + params_file = rel_path(params_file, root) params = {params_file: load_yaml(live.params_file)} else: params = {} @@ -76,8 +79,7 @@ def get_studio_updates(live): plots, metrics = parse_metrics(live) metrics_file = live.metrics_file - if live._dvc_repo is not None: - metrics_file = rel_path(metrics_file, live._dvc_repo.root_dir) + metrics_file = rel_path(metrics_file, root) metrics = {metrics_file: {"data": metrics}} plots = { diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 65f0266c..8e133e03 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -174,8 +174,10 @@ def wrapper(*args, **kwargs): def rel_path(path, dvc_root_path): - absolute_path = Path(path).absolute() - return str(Path(os.path.relpath(absolute_path, dvc_root_path)).as_posix()) + if dvc_root_path: + absolute_path = Path(path).absolute() + path = os.path.relpath(absolute_path, dvc_root_path) + return str(Path(path).as_posix()) def read_history(live, metric): From 8d2112f111b3e1de4c9b9e6921bd749e4113de3c Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 11 Sep 2023 17:24:09 -0400 Subject: [PATCH 18/33] fix windows studio paths for plots --- src/dvclive/studio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 7945c77b..71d1edc6 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -37,7 +37,7 @@ def _adapt_plot_name(live, name): root = None if live._dvc_repo is not None: root = live._dvc_repo.root_dir - name = rel_path(name, root) + name = rel_path(name, root) if os.path.isfile(live.dvc_file): dvc_file = live.dvc_file dvc_file = rel_path(live.dvc_file, root) From 937bc5ba43698803608aa153b9e26a7890ac74c2 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Thu, 7 Dec 2023 10:44:51 -0500 Subject: [PATCH 19/33] skip fabric tests if not installed --- tests/frameworks/test_fabric.py | 70 +++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/frameworks/test_fabric.py diff --git a/tests/frameworks/test_fabric.py b/tests/frameworks/test_fabric.py new file mode 100644 index 00000000..78820f81 --- /dev/null +++ b/tests/frameworks/test_fabric.py @@ -0,0 +1,70 @@ +from argparse import Namespace +from unittest.mock import Mock + +import numpy as np +import pytest +import torch + +try: + from dvclive.fabric import DVCLiveLogger +except ImportError: + pytest.skip("skipping lightning tests", allow_module_level=True) + + +class BoringModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2, bias=False) + + def forward(self, x): + x = self.layer(x) + return torch.nn.functional.mse_loss(x, torch.ones_like(x)) + + +@pytest.mark.parametrize("step_idx", [10, None]) +def test_dvclive_log_metrics(tmp_path, step_idx): + logger = DVCLiveLogger(dir=tmp_path) + metrics = { + "float": 0.3, + "int": 1, + "FloatTensor": torch.tensor(0.1), + "IntTensor": torch.tensor(1), + } + logger.log_metrics(metrics, step_idx) + + +def test_dvclive_log_hyperparams(tmp_path): + logger = DVCLiveLogger(dir=tmp_path) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "dict": {"a": {"b": "c"}}, + "list": [1, 2, 3], + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, + "tensor": torch.empty(2, 2, 2), + "array": np.empty([2, 2, 2]), + } + logger.log_hyperparams(hparams) + + +def test_dvclive_finalize(monkeypatch, tmp_path): + """Test that the SummaryWriter closes in finalize.""" + import dvclive + + monkeypatch.setattr(dvclive, "Live", Mock()) + logger = DVCLiveLogger(dir=tmp_path) + assert logger._experiment is None + logger.finalize("any") + + # no log calls, no experiment created -> nothing to flush + logger.experiment.assert_not_called() + + logger = DVCLiveLogger(dir=tmp_path) + logger.log_hyperparams({"flush_me": 11.1}) # trigger creation of an experiment + logger.finalize("any") + + # finalize flushes to experiment directory + logger.experiment.end.assert_called() From 0c49bea50956d4312436bb0d1102d65193efbeb4 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 8 Dec 2023 16:11:42 -0500 Subject: [PATCH 20/33] drop dvc repo --- .dvc/.gitignore | 3 --- .dvc/config | 0 2 files changed, 3 deletions(-) delete mode 100644 .dvc/.gitignore delete mode 100644 .dvc/config diff --git a/.dvc/.gitignore b/.dvc/.gitignore deleted file mode 100644 index 528f30c7..00000000 --- a/.dvc/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -/config.local -/tmp -/cache diff --git a/.dvc/config b/.dvc/config deleted file mode 100644 index e69de29b..00000000 From 83bb14a7f470ced72d7e5ff27e5be6782a700ff7 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 8 Dec 2023 16:13:05 -0500 Subject: [PATCH 21/33] drop dvcignore --- .dvcignore | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .dvcignore diff --git a/.dvcignore b/.dvcignore deleted file mode 100644 index 51973055..00000000 --- a/.dvcignore +++ /dev/null @@ -1,3 +0,0 @@ -# Add patterns of files dvc should ignore, which could improve -# the performance. Learn more at -# https://dvc.org/doc/user-guide/dvcignore From 71698f0bfee273fa6139510878458b244ae64d93 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 8 Dec 2023 16:21:54 -0500 Subject: [PATCH 22/33] drop unrelated test_fabric.py file --- tests/frameworks/test_fabric.py | 70 --------------------------------- 1 file changed, 70 deletions(-) delete mode 100644 tests/frameworks/test_fabric.py diff --git a/tests/frameworks/test_fabric.py b/tests/frameworks/test_fabric.py deleted file mode 100644 index 78820f81..00000000 --- a/tests/frameworks/test_fabric.py +++ /dev/null @@ -1,70 +0,0 @@ -from argparse import Namespace -from unittest.mock import Mock - -import numpy as np -import pytest -import torch - -try: - from dvclive.fabric import DVCLiveLogger -except ImportError: - pytest.skip("skipping lightning tests", allow_module_level=True) - - -class BoringModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.layer = torch.nn.Linear(32, 2, bias=False) - - def forward(self, x): - x = self.layer(x) - return torch.nn.functional.mse_loss(x, torch.ones_like(x)) - - -@pytest.mark.parametrize("step_idx", [10, None]) -def test_dvclive_log_metrics(tmp_path, step_idx): - logger = DVCLiveLogger(dir=tmp_path) - metrics = { - "float": 0.3, - "int": 1, - "FloatTensor": torch.tensor(0.1), - "IntTensor": torch.tensor(1), - } - logger.log_metrics(metrics, step_idx) - - -def test_dvclive_log_hyperparams(tmp_path): - logger = DVCLiveLogger(dir=tmp_path) - hparams = { - "float": 0.3, - "int": 1, - "string": "abc", - "bool": True, - "dict": {"a": {"b": "c"}}, - "list": [1, 2, 3], - "namespace": Namespace(foo=Namespace(bar="buzz")), - "layer": torch.nn.BatchNorm1d, - "tensor": torch.empty(2, 2, 2), - "array": np.empty([2, 2, 2]), - } - logger.log_hyperparams(hparams) - - -def test_dvclive_finalize(monkeypatch, tmp_path): - """Test that the SummaryWriter closes in finalize.""" - import dvclive - - monkeypatch.setattr(dvclive, "Live", Mock()) - logger = DVCLiveLogger(dir=tmp_path) - assert logger._experiment is None - logger.finalize("any") - - # no log calls, no experiment created -> nothing to flush - logger.experiment.assert_not_called() - - logger = DVCLiveLogger(dir=tmp_path) - logger.log_hyperparams({"flush_me": 11.1}) # trigger creation of an experiment - logger.finalize("any") - - # finalize flushes to experiment directory - logger.experiment.end.assert_called() From a9b028f6c72e751028848ddfe2a4828ebb1634ee Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 8 Dec 2023 17:13:34 -0500 Subject: [PATCH 23/33] fix windows paths --- src/dvclive/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 8e133e03..f8d82e42 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -4,7 +4,7 @@ import re import shutil import webbrowser -from pathlib import Path +from pathlib import Path, PureWindowsPath from platform import uname from typing import Union @@ -177,6 +177,8 @@ def rel_path(path, dvc_root_path): if dvc_root_path: absolute_path = Path(path).absolute() path = os.path.relpath(absolute_path, dvc_root_path) + if os.name == "nt": + path = PureWindowsPath(path) return str(Path(path).as_posix()) From bb2a3b51567eaab3f547e4b89407a548a80094a2 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 8 Dec 2023 17:29:27 -0500 Subject: [PATCH 24/33] fix windows paths --- src/dvclive/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index f8d82e42..a7eddfa5 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -178,7 +178,7 @@ def rel_path(path, dvc_root_path): absolute_path = Path(path).absolute() path = os.path.relpath(absolute_path, dvc_root_path) if os.name == "nt": - path = PureWindowsPath(path) + return str(PureWindowsPath(path).as_posix()) return str(Path(path).as_posix()) From 28cde806ef0f48134eb406aca94a342c105bd5e4 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Thu, 25 Jan 2024 10:34:51 -0500 Subject: [PATCH 25/33] adapt plot paths even if no dvc repo --- src/dvclive/studio.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index f74db7ab..e4c79a57 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -34,9 +34,8 @@ def _cast_to_numbers(datapoints): def _adapt_path(live, name): - if live._dvc_repo is not None: - name = rel_path(name, live._dvc_repo.root_dir) - return name + dvc_root_path = live._dvc_repo.root_dir if live._dvc_repo else None + return rel_path(name, dvc_root_path) def _adapt_plot_datapoints(live, plot): From 12755f2a586ff29714c3bcdf5b8115a35bbf5e81 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Thu, 25 Jan 2024 15:03:33 -0500 Subject: [PATCH 26/33] default baseline rev to all zeros --- src/dvclive/live.py | 5 ++++- tests/test_post_to_studio.py | 27 +++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index fd23c748..415dcfe2 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -95,7 +95,10 @@ def __init__( self._report_notebook = None self._init_report() - self._baseline_rev: Optional[str] = os.getenv(env.DVC_EXP_BASELINE_REV) + self._baseline_rev: str = os.getenv( + env.DVC_EXP_BASELINE_REV, + "0" * 40, # noqa: PLW1508 + ) self._exp_name: Optional[str] = exp_name or os.getenv(env.DVC_EXP_NAME) self._exp_message: Optional[str] = exp_message self._experiment_rev: Optional[str] = None diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index 9b931967..edc02a25 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -396,11 +396,10 @@ def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_ assert "data" in call_types +@pytest.mark.studio() def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): monkeypatch.setenv(DVC_STUDIO_TOKEN, "STUDIO_TOKEN") monkeypatch.setenv(DVC_STUDIO_REPO_URL, "STUDIO_REPO_URL") - monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40) - monkeypatch.setenv(DVC_EXP_NAME, "bar") live = Live(save_dvc_exp=True) live.log_param("fooparam", 1) @@ -411,7 +410,8 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): mocked_post.assert_called() mocked_post.assert_called_with( - "https://0.0.0.0/api/live", **get_studio_call("start", exp_name=live._exp_name) + "https://0.0.0.0/api/live", + **get_studio_call("start", baseline_sha="0" * 40, exp_name=live._exp_name), ) live.log_metric("foo", 1) @@ -421,6 +421,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): "https://0.0.0.0/api/live", **get_studio_call( "data", + baseline_sha="0" * 40, exp_name=live._exp_name, step=0, plots={f"{foo_path}": {"data": [{"step": 0, "foo": 1.0}]}}, @@ -434,6 +435,7 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): "https://0.0.0.0/api/live", **get_studio_call( "data", + baseline_sha="0" * 40, exp_name=live._exp_name, step=1, plots={f"{foo_path}": {"data": [{"step": 1, "foo": 2.0}]}}, @@ -443,5 +445,22 @@ def test_post_to_studio_no_repo(tmp_dir, monkeypatch, mocked_studio_post): live.end() mocked_post.assert_called_with( "https://0.0.0.0/api/live", - **get_studio_call("done", exp_name=live._exp_name), + **get_studio_call("done", baseline_sha="0" * 40, exp_name=live._exp_name), ) + + +@pytest.mark.studio() +def test_post_to_studio_skip_if_no_repo_url( + tmp_dir, + mocker, + monkeypatch, +): + mocked_post = mocker.patch("dvclive.studio.post_live_metrics", return_value=None) + + monkeypatch.setenv(DVC_STUDIO_TOKEN, "token") + + with Live() as live: + live.log_metric("foo", 1) + live.next_step() + + assert mocked_post.call_count == 0 From ade1b3d9e7960372b7b39125e00f04ce136b80b6 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 7 Feb 2024 13:23:56 -0500 Subject: [PATCH 27/33] consolidate repro tests --- tests/test_dvc.py | 37 ++++++------------------------------- 1 file changed, 6 insertions(+), 31 deletions(-) diff --git a/tests/test_dvc.py b/tests/test_dvc.py index ebfa90bc..bf263fc4 100644 --- a/tests/test_dvc.py +++ b/tests/test_dvc.py @@ -29,9 +29,9 @@ def test_get_dvc_repo_subdir(tmp_dir): def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): live = Live(save_dvc_exp=save) live.end() + assert live._baseline_rev is not None + assert live._exp_name is not None if save: - assert live._baseline_rev is not None - assert live._exp_name is not None mocked_dvc_repo.experiments.save.assert_called_with( name=live._exp_name, include_untracked=[live.dir, "dvc.yaml"], @@ -39,8 +39,6 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo): message=None, ) else: - assert live._baseline_rev is not None - assert live._exp_name is not None mocked_dvc_repo.experiments.save.assert_not_called() @@ -59,31 +57,6 @@ def test_exp_save_skip_on_env_vars(tmp_dir, monkeypatch): assert live._inside_dvc_pipeline -def test_exp_save_run_on_dvc_repro(tmp_dir, mocker): - dvc_repo = mocker.MagicMock() - dvc_stage = mocker.MagicMock() - dvc_file = mocker.MagicMock() - dvc_repo.index.stages = [dvc_stage, dvc_file] - dvc_repo.scm.get_rev.return_value = "current_rev" - dvc_repo.scm.get_ref.return_value = None - dvc_repo.scm.no_commits = False - dvc_repo.config = {} - dvc_repo.root_dir = tmp_dir - mocker.patch("dvclive.live.get_dvc_repo", return_value=dvc_repo) - live = Live() - assert live._save_dvc_exp - assert live._baseline_rev is not None - assert live._exp_name is not None - live.end() - - dvc_repo.experiments.save.assert_called_with( - name=live._exp_name, - include_untracked=[live.dir, "dvc.yaml"], - force=True, - message=None, - ) - - def test_exp_save_with_dvc_files(tmp_dir, mocker): dvc_repo = mocker.MagicMock() dvc_file = mocker.MagicMock() @@ -203,10 +176,12 @@ def test_no_scm_repo(tmp_dir, mocker): assert live._save_dvc_exp is False -def test_dvc_repro(tmp_dir, monkeypatch, mocker): +def test_dvc_repro(tmp_dir, monkeypatch, mocked_dvc_repo, mocked_studio_post): monkeypatch.setenv(DVC_ROOT, "root") - mocker.patch("dvclive.live.get_dvc_repo", return_value=None) live = Live(save_dvc_exp=True) + assert live._baseline_rev is not None + assert live._exp_name is not None + assert not live._studio_events_to_skip assert not live._save_dvc_exp From 56acdaed5518975365d2e4707495e77a9b032d01 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Wed, 7 Feb 2024 13:31:30 -0500 Subject: [PATCH 28/33] set null sha as variable --- src/dvclive/live.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 415dcfe2..8f747e6a 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -63,6 +63,8 @@ ParamLike = Union[int, float, str, bool, List["ParamLike"], Dict[str, "ParamLike"]] +NULL_SHA: str = "0" * 40 + class Live: def __init__( @@ -95,10 +97,7 @@ def __init__( self._report_notebook = None self._init_report() - self._baseline_rev: str = os.getenv( - env.DVC_EXP_BASELINE_REV, - "0" * 40, # noqa: PLW1508 - ) + self._baseline_rev: str = os.getenv(env.DVC_EXP_BASELINE_REV, NULL_SHA) self._exp_name: Optional[str] = exp_name or os.getenv(env.DVC_EXP_NAME) self._exp_message: Optional[str] = exp_message self._experiment_rev: Optional[str] = None From 33288cd567765da841d25ddcd36e60ed16ab0d9d Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 13 Feb 2024 10:36:07 -0500 Subject: [PATCH 29/33] add type hints to studio --- dvc.yaml | 8 ++++++++ src/dvclive/studio.py | 24 ++++++++++++++---------- 2 files changed, 22 insertions(+), 10 deletions(-) create mode 100644 dvc.yaml diff --git a/dvc.yaml b/dvc.yaml new file mode 100644 index 00000000..69ddfc16 --- /dev/null +++ b/dvc.yaml @@ -0,0 +1,8 @@ +metrics: +- ../../../../private/var/folders/24/99_tf1xj3vx8k1k_jkdmnhq00000gn/T/pytest-of-dave/pytest-255/test_dvclive_log_metrics_10_0/metrics.json +- ../../../../private/var/folders/24/99_tf1xj3vx8k1k_jkdmnhq00000gn/T/pytest-of-dave/pytest-255/test_dvclive_log_metrics_None_0/metrics.json +plots: +- ? ../../../../private/var/folders/24/99_tf1xj3vx8k1k_jkdmnhq00000gn/T/pytest-of-dave/pytest-255/test_dvclive_log_metrics_10_0/plots/metrics + : x: step +- ? ../../../../private/var/folders/24/99_tf1xj3vx8k1k_jkdmnhq00000gn/T/pytest-of-dave/pytest-255/test_dvclive_log_metrics_None_0/plots/metrics + : x: step diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index e4c79a57..796d065e 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -1,23 +1,27 @@ # ruff: noqa: SLF001 +from __future__ import annotations import base64 import logging import math import os +from typing import TYPE_CHECKING from dvc_studio_client.config import get_studio_config from dvc_studio_client.post_live_metrics import post_live_metrics +if TYPE_CHECKING: + from dvclive.live import Live from dvclive.serialize import load_yaml -from dvclive.utils import parse_metrics, rel_path +from dvclive.utils import parse_metrics, rel_path, StrPath logger = logging.getLogger("dvclive") -def _get_unsent_datapoints(plot, latest_step): +def _get_unsent_datapoints(plot: dict, latest_step: int): return [x for x in plot if int(x["step"]) > latest_step] -def _cast_to_numbers(datapoints): +def _cast_to_numbers(datapoints: dict): for datapoint in datapoints: for k, v in datapoint.items(): if k == "step": @@ -33,22 +37,22 @@ def _cast_to_numbers(datapoints): return datapoints -def _adapt_path(live, name): +def _adapt_path(live: Live, name: StrPath): dvc_root_path = live._dvc_repo.root_dir if live._dvc_repo else None return rel_path(name, dvc_root_path) -def _adapt_plot_datapoints(live, plot): +def _adapt_plot_datapoints(live: Live, plot: dict): datapoints = _get_unsent_datapoints(plot, live._latest_studio_step) return _cast_to_numbers(datapoints) -def _adapt_image(image_path): +def _adapt_image(image_path: StrPath): with open(image_path, "rb") as fobj: return base64.b64encode(fobj.read()).decode("utf-8") -def _adapt_images(live): +def _adapt_images(live: Live): return { _adapt_path(live, image.output_path): {"image": _adapt_image(image.output_path)} for image in live._images.values() @@ -56,7 +60,7 @@ def _adapt_images(live): } -def get_studio_updates(live): +def get_studio_updates(live: Live): if os.path.isfile(live.params_file): params_file = live.params_file params_file = _adapt_path(live, params_file) @@ -81,14 +85,14 @@ def get_studio_updates(live): return metrics, params, plots -def get_dvc_studio_config(live): +def get_dvc_studio_config(live: Live): config = {} if live._dvc_repo: config = live._dvc_repo.config.get("studio") return get_studio_config(dvc_studio_config=config) -def post_to_studio(live, event): +def post_to_studio(live: Live, event: str): if event in live._studio_events_to_skip: return From 73357b26832e894cf20e1972ff80bb3e1027c1ca Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 13 Feb 2024 10:51:21 -0500 Subject: [PATCH 30/33] limit windows path handling to studio --- dvc.yaml | 8 -------- src/dvclive/studio.py | 8 ++++++-- src/dvclive/utils.py | 10 +++------- 3 files changed, 9 insertions(+), 17 deletions(-) delete mode 100644 dvc.yaml diff --git a/dvc.yaml b/dvc.yaml deleted file mode 100644 index 69ddfc16..00000000 --- a/dvc.yaml +++ /dev/null @@ -1,8 +0,0 @@ -metrics: -- ../../../../private/var/folders/24/99_tf1xj3vx8k1k_jkdmnhq00000gn/T/pytest-of-dave/pytest-255/test_dvclive_log_metrics_10_0/metrics.json -- ../../../../private/var/folders/24/99_tf1xj3vx8k1k_jkdmnhq00000gn/T/pytest-of-dave/pytest-255/test_dvclive_log_metrics_None_0/metrics.json -plots: -- ? ../../../../private/var/folders/24/99_tf1xj3vx8k1k_jkdmnhq00000gn/T/pytest-of-dave/pytest-255/test_dvclive_log_metrics_10_0/plots/metrics - : x: step -- ? ../../../../private/var/folders/24/99_tf1xj3vx8k1k_jkdmnhq00000gn/T/pytest-of-dave/pytest-255/test_dvclive_log_metrics_None_0/plots/metrics - : x: step diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 796d065e..7add1389 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -4,6 +4,7 @@ import logging import math import os +from pathlib import PureWindowsPath from typing import TYPE_CHECKING from dvc_studio_client.config import get_studio_config @@ -38,8 +39,11 @@ def _cast_to_numbers(datapoints: dict): def _adapt_path(live: Live, name: StrPath): - dvc_root_path = live._dvc_repo.root_dir if live._dvc_repo else None - return rel_path(name, dvc_root_path) + if live._dvc_repo is not None: + name = rel_path(name, live._dvc_repo.root_dir) + if os.name == "nt": + name = str(PureWindowsPath(name).as_posix()) + return name def _adapt_plot_datapoints(live: Live, plot: dict): diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index e128c512..a168de9e 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -4,7 +4,7 @@ import os import re import shutil -from pathlib import Path, PureWindowsPath +from pathlib import Path from platform import uname from typing import Union, List, Dict, TYPE_CHECKING import webbrowser @@ -192,12 +192,8 @@ def wrapper(*args, **kwargs): def rel_path(path, dvc_root_path): - if dvc_root_path: - absolute_path = Path(path).absolute() - path = os.path.relpath(absolute_path, dvc_root_path) - if os.name == "nt": - return str(PureWindowsPath(path).as_posix()) - return str(Path(path).as_posix()) + absolute_path = Path(path).absolute() + return str(Path(os.path.relpath(absolute_path, dvc_root_path)).as_posix()) def read_history(live, metric): From 53e77a61c759b1879b2b617aa859922553c9114c Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 13 Feb 2024 11:14:13 -0500 Subject: [PATCH 31/33] fix typing errors in studio module --- src/dvclive/live.py | 2 +- src/dvclive/studio.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 7d4cd145..ab873908 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -158,7 +158,7 @@ def __init__( else: self._init_cleanup() - self._latest_studio_step = self.step if resume else -1 + self._latest_studio_step: int = self.step if resume else -1 self._studio_events_to_skip: Set[str] = set() self._dvc_studio_config: Dict[str, Any] = {} self._init_studio() diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index 7add1389..ca63e4b0 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -5,7 +5,7 @@ import math import os from pathlib import PureWindowsPath -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, Mapping from dvc_studio_client.config import get_studio_config from dvc_studio_client.post_live_metrics import post_live_metrics @@ -18,11 +18,11 @@ logger = logging.getLogger("dvclive") -def _get_unsent_datapoints(plot: dict, latest_step: int): +def _get_unsent_datapoints(plot: Mapping, latest_step: int): return [x for x in plot if int(x["step"]) > latest_step] -def _cast_to_numbers(datapoints: dict): +def _cast_to_numbers(datapoints: Mapping): for datapoint in datapoints: for k, v in datapoint.items(): if k == "step": @@ -46,7 +46,7 @@ def _adapt_path(live: Live, name: StrPath): return name -def _adapt_plot_datapoints(live: Live, plot: dict): +def _adapt_plot_datapoints(live: Live, plot: Mapping): datapoints = _get_unsent_datapoints(plot, live._latest_studio_step) return _cast_to_numbers(datapoints) @@ -96,7 +96,7 @@ def get_dvc_studio_config(live: Live): return get_studio_config(dvc_studio_config=config) -def post_to_studio(live: Live, event: str): +def post_to_studio(live: Live, event: Literal["start", "data", "done"]): if event in live._studio_events_to_skip: return @@ -105,7 +105,7 @@ def post_to_studio(live: Live, event: str): kwargs["message"] = live._exp_message elif event == "data": metrics, params, plots = get_studio_updates(live) - kwargs["step"] = live.step + kwargs["step"] = live.step # type: ignore kwargs["metrics"] = metrics kwargs["params"] = params kwargs["plots"] = plots @@ -115,10 +115,10 @@ def post_to_studio(live: Live, event: str): response = post_live_metrics( event, live._baseline_rev, - live._exp_name, + live._exp_name, # type: ignore "dvclive", dvc_studio_config=live._dvc_studio_config, - **kwargs, + **kwargs, # type: ignore ) if not response: logger.warning(f"`post_to_studio` `{event}` failed.") From bea5a833a5f99d76bc32af47899352231faf5153 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 13 Feb 2024 11:24:58 -0500 Subject: [PATCH 32/33] fix mypy in live module --- src/dvclive/live.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index ab873908..33be50b3 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -828,7 +828,7 @@ def make_dvcyaml(self): make_dvcyaml(self) @catch_and_warn(DvcException, logger) - def post_to_studio(self, event: str): + def post_to_studio(self, event: Literal["start", "data", "done"]): post_to_studio(self, event) def end(self): From bd431da38171a9019c67b93975b9856699cc2c5a Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Tue, 13 Feb 2024 14:52:27 -0500 Subject: [PATCH 33/33] drop checking for dvc_file --- src/dvclive/live.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 33be50b3..b6f983eb 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -824,8 +824,7 @@ def make_dvcyaml(self): `Live.next_step()` and `Live.end()` will call `Live.make_dvcyaml()` internally, so you don't need to call both (unless `dvcyaml=None`). """ - if self.dvc_file: - make_dvcyaml(self) + make_dvcyaml(self) @catch_and_warn(DvcException, logger) def post_to_studio(self, event: Literal["start", "data", "done"]):