Skip to content

Commit c5caedf

Browse files
author
dberenbaum
committed
add tests to hf accelerate dvclivetracker
1 parent eb89794 commit c5caedf

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

tests/frameworks/test_huggingface.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
try:
1111
import numpy as np
1212
import torch
13+
from accelerate import Accelerator
1314
from torch import nn
1415
from transformers import (
1516
PretrainedConfig,
@@ -18,7 +19,7 @@
1819
TrainingArguments,
1920
)
2021

21-
from dvclive.huggingface import DVCLiveCallback
22+
from dvclive.huggingface import DVCLiveCallback, DVCLiveTracker
2223
except ImportError:
2324
pytest.skip("skipping huggingface tests", allow_module_level=True)
2425

@@ -182,3 +183,34 @@ def test_huggingface_pass_logger():
182183

183184
assert DVCLiveCallback().live is not logger
184185
assert DVCLiveCallback(live=logger).live is logger
186+
187+
188+
def test_accelerate_tracker():
189+
tracker = DVCLiveTracker(dir="test_dir")
190+
accelerator = Accelerator(log_with=tracker)
191+
config = {
192+
"num_iterations": 12,
193+
"learning_rate": 1e-2,
194+
"some_boolean": False,
195+
"some_string": "some_value",
196+
}
197+
accelerator.init_trackers("test_project", config=config)
198+
values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"}
199+
accelerator.log(values, step=0)
200+
accelerator.end_training()
201+
202+
# Check kwargs passed
203+
live = accelerator.trackers[0].live
204+
assert live.dir == "test_dir"
205+
206+
# Check params logged
207+
params = load_yaml(live.params_file)
208+
assert params == config
209+
210+
# Check metrics logged
211+
logs, latest = parse_metrics(live)
212+
assert latest == values
213+
scalars = os.path.join(live.plots_dir, Metric.subfolder)
214+
assert os.path.join(scalars, "total_loss.tsv") in logs
215+
assert os.path.join(scalars, "iteration.tsv") in logs
216+
assert os.path.join(scalars, "my_text.tsv") in logs

0 commit comments

Comments
 (0)