Skip to content

Commit

Permalink
dvcyaml: write to cwd instead of git root (#729)
Browse files Browse the repository at this point in the history
* dvcyaml: write to cwd instead of git root

* revert changes to lightning tests

* refactor dvcyaml
  • Loading branch information
Dave Berenbaum authored Oct 31, 2023
1 parent 0df9dc5 commit 3403046
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 23 deletions.
13 changes: 3 additions & 10 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
resume: bool = False,
report: Optional[str] = None,
save_dvc_exp: bool = True,
dvcyaml: Union[str, bool] = True,
dvcyaml: Union[str, bool] = "dvc.yaml",
cache_images: bool = False,
exp_name: Optional[str] = None,
exp_message: Optional[str] = None,
Expand Down Expand Up @@ -191,13 +191,7 @@ def _init_dvc_file(self) -> str:
if os.path.basename(self._dvcyaml) == "dvc.yaml":
return self._dvcyaml
raise InvalidDvcyamlError
if self._dvc_repo is not None:
return os.path.join(self._dvc_repo.root_dir, "dvc.yaml")
logger.warning(
"Can't infer dvcyaml path without a DVC repo. "
"`dvc.yaml` file will not be written."
)
return ""
return "dvc.yaml"

def _init_dvc_pipeline(self):
if os.getenv(env.DVC_EXP_BASELINE_REV, None):
Expand Down Expand Up @@ -543,8 +537,7 @@ def make_report(self):

@catch_and_warn(DvcException, logger)
def make_dvcyaml(self):
if self.dvc_file:
make_dvcyaml(self)
make_dvcyaml(self)

@catch_and_warn(DvcException, logger)
def post_to_studio(self, event):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_exp_save_on_end(tmp_dir, save, mocked_dvc_repo):
assert live._exp_name is not None
mocked_dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")],
include_untracked=[live.dir, "dvc.yaml"],
force=True,
message=None,
)
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_exp_save_run_on_dvc_repro(tmp_dir, mocker):

dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")],
include_untracked=[live.dir, "dvc.yaml"],
force=True,
message=None,
)
Expand All @@ -102,7 +102,7 @@ def test_exp_save_with_dvc_files(tmp_dir, mocker):

dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")],
include_untracked=[live.dir, "dvc.yaml"],
force=True,
message=None,
)
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_exp_save_message(tmp_dir, mocked_dvc_repo):
live.end()
mocked_dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")],
include_untracked=[live.dir, "dvc.yaml"],
force=True,
message="Custom message",
)
Expand All @@ -186,7 +186,7 @@ def test_exp_save_name(tmp_dir, mocked_dvc_repo):
live.end()
mocked_dvc_repo.experiments.save.assert_called_with(
name="custom-name",
include_untracked=[live.dir, str(tmp_dir / "dvc.yaml")],
include_untracked=[live.dir, "dvc.yaml"],
force=True,
message=None,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_log_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_log_artifact_with_save_dvc_exp(tmp_dir, mocker, mocked_dvc_repo):
live.log_artifact("data")
mocked_dvc_repo.experiments.save.assert_called_with(
name=live._exp_name,
include_untracked=[live.dir, "data", ".gitignore", str(tmp_dir / "dvc.yaml")],
include_untracked=[live.dir, "data", ".gitignore", "dvc.yaml"],
force=True,
message=None,
)
Expand Down
29 changes: 22 additions & 7 deletions tests/test_make_dvcyaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,18 +440,33 @@ def test_make_dvcyaml(tmp_dir, mocked_dvc_repo, dvcyaml):


def test_make_dvcyaml_no_repo(tmp_dir, mocker):
logger = mocker.patch("dvclive.live.logger")
dvclive = Live("logs")
dvclive.make_dvcyaml()

assert not os.path.exists("dvc.yaml")
assert not dvclive.dvc_file
logger.warning.assert_any_call(
"Can't infer dvcyaml path without a DVC repo. "
"`dvc.yaml` file will not be written."
)
assert os.path.exists("dvc.yaml")


def test_make_dvcyaml_invalid(tmp_dir, mocker):
with pytest.raises(InvalidDvcyamlError):
Live("logs", dvcyaml="invalid")


def test_make_dvcyaml_on_end(tmp_dir, mocker):
dvclive = Live("logs")
dvclive.end()

assert os.path.exists("dvc.yaml")


def test_make_dvcyaml_false(tmp_dir, mocker):
dvclive = Live("logs", dvcyaml=False)
dvclive.end()

assert not os.path.exists("dvc.yaml")


def test_make_dvcyaml_none(tmp_dir, mocker):
dvclive = Live("logs", dvcyaml=None)
dvclive.end()

assert not os.path.exists("dvc.yaml")

0 comments on commit 3403046

Please sign in to comment.