Skip to content

Commit

Permalink
Merge fixtures: available_port into llm_server (#758)
Browse files Browse the repository at this point in the history
The available_ports fixture seems to rely on the servers booting very
quickly to work. Merging it into the llm_server & increasing the server
startup timeout fixes the remaining CI failures caused by updating IREE
to 1220.
  • Loading branch information
renxida authored Jan 6, 2025
1 parent f10cc3a commit 922b1c2
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-shark-ai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
- name: Run LLM Integration Tests
run: |
source ${VENV_DIR}/bin/activate
pytest -v app_tests/integration_tests/llm/shortfin --log-cli-level=INFO
pytest -v -s app_tests/integration_tests/llm/shortfin --log-cli-level=INFO
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def test_shortfin_benchmark(
model_path = tmp_dir / model_param_file_name

# Start shortfin llm server
port = find_available_port()
server_process = start_llm_server(
port,
server_process, port = start_llm_server(
tokenizer_path,
config_path,
vmfb_path,
Expand Down
17 changes: 5 additions & 12 deletions app_tests/integration_tests/llm/sglang/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def register_shortfin_backend(available_port):
def register_shortfin_backend(port):
backend = sgl.Shortfin(
chat_template=get_chat_template("llama-3-instruct"),
base_url=f"http://localhost:{available_port}",
base_url=f"http://localhost:{port}",
)
sgl.set_default_backend(backend)

Expand Down Expand Up @@ -68,12 +67,7 @@ def pre_process_model(request, tmp_path_factory):


@pytest.fixture(scope="module")
def available_port():
return find_available_port()


@pytest.fixture(scope="module")
def start_server(request, pre_process_model, available_port):
def start_server(request, pre_process_model):
os.environ["ROCR_VISIBLE_DEVICES"] = "1"
device_settings = request.param["device_settings"]

Expand All @@ -85,8 +79,7 @@ def start_server(request, pre_process_model, available_port):
config_path = export_dir / "config.json"

logger.info("Starting server...")
server_process = start_llm_server(
available_port,
server_process, port = start_llm_server(
tokenizer_path,
config_path,
vmfb_path,
Expand All @@ -96,7 +89,7 @@ def start_server(request, pre_process_model, available_port):
)
logger.info("Server started")

yield server_process
yield server_process, port

server_process.terminate()
server_process.wait()
Expand Down
30 changes: 22 additions & 8 deletions app_tests/integration_tests/llm/sglang/sglang_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
ACCEPTED_THRESHOLD = 0.7


def register_shortfin_backend(port):
backend = sgl.Shortfin(
chat_template=get_chat_template("llama-3-instruct"),
base_url=f"http://localhost:{port}",
)
sgl.set_default_backend(backend)


def compute_similarity(model: SentenceTransformer, sentence_1: str, sentence_2: str):
embeddings = model.encode([sentence_1, sentence_2])
return util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
Expand Down Expand Up @@ -72,7 +80,10 @@ def tip_suggestion(s):
],
indirect=True,
)
def test_multi_turn_qa(load_comparison_model, start_server, register_shortfin_backend):
def test_multi_turn_qa(load_comparison_model, start_server):
server, port = start_server
register_shortfin_backend(port)

model = load_comparison_model

question_1 = "Name the capital city of the USA."
Expand Down Expand Up @@ -130,9 +141,10 @@ def test_multi_turn_qa(load_comparison_model, start_server, register_shortfin_ba
],
indirect=True,
)
def test_stream_multi_turn_qa(
load_comparison_model, start_server, register_shortfin_backend
):
def test_stream_multi_turn_qa(load_comparison_model, start_server):
server, port = start_server
register_shortfin_backend(port)

def clean_message(message: str):
"""Remove chat tags from message before comparison.
Expand Down Expand Up @@ -183,9 +195,9 @@ def clean_message(message: str):
],
indirect=True,
)
def test_batch_multi_turn_qa(
load_comparison_model, start_server, register_shortfin_backend
):
def test_batch_multi_turn_qa(load_comparison_model, start_server):
server, port = start_server
register_shortfin_backend(port)
model = load_comparison_model

question_1_1 = "Name the capital city of the USA."
Expand Down Expand Up @@ -287,7 +299,9 @@ def test_batch_multi_turn_qa(
],
indirect=True,
)
def test_fork(load_comparison_model, start_server, register_shortfin_backend):
def test_fork(load_comparison_model, start_server):
server, port = start_server
register_shortfin_backend(port)
model = load_comparison_model

logger.info("Testing fork...")
Expand Down
14 changes: 4 additions & 10 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,14 @@ def write_config(request, model_test_dir):


@pytest.fixture(scope="module")
def available_port():
return find_available_port()


@pytest.fixture(scope="module")
def llm_server(request, model_test_dir, write_config, available_port):
def llm_server(request, model_test_dir, write_config):
"""Start the LLM server.
Args:
request (FixtureRequest): The following params are accepted:
- model_file (str): The model file to download.
- settings (dict): The settings for starting the server.
model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir.
available_port (int): The available port to start the server on.
Yields:
subprocess.Popen: The server process that was started.
Expand All @@ -150,16 +144,16 @@ def llm_server(request, model_test_dir, write_config, available_port):
parameters_path = tmp_dir / model_file

# Start llm server
server_process = start_llm_server(
available_port,
server_process, port = start_llm_server(
tokenizer_path,
config_path,
vmfb_path,
parameters_path,
settings,
timeout=60,
)
logger.info("LLM server started!" + end_log_group())
yield server_process
yield server_process, port
# Teardown: kill the server
server_process.terminate()
server_process.wait()
15 changes: 8 additions & 7 deletions app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,18 @@ def do_generate(prompt, port, concurrent_requests=1):
],
indirect=True,
)
def test_llm_server(llm_server, available_port):
def test_llm_server(llm_server):
# Here you would typically make requests to your server
# and assert on the responses
assert llm_server.poll() is None
server, port = llm_server
assert server.poll() is None
PROMPT = "1 2 3 4 5 "
expected_output_prefix = "6 7 8"
logger.info(
"Sending HTTP Generation Request"
+ start_log_group("Sending HTTP Generation Request")
)
output = do_generate(PROMPT, available_port)[0]
output = do_generate(PROMPT, port)[0]
# log to GITHUB_STEP_SUMMARY if we are in a GitHub Action
if "GITHUB_ACTION" in os.environ:
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
Expand Down Expand Up @@ -170,17 +171,17 @@ def test_llm_server(llm_server, available_port):
raises=AccuracyValidationException,
reason="Concurreny issues in Shortfin batch processing",
)
def test_llm_server_concurrent(llm_server, available_port, concurrent_requests):
def test_llm_server_concurrent(llm_server, concurrent_requests):
logger.info("Testing concurrent invocations")

