Skip to content

Commit

Permalink
docker
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Feb 5, 2024
1 parent e5ad8e3 commit 371d815
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tests/models/llama/test_inference_llama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import unittest
import torch

from transformers import AutoModelForCausalLM

Expand All @@ -15,5 +16,5 @@ def test_foo(self):
assert os.getenv("HF_HUB_READ_TOKEN", None) != None
assert type(os.getenv("HF_HUB_READ_TOKEN", None)) == str
assert len(os.getenv("HF_HUB_READ_TOKEN", None)) > 3
model = AutoModelForCausalLM.from_pretrained(ckpt, token=os.getenv("HF_HUB_READ_TOKEN", None))
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, token=os.getenv("HF_HUB_READ_TOKEN", None))
model.to(device)
3 changes: 2 additions & 1 deletion tests/models/llama/test_train_llama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import unittest
import torch

from transformers import AutoModelForCausalLM

Expand All @@ -13,6 +14,6 @@ def test_foo(self):
assert os.getenv("HF_HUB_READ_TOKEN", None) != None
assert type(os.getenv("HF_HUB_READ_TOKEN", None)) == str
assert len(os.getenv("HF_HUB_READ_TOKEN", None)) > 3
model = AutoModelForCausalLM.from_pretrained(ckpt, token=os.getenv("HF_HUB_READ_TOKEN", None))
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, token=os.getenv("HF_HUB_READ_TOKEN", None))
model.train()
model.to(device)

0 comments on commit 371d815

Please sign in to comment.