Skip to content

Commit

Permalink
run both CPU and GPU integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Jan 30, 2025
1 parent 450c5b1 commit 9dc59f6
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 84 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/pkgci_shark_ai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ jobs:
strategy:
matrix:
version: [3.11]
device: ["cpu", "gfx942"]
fail-fast: false
runs-on: azure-cpubuilder-linux-scale
# runs-on: ubuntu-latest # everything else works but this throws an "out of resources" during model loading
# TODO: make a copy of this that runs on standard runners with tiny llama instead of a 8b model
runs-on: ${{ matrix.device == 'cpu' && 'azure-cpubuilder-linux-scale' || 'mi300x-3' }}
defaults:
run:
shell: bash
Expand Down Expand Up @@ -77,4 +76,4 @@ jobs:
- name: Run LLM Integration Tests
run: |
source ${VENV_DIR}/bin/activate
pytest -v -s app_tests/integration_tests/llm/shortfin --log-cli-level=INFO
SHORTFIN_INTEGRATION_TEST_DEVICE=${{ matrix.device }} pytest -v -s app_tests/integration_tests/llm/shortfin --log-cli-level=INFO
30 changes: 30 additions & 0 deletions app_tests/integration_tests/llm/device_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,33 @@ class DeviceSettings:
),
server_flags=("--device=local-task",),
)

GFX942 = DeviceSettings(
compile_flags=(
"--iree-hal-target-backends=rocm",
"--iree-hip-target=gfx942",
),
server_flags=("--device=hip",),
)


def get_device_based_on_env_variable():
import os

device_name = os.environ.get("SHORTFIN_INTEGRATION_TEST_DEVICE", "cpu").lower()

table = {
"gpu": GFX942,
"amdgpu": GFX942,
"gfx942": GFX942,
"host": CPU,
"hostcpu": CPU,
"local-task": CPU,
"cpu": CPU,
}
if device_name in table:
return table[device_name]

raise ValueError(
f"os.environ['SHORTFIN_INTEGRATION_TEST_DEVICE']=={device_name} but is not recognized. Supported device names: {list(table.keys())}"
)
10 changes: 6 additions & 4 deletions app_tests/integration_tests/llm/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from sharktank.utils.hf_datasets import Dataset, RemoteFile, get_dataset

from .. import device_settings
from . import device_settings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,6 +70,7 @@ class ModelArtifacts:
mlir_path: Path
vmfb_path: Path
config_path: Path
model_config: ModelConfig # config that was originally used to generate these artifacts


class ModelStageManager:
Expand Down Expand Up @@ -277,6 +278,7 @@ def process_model(self, config: ModelConfig) -> ModelArtifacts:
mlir_path=mlir_path,
vmfb_path=vmfb_path,
config_path=config_path,
model_config=config,
)


Expand All @@ -287,15 +289,15 @@ def process_model(self, config: ModelConfig) -> ModelArtifacts:
model_file="open-llama-3b-v2-f16.gguf",
tokenizer_id="openlm-research/open_llama_3b_v2",
batch_sizes=(1, 4),
device_settings=device_settings.CPU,
device_settings=device_settings.get_device_based_on_env_variable(),
),
"llama3.1_8b": ModelConfig(
source=ModelSource.HUGGINGFACE,
repo_id="SanctumAI/Meta-Llama-3.1-8B-Instruct-GGUF",
model_file="meta-llama-3.1-8b-instruct.f16.gguf",
tokenizer_id="NousResearch/Meta-Llama-3.1-8B",
batch_sizes=(1, 4),
device_settings=device_settings.CPU,
device_settings=device_settings.get_device_based_on_env_variable(),
),
"azure_llama": ModelConfig(
source=ModelSource.AZURE,
Expand All @@ -307,6 +309,6 @@ def process_model(self, config: ModelConfig) -> ModelArtifacts:
model_file="azure-llama.irpa",
tokenizer_id="openlm-research/open_llama_3b_v2",
batch_sizes=(1, 4),
device_settings=device_settings.CPU,
device_settings=device_settings.get_device_based_on_env_variable(),
),
}
3 changes: 1 addition & 2 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def model_artifacts(tmp_path_factory, request):
@pytest.fixture(scope="module")
def server(model_artifacts, request):
"""Starts and manages the test server."""
model_id = request.param["model"]
model_config = TEST_MODELS[model_id]
model_config = model_artifacts.model_config

