Skip to content

Commit

Permalink
SGLang Integration + Accuracy Tests, Restructure app_tests/integratio…
Browse files Browse the repository at this point in the history
…n_tests (#570)

# Description

This PR implements integration tests for the Shortfin LLM Server w/ the
SGLang integration. It uses llama3-8b-instruct on GPU, which is
downloaded using sharktank's `hf_datasets` script.

The tests server two purposes:
1. Test that the SGLang integration works properly at a functional
level.
2. Test that the accuracy of the responses from the shortfin LLM server
are consistent.
    - We have a batch of candidate questions, with expected answers
- We have temperature set to `1.0`, so the responses should be
deterministic.
 
This test is intended to run every 4 hours, which allows for us to
detect degradations in shortfin LLM output accuracy. If we do get a
failure due to an accuracy degradation, there will only be a small set
of shark-ai/iree commits that could be responsible.
  • Loading branch information
stbaione authored Nov 19, 2024
1 parent a7feae8 commit ac17f86
Show file tree
Hide file tree
Showing 13 changed files with 575 additions and 3 deletions.
84 changes: 84 additions & 0 deletions .github/workflows/ci-sglang-integration-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: SGLang Llama Integration Tests

on:
workflow_dispatch:
schedule:
# Run periodically, every 4 hours. This is ran periodically with the
# intent of catching regressions early, and allowing for those
# regressions to be easily triaged to a small subset of commits.
- cron: '0 */4 * * *'

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit). The workflow name is prepended to avoid conflicts between
# different workflows.
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

jobs:
sglang_bench_serve:
name: "SGLang Integration Tests"
strategy:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
steps:
- name: Get Current Date
id: date
run: echo "::set-output name=date::$(date +'%Y-%m-%d')"

- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}

- name: Install pip deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
# Use newest possible releases to be able to track commits that may
# cause errors.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
iree-base-compiler \
iree-base-runtime \
"numpy<2.0"
- name: Install SGLang
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"

- name: Install sentence_transformers
run: pip install sentence_transformers

- name: Run Integration Tests
run: pytest -v app_tests/integration_tests/llm/sglang --log-cli-level=INFO
2 changes: 1 addition & 1 deletion .github/workflows/ci-shark-ai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ jobs:
iree-base-runtime
- name: Run LLM Integration Tests
run: pytest -v app_tests/integration_tests/llm --log-cli-level=INFO
run: pytest -v app_tests/integration_tests/llm/shortfin --log-cli-level=INFO
5 changes: 5 additions & 0 deletions app_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 changes: 5 additions & 0 deletions app_tests/benchmark_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 changes: 5 additions & 0 deletions app_tests/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 changes: 5 additions & 0 deletions app_tests/integration_tests/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 changes: 5 additions & 0 deletions app_tests/integration_tests/llm/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
123 changes: 123 additions & 0 deletions app_tests/integration_tests/llm/sglang/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import logging
import os
import pytest

from ..utils import (
find_available_port,
start_llm_server,
download_with_hf_datasets,
export_paged_llm_v1,
compile_model,
)

pytest.importorskip("sglang")
import sglang as sgl
from sglang.lang.chat_template import get_chat_template

pytest.importorskip("sentence_transformers")
from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def register_shortfin_backend(available_port):
backend = sgl.Shortfin(
chat_template=get_chat_template("llama-3-instruct"),
base_url=f"http://localhost:{available_port}",
)
sgl.set_default_backend(backend)


@pytest.fixture(scope="module")
def pre_process_model(request, tmp_path_factory):
device_settings = request.param["device_settings"]
tmp_dir = tmp_path_factory.mktemp("sglang_integration_tests")

# Download model
model_params_path = tmp_dir / "meta-llama-3.1-8b-instruct.f16.gguf"
download_with_hf_datasets(tmp_dir, "llama3_8B_fp16")

# Export to mlir
mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
batch_sizes = [1, 4]
export_paged_llm_v1(
mlir_path,
config_path,
model_params_path,
batch_sizes,
)

# Compile Model
vmfb_path = tmp_dir / "model.vmfb"
compile_model(
mlir_path,
vmfb_path,
device_settings,
)

config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 131072,
"attn_head_count": 8,
"attn_head_dim": 128,
"prefill_batch_sizes": [1, 4],
"decode_batch_sizes": [1, 4],
"transformer_block_count": 32,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}
config_path = tmp_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f)

return tmp_dir


@pytest.fixture(scope="module")
def available_port():
return find_available_port()


@pytest.fixture(scope="module")
def start_server(request, pre_process_model, available_port):
os.environ["ROCR_VISIBLE_DEVICES"] = "1"
device_settings = request.param["device_settings"]

export_dir = pre_process_model

tokenizer_path = export_dir / "tokenizer.json"
model_params_path = export_dir / "meta-llama-3.1-8b-instruct.f16.gguf"
vmfb_path = export_dir / "model.vmfb"
config_path = export_dir / "config.json"

logger.info("Starting server...")
server_process = start_llm_server(
available_port,
tokenizer_path,
config_path,
vmfb_path,
model_params_path,
device_settings,
timeout=30,
)
logger.info("Server started")

yield server_process

server_process.terminate()
server_process.wait()


@pytest.fixture(scope="module")
def load_comparison_model():
model = SentenceTransformer("all-MiniLM-L6-v2")
return model
Loading

0 comments on commit ac17f86

Please sign in to comment.