Skip to content

Commit

Permalink
add tests to hf accelerate dvclivetracker
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Nov 4, 2023
1 parent eb89794 commit c5caedf
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion tests/frameworks/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
try:
import numpy as np
import torch
from accelerate import Accelerator
from torch import nn
from transformers import (
PretrainedConfig,
Expand All @@ -18,7 +19,7 @@
TrainingArguments,
)

from dvclive.huggingface import DVCLiveCallback
from dvclive.huggingface import DVCLiveCallback, DVCLiveTracker
except ImportError:
pytest.skip("skipping huggingface tests", allow_module_level=True)

Expand Down Expand Up @@ -182,3 +183,34 @@ def test_huggingface_pass_logger():

assert DVCLiveCallback().live is not logger
assert DVCLiveCallback(live=logger).live is logger


def test_accelerate_tracker():
tracker = DVCLiveTracker(dir="test_dir")
accelerator = Accelerator(log_with=tracker)
config = {
"num_iterations": 12,
"learning_rate": 1e-2,
"some_boolean": False,
"some_string": "some_value",
}
accelerator.init_trackers("test_project", config=config)
values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"}
accelerator.log(values, step=0)
accelerator.end_training()

# Check kwargs passed
live = accelerator.trackers[0].live
assert live.dir == "test_dir"

# Check params logged
params = load_yaml(live.params_file)
assert params == config

# Check metrics logged
logs, latest = parse_metrics(live)
assert latest == values
scalars = os.path.join(live.plots_dir, Metric.subfolder)
assert os.path.join(scalars, "total_loss.tsv") in logs
assert os.path.join(scalars, "iteration.tsv") in logs
assert os.path.join(scalars, "my_text.tsv") in logs

0 comments on commit c5caedf

Please sign in to comment.