diff --git a/tests/models/llama/test_inference_llama.py b/tests/models/llama/test_inference_llama.py index 3919914c..621703b8 100644 --- a/tests/models/llama/test_inference_llama.py +++ b/tests/models/llama/test_inference_llama.py @@ -2,16 +2,28 @@ import unittest import torch -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer device = "cuda" - - class LLaMaInferenceTest(unittest.TestCase): def test_foo(self): ckpt = "meta-llama/Llama-2-7b-hf" + + tokenizer = AutoTokenizer.from_pretrained(ckpt) + + prompt = "Hey, are you conscious? Can you talk to me?" + inputs = tokenizer(prompt, return_tensors="pt").to(device) + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, token=os.getenv("HF_HUB_READ_TOKEN", None)) model.to(device) + + # Generate + generate_ids = model.generate(inputs.input_ids, max_length=30) + output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + expected_output = "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + + assert output == expected_output