From 371d8158457abbd471d72fd22e16fecebfab7cd8 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Mon, 5 Feb 2024 13:18:46 +0100 Subject: [PATCH] docker --- tests/models/llama/test_inference_llama.py | 3 ++- tests/models/llama/test_train_llama.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/llama/test_inference_llama.py b/tests/models/llama/test_inference_llama.py index 54da9c34..1fc469c7 100644 --- a/tests/models/llama/test_inference_llama.py +++ b/tests/models/llama/test_inference_llama.py @@ -1,5 +1,6 @@ import os import unittest +import torch from transformers import AutoModelForCausalLM @@ -15,5 +16,5 @@ def test_foo(self): assert os.getenv("HF_HUB_READ_TOKEN", None) != None assert type(os.getenv("HF_HUB_READ_TOKEN", None)) == str assert len(os.getenv("HF_HUB_READ_TOKEN", None)) > 3 - model = AutoModelForCausalLM.from_pretrained(ckpt, token=os.getenv("HF_HUB_READ_TOKEN", None)) + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, token=os.getenv("HF_HUB_READ_TOKEN", None)) model.to(device) diff --git a/tests/models/llama/test_train_llama.py b/tests/models/llama/test_train_llama.py index c5c78bd8..ccbb0ae0 100644 --- a/tests/models/llama/test_train_llama.py +++ b/tests/models/llama/test_train_llama.py @@ -1,5 +1,6 @@ import os import unittest +import torch from transformers import AutoModelForCausalLM @@ -13,6 +14,6 @@ def test_foo(self): assert os.getenv("HF_HUB_READ_TOKEN", None) != None assert type(os.getenv("HF_HUB_READ_TOKEN", None)) == str assert len(os.getenv("HF_HUB_READ_TOKEN", None)) > 3 - model = AutoModelForCausalLM.from_pretrained(ckpt, token=os.getenv("HF_HUB_READ_TOKEN", None)) + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, token=os.getenv("HF_HUB_READ_TOKEN", None)) model.train() model.to(device) \ No newline at end of file