-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SGLang Integration + Accuracy Tests, Restructure app_tests/integratio…
…n_tests (#570) # Description This PR implements integration tests for the Shortfin LLM Server w/ the SGLang integration. It uses llama3-8b-instruct on GPU, which is downloaded using sharktank's `hf_datasets` script. The tests server two purposes: 1. Test that the SGLang integration works properly at a functional level. 2. Test that the accuracy of the responses from the shortfin LLM server are consistent. - We have a batch of candidate questions, with expected answers - We have temperature set to `1.0`, so the responses should be deterministic. This test is intended to run every 4 hours, which allows for us to detect degradations in shortfin LLM output accuracy. If we do get a failure due to an accuracy degradation, there will only be a small set of shark-ai/iree commits that could be responsible.
- Loading branch information
Showing
13 changed files
with
575 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
name: SGLang Llama Integration Tests | ||
|
||
on: | ||
workflow_dispatch: | ||
schedule: | ||
# Run periodically, every 4 hours. This is ran periodically with the | ||
# intent of catching regressions early, and allowing for those | ||
# regressions to be easily triaged to a small subset of commits. | ||
- cron: '0 */4 * * *' | ||
|
||
concurrency: | ||
# A PR number if a pull request and otherwise the commit hash. This cancels | ||
# queued and in-progress runs for the same PR (presubmit) or commit | ||
# (postsubmit). The workflow name is prepended to avoid conflicts between | ||
# different workflows. | ||
group: ${{ github.workflow }}-${{ github.event.number || github.sha }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
sglang_bench_serve: | ||
name: "SGLang Integration Tests" | ||
strategy: | ||
matrix: | ||
version: [3.11] | ||
fail-fast: false | ||
runs-on: llama-mi300x-3 | ||
defaults: | ||
run: | ||
shell: bash | ||
env: | ||
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" | ||
steps: | ||
- name: Get Current Date | ||
id: date | ||
run: echo "::set-output name=date::$(date +'%Y-%m-%d')" | ||
|
||
- name: "Setting up Python" | ||
id: setup_python | ||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | ||
with: | ||
python-version: ${{matrix.version}} | ||
|
||
- name: "Checkout Code" | ||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
|
||
- name: Cache Pip Packages | ||
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 | ||
id: cache-pip | ||
with: | ||
path: ${{ env.PIP_CACHE_DIR }} | ||
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} | ||
|
||
- name: Install pip deps | ||
run: | | ||
python -m pip install --no-compile --upgrade pip | ||
# Note: We install in three steps in order to satisfy requirements | ||
# from non default locations first. Installing the PyTorch CPU | ||
# wheels saves multiple minutes and a lot of bandwidth on runner setup. | ||
pip install --no-compile -r pytorch-cpu-requirements.txt | ||
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ | ||
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" | ||
pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ | ||
# Use newest possible releases to be able to track commits that may | ||
# cause errors. | ||
pip install -f https://iree.dev/pip-release-links.html --upgrade \ | ||
iree-base-compiler \ | ||
iree-base-runtime \ | ||
"numpy<2.0" | ||
- name: Install SGLang | ||
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" | ||
|
||
- name: Install sentence_transformers | ||
run: pip install sentence_transformers | ||
|
||
- name: Run Integration Tests | ||
run: pytest -v app_tests/integration_tests/llm/sglang --log-cli-level=INFO |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import json | ||
import logging | ||
import os | ||
import pytest | ||
|
||
from ..utils import ( | ||
find_available_port, | ||
start_llm_server, | ||
download_with_hf_datasets, | ||
export_paged_llm_v1, | ||
compile_model, | ||
) | ||
|
||
pytest.importorskip("sglang") | ||
import sglang as sgl | ||
from sglang.lang.chat_template import get_chat_template | ||
|
||
pytest.importorskip("sentence_transformers") | ||
from sentence_transformers import SentenceTransformer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def register_shortfin_backend(available_port): | ||
backend = sgl.Shortfin( | ||
chat_template=get_chat_template("llama-3-instruct"), | ||
base_url=f"http://localhost:{available_port}", | ||
) | ||
sgl.set_default_backend(backend) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def pre_process_model(request, tmp_path_factory): | ||
device_settings = request.param["device_settings"] | ||
tmp_dir = tmp_path_factory.mktemp("sglang_integration_tests") | ||
|
||
# Download model | ||
model_params_path = tmp_dir / "meta-llama-3.1-8b-instruct.f16.gguf" | ||
download_with_hf_datasets(tmp_dir, "llama3_8B_fp16") | ||
|
||
# Export to mlir | ||
mlir_path = tmp_dir / "model.mlir" | ||
config_path = tmp_dir / "config.json" | ||
batch_sizes = [1, 4] | ||
export_paged_llm_v1( | ||
mlir_path, | ||
config_path, | ||
model_params_path, | ||
batch_sizes, | ||
) | ||
|
||
# Compile Model | ||
vmfb_path = tmp_dir / "model.vmfb" | ||
compile_model( | ||
mlir_path, | ||
vmfb_path, | ||
device_settings, | ||
) | ||
|
||
config = { | ||
"module_name": "module", | ||
"module_abi_version": 1, | ||
"max_seq_len": 131072, | ||
"attn_head_count": 8, | ||
"attn_head_dim": 128, | ||
"prefill_batch_sizes": [1, 4], | ||
"decode_batch_sizes": [1, 4], | ||
"transformer_block_count": 32, | ||
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, | ||
} | ||
config_path = tmp_dir / "config.json" | ||
with open(config_path, "w") as f: | ||
json.dump(config, f) | ||
|
||
return tmp_dir | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def available_port(): | ||
return find_available_port() | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def start_server(request, pre_process_model, available_port): | ||
os.environ["ROCR_VISIBLE_DEVICES"] = "1" | ||
device_settings = request.param["device_settings"] | ||
|
||
export_dir = pre_process_model | ||
|
||
tokenizer_path = export_dir / "tokenizer.json" | ||
model_params_path = export_dir / "meta-llama-3.1-8b-instruct.f16.gguf" | ||
vmfb_path = export_dir / "model.vmfb" | ||
config_path = export_dir / "config.json" | ||
|
||
logger.info("Starting server...") | ||
server_process = start_llm_server( | ||
available_port, | ||
tokenizer_path, | ||
config_path, | ||
vmfb_path, | ||
model_params_path, | ||
device_settings, | ||
timeout=30, | ||
) | ||
logger.info("Server started") | ||
|
||
yield server_process | ||
|
||
server_process.terminate() | ||
server_process.wait() | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def load_comparison_model(): | ||
model = SentenceTransformer("all-MiniLM-L6-v2") | ||
return model |
Oops, something went wrong.