Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SGLang Integration + Accuracy Tests, Restructure app_tests/integration_tests #570

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implement sglang integration tests,
Restructure app_tests/integration_tests,
Add copyright headers to files in integration_tests that were missing it
stbaione authored and renxida committed Nov 19, 2024
commit 808511ec8cf6232b7363c0ffdc6be53e4807088c
87 changes: 87 additions & 0 deletions .github/workflows/ci-sglang-integration-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

name: SGLang Llama Integration Tests

on:
workflow_dispatch:
# TODO: Remove after validating action
pull_request:
schedule:
# Run periodically, every 4 hours. This is ran periodically with the
# intent of catching regressions early, and allowing for those
# regressions to be easily triaged to a small subset of commits.
- cron: '0 */4 * * *'

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit). The workflow name is prepended to avoid conflicts between
# different workflows.
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

jobs:
sglang_bench_serve:
name: "SGLang Integration Tests"
strategy:
matrix:
version: [3.11]
fail-fast: false
runs-on: llama-mi300x-3
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
steps:
- name: Get Current Date
id: date
run: echo "::set-output name=date::$(date +'%Y-%m-%d')"

- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}

- name: Install pip deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
# Try with the latest nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
iree-base-compiler==3.0.0rc20241115 \
iree-base-runtime==3.0.0rc20241115 \
"numpy<2.0"
- name: Install SGLang
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"

- name: Install sentence_transformers
run: pip install sentence_transformers

- name: Run Integration Tests
run: pytest -v app_tests/integration_tests/sglang --log-cli-level=INFO
2 changes: 1 addition & 1 deletion .github/workflows/ci-shark-ai.yml
Original file line number Diff line number Diff line change
@@ -72,4 +72,4 @@ jobs:
iree-base-runtime
- name: Run LLM Integration Tests
run: pytest -v app_tests/integration_tests/llm --log-cli-level=INFO
run: pytest -v app_tests/integration_tests/llm/shortfin --log-cli-level=INFO
5 changes: 5 additions & 0 deletions app_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 changes: 5 additions & 0 deletions app_tests/benchmark_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 changes: 5 additions & 0 deletions app_tests/integration_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 changes: 5 additions & 0 deletions app_tests/integration_tests/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 changes: 5 additions & 0 deletions app_tests/integration_tests/llm/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
123 changes: 123 additions & 0 deletions app_tests/integration_tests/llm/sglang/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import json
import logging
import os
import pytest

from ..utils import (
find_available_port,
start_llm_server,
download_with_hf_datasets,
export_paged_llm_v1,
compile_model,
)

pytest.importorskip("sglang")
import sglang as sgl
from sglang.lang.chat_template import get_chat_template

pytest.importorskip("sentence_transformers")
from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def register_shortfin_backend(available_port):
backend = sgl.Shortfin(
chat_template=get_chat_template("llama-3-instruct"),
base_url=f"http://localhost:{available_port}",
)
sgl.set_default_backend(backend)


@pytest.fixture(scope="module")
def pre_process_model(request, tmp_path_factory):
device_settings = request.param["device_settings"]
tmp_dir = tmp_path_factory.mktemp("sglang_integration_tests")

# Download model
model_params_path = tmp_dir / "meta-llama-3.1-8b-instruct.f16.gguf"
download_with_hf_datasets(tmp_dir, "llama3_8B_fp16")

# Export to mlir
mlir_path = tmp_dir / "model.mlir"
config_path = tmp_dir / "config.json"
batch_sizes = [1, 4]
export_paged_llm_v1(
mlir_path,
config_path,
model_params_path,
batch_sizes,
)

# Compile Model
vmfb_path = tmp_dir / "model.vmfb"
compile_model(
mlir_path,
vmfb_path,
device_settings,
)

config = {
"module_name": "module",
"module_abi_version": 1,
"max_seq_len": 131072,
"attn_head_count": 8,
"attn_head_dim": 128,
"prefill_batch_sizes": [1, 4],
"decode_batch_sizes": [1, 4],
"transformer_block_count": 32,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
}
config_path = tmp_dir / "config.json"
with open(config_path, "w") as f:
json.dump(config, f)

return tmp_dir


@pytest.fixture(scope="module")
def available_port():
return find_available_port()


@pytest.fixture(scope="module")
def start_server(request, pre_process_model, available_port):
os.environ["ROCR_VISIBLE_DEVICES"] = "1"
device_settings = request.param["device_settings"]

export_dir = pre_process_model

tokenizer_path = export_dir / "tokenizer.json"
model_params_path = export_dir / "meta-llama-3.1-8b-instruct.f16.gguf"
vmfb_path = export_dir / "model.vmfb"
config_path = export_dir / "config.json"

logger.info("Starting server...")
server_process = start_llm_server(
available_port,
tokenizer_path,
config_path,
vmfb_path,
model_params_path,
device_settings,
timeout=30,
)
logger.info("Server started")

yield server_process

server_process.terminate()
server_process.wait()


@pytest.fixture(scope="module")
def load_comparison_model():
model = SentenceTransformer("all-MiniLM-L6-v2")
return model
Loading