From a186c4bcb0d3a9a644aeb979c576268ff8c7f459 Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Wed, 18 Dec 2024 10:32:16 -0800 Subject: [PATCH] fix(hf): pass fake eval dataset since it is required --- src/dvclive/huggingface.py | 5 ++++- tests/frameworks/test_huggingface.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 883c1490..ed91fef4 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -73,7 +73,10 @@ def on_train_end( ): if self._log_model is True and state.is_world_process_zero: fake_trainer = Trainer( - args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer") + args=args, + model=kwargs.get("model"), + tokenizer=kwargs.get("tokenizer"), + eval_dataset=["fake"], ) name = "best" if args.load_best_model_at_end else "last" output_dir = os.path.join(args.output_dir, name) diff --git a/tests/frameworks/test_huggingface.py b/tests/frameworks/test_huggingface.py index 18b39901..42db6d61 100644 --- a/tests/frameworks/test_huggingface.py +++ b/tests/frameworks/test_huggingface.py @@ -162,6 +162,7 @@ def test_huggingface_log_model( live_callback = callback(live=live, log_model=log_model) args.load_best_model_at_end = best + args.metric_for_best_model = "loss" trainer = Trainer( model,