From c5caedf28df1e63039a34aa1cd7599ecc01e9db0 Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Fri, 3 Nov 2023 20:06:53 -0400 Subject: [PATCH] add tests to hf accelerate dvclivetracker --- tests/frameworks/test_huggingface.py | 34 +++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/frameworks/test_huggingface.py b/tests/frameworks/test_huggingface.py index 65f1e151..078ea00d 100644 --- a/tests/frameworks/test_huggingface.py +++ b/tests/frameworks/test_huggingface.py @@ -10,6 +10,7 @@ try: import numpy as np import torch + from accelerate import Accelerator from torch import nn from transformers import ( PretrainedConfig, @@ -18,7 +19,7 @@ TrainingArguments, ) - from dvclive.huggingface import DVCLiveCallback + from dvclive.huggingface import DVCLiveCallback, DVCLiveTracker except ImportError: pytest.skip("skipping huggingface tests", allow_module_level=True) @@ -182,3 +183,34 @@ def test_huggingface_pass_logger(): assert DVCLiveCallback().live is not logger assert DVCLiveCallback(live=logger).live is logger + + +def test_accelerate_tracker(): + tracker = DVCLiveTracker(dir="test_dir") + accelerator = Accelerator(log_with=tracker) + config = { + "num_iterations": 12, + "learning_rate": 1e-2, + "some_boolean": False, + "some_string": "some_value", + } + accelerator.init_trackers("test_project", config=config) + values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"} + accelerator.log(values, step=0) + accelerator.end_training() + + # Check kwargs passed + live = accelerator.trackers[0].live + assert live.dir == "test_dir" + + # Check params logged + params = load_yaml(live.params_file) + assert params == config + + # Check metrics logged + logs, latest = parse_metrics(live) + assert latest == values + scalars = os.path.join(live.plots_dir, Metric.subfolder) + assert os.path.join(scalars, "total_loss.tsv") in logs + assert os.path.join(scalars, "iteration.tsv") in logs + assert os.path.join(scalars, "my_text.tsv") in logs