|
10 | 10 | try:
|
11 | 11 | import numpy as np
|
12 | 12 | import torch
|
| 13 | + from accelerate import Accelerator |
13 | 14 | from torch import nn
|
14 | 15 | from transformers import (
|
15 | 16 | PretrainedConfig,
|
|
18 | 19 | TrainingArguments,
|
19 | 20 | )
|
20 | 21 |
|
21 |
| - from dvclive.huggingface import DVCLiveCallback |
| 22 | + from dvclive.huggingface import DVCLiveCallback, DVCLiveTracker |
22 | 23 | except ImportError:
|
23 | 24 | pytest.skip("skipping huggingface tests", allow_module_level=True)
|
24 | 25 |
|
@@ -182,3 +183,34 @@ def test_huggingface_pass_logger():
|
182 | 183 |
|
183 | 184 | assert DVCLiveCallback().live is not logger
|
184 | 185 | assert DVCLiveCallback(live=logger).live is logger
|
| 186 | + |
| 187 | + |
| 188 | +def test_accelerate_tracker(): |
| 189 | + tracker = DVCLiveTracker(dir="test_dir") |
| 190 | + accelerator = Accelerator(log_with=tracker) |
| 191 | + config = { |
| 192 | + "num_iterations": 12, |
| 193 | + "learning_rate": 1e-2, |
| 194 | + "some_boolean": False, |
| 195 | + "some_string": "some_value", |
| 196 | + } |
| 197 | + accelerator.init_trackers("test_project", config=config) |
| 198 | + values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"} |
| 199 | + accelerator.log(values, step=0) |
| 200 | + accelerator.end_training() |
| 201 | + |
| 202 | + # Check kwargs passed |
| 203 | + live = accelerator.trackers[0].live |
| 204 | + assert live.dir == "test_dir" |
| 205 | + |
| 206 | + # Check params logged |
| 207 | + params = load_yaml(live.params_file) |
| 208 | + assert params == config |
| 209 | + |
| 210 | + # Check metrics logged |
| 211 | + logs, latest = parse_metrics(live) |
| 212 | + assert latest == values |
| 213 | + scalars = os.path.join(live.plots_dir, Metric.subfolder) |
| 214 | + assert os.path.join(scalars, "total_loss.tsv") in logs |
| 215 | + assert os.path.join(scalars, "iteration.tsv") in logs |
| 216 | + assert os.path.join(scalars, "my_text.tsv") in logs |
0 commit comments