diff --git a/tests/models/llama/test_inference_llama.py b/tests/models/llama/test_inference_llama.py index 626c52c8..1fc8f1d3 100644 --- a/tests/models/llama/test_inference_llama.py +++ b/tests/models/llama/test_inference_llama.py @@ -4,6 +4,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer +torch.backends.cudnn.deterministic = True +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + device = "cuda"