assert llm_server.poll() is None
server, port = llm_server
assert server.poll() is None
PROMPT = "1 2 3 4 5 "
expected_output_prefix = "6 7 8"
logger.info(
"Sending HTTP Generation Request"
+ start_log_group("Sending HTTP Generation Request")
)
outputs = do_generate(PROMPT, available_port, concurrent_requests)
outputs = do_generate(PROMPT, port, concurrent_requests)

for output in outputs:
# log to GITHUB_STEP_SUMMARY if we are in a GitHub Action
Expand Down
19 changes: 12 additions & 7 deletions app_tests/integration_tests/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,17 +129,22 @@ def find_available_port():
return port


def wait_for_server(url, timeout=10):
def wait_for_server(url, timeout):
logger.info(f"Waiting for server to start at {url}...")
start = time.time()
while time.time() - start < timeout:
elapsed = 0
while elapsed <= timeout:
try:
requests.get(f"{url}/health")
logger.info("Server successfully started")
return
except requests.exceptions.ConnectionError:
logger.info(
f"Server has not started yet; waited {elapsed} seconds; timeout: {timeout} seconds."
)
time.sleep(1)
raise TimeoutError(f"Server did not start within {timeout} seconds")
elapsed = time.time() - start
raise TimeoutError(f"Server did not start within {timeout} seconds at {url}")


def _start_llm_server_args(
Expand All @@ -164,16 +169,16 @@ def _start_llm_server_args(


def start_llm_server(
port,
tokenizer_path,
model_config_path,
vmfb_path,
parameters_path,
settings,
timeout=10,
timeout,
multi=False,
):
logger.info("Starting LLM server...")
port = find_available_port()
logger.info(f"Starting LLM server on port {port}...")
if multi:
server_process = multiprocessing.Process(
target=subprocess.Popen(
Expand Down Expand Up @@ -204,7 +209,7 @@ def start_llm_server(
logger.info("Process started... waiting for server")
# Wait for server to start
wait_for_server(f"http://localhost:{port}", timeout)
return server_process
return server_process, port


def start_log_group(headline):
Expand Down

0 comments on commit 922b1c2

Please sign in to comment.