diff --git a/tests/e2e/vLLM/test_vllm.py b/tests/e2e/vLLM/test_vllm.py index 70a6a35e4..d233f5ee1 100644 --- a/tests/e2e/vLLM/test_vllm.py +++ b/tests/e2e/vLLM/test_vllm.py @@ -2,12 +2,12 @@ import re import shutil from pathlib import Path -from typing import Callable import pytest import yaml from huggingface_hub import HfApi from loguru import logger +from parameterized import parameterized_class from llmcompressor.core import active_session from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing @@ -34,15 +34,10 @@ ] -@pytest.fixture -def record_config_file(record_testsuite_property: Callable[[str, object], None]): - test_data_file_name = TEST_DATA_FILE.split("configs/")[-1] - record_testsuite_property("TEST_DATA_FILE_NAME", test_data_file_name) - - # Will run each test case in its own process through run_tests.sh # emulating vLLM CI testing @requires_gpu_count(1) +@parameterized_class("test_data_file", [(TEST_DATA_FILE,)]) @pytest.mark.skipif(not vllm_installed, reason="vLLM is not installed, skipping test") class TestvLLM: """ @@ -62,7 +57,9 @@ class TestvLLM: """ # noqa: E501 def set_up(self): - eval_config = yaml.safe_load(Path(TEST_DATA_FILE).read_text(encoding="utf-8")) + eval_config = yaml.safe_load( + Path(self.test_data_file).read_text(encoding="utf-8") + ) if os.environ.get("CADENCE", "commit") != eval_config.get("cadence"): pytest.skip("Skipping test; cadence mismatch") @@ -90,7 +87,6 @@ def set_up(self): ] self.api = HfApi() - @pytest.mark.usefixtures("record_config_file") def test_vllm(self): # Run vLLM with saved model import torch