server_config = ServerConfig(
artifacts=model_artifacts,
Expand Down
83 changes: 9 additions & 74 deletions app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,17 @@ def __init__(
super().__init__(self.message)


@pytest.mark.parameterize("model_artifacts", ["llama3.1_8b", "open_llama_3b"])
@pytest.mark.parameterize(
"server",
[
{"prefix_sharing": "none"},
{"prefix_sharing": "trie"},
],
)
class TestLLMServer:
"""Test suite for LLM server functionality."""

@pytest.mark.parametrize(
"model_artifacts,server",
[
("open_llama_3b", {"model": "open_llama_3b", "prefix_sharing": "none"}),
("open_llama_3b", {"model": "open_llama_3b", "prefix_sharing": "trie"}),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "none"},
marks=pytest.mark.skip(
"Skipping meta llama because we haven't set up the model artifact downloading yet."
),
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "trie"},
marks=pytest.mark.skip(
"Skipping meta llama because we haven't set up the model artifact downloading yet."
),
),
],
ids=[
"open_llama_3b_none",
"open_llama_3b_trie",
"llama31_8b_none",
"llama31_8b_trie",
],
indirect=True,
)
def test_basic_generation(self, server: tuple[Any, int]) -> None:
"""Tests basic text generation capabilities.
Expand All @@ -78,44 +58,7 @@ def test_basic_generation(self, server: tuple[Any, int]) -> None:
message=f"Generation did not match expected pattern.\nExpected to start with: {expected_prefix}\nActual response: {response}",
)

@pytest.mark.parametrize(
"model_artifacts,server,encoded_prompt",
[
(
"open_llama_3b",
{"model": "open_llama_3b", "prefix_sharing": "none"},
"0 1 2 3 4 5 ",
),
(
"open_llama_3b",
{"model": "open_llama_3b", "prefix_sharing": "trie"},
"0 1 2 3 4 5 ",
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "none"},
"0 1 2 3 4 5 ",
marks=pytest.mark.skip(
"Skipping meta llama because we haven't set up the model artifact downloading yet."
),
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "trie"},
"0 1 2 3 4 5 ",
marks=pytest.mark.skip(
"Skipping meta llama because we haven't set up the model artifact downloading yet."
),
),
],
ids=[
"open_llama_3b_none_input_ids",
"open_llama_3b_trie_input_ids",
"llama31_8b_none_input_ids",
"llama31_8b_trie_input_ids",
],
indirect=True,
)
@pytest.mark.parametrize("encoded_prompt", ["0 1 2 3 4 5 "])
def test_basic_generation_input_ids(
self, server: tuple[Any, int], encoded_prompt
) -> None:
Expand All @@ -136,14 +79,6 @@ def test_basic_generation_input_ids(
message=f"Generation did not match expected pattern.\nExpected to start with: {expected_prefix}\nActual response: {response}",
)

@pytest.mark.parametrize(
"model_artifacts,server",
[
("open_llama_3b", {"model": "open_llama_3b", "prefix_sharing": "none"}),
("open_llama_3b", {"model": "open_llama_3b", "prefix_sharing": "trie"}),
],
indirect=True,
)
@pytest.mark.parametrize("concurrent_requests", [2, 4, 8])
def test_concurrent_generation(
self, server: tuple[Any, int], concurrent_requests: int
Expand Down

0 comments on commit 9dc59f6

Please sign in to comment.