Skip to content

Commit

Permalink
rework find available prot to nolonger be a fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Jan 6, 2025
1 parent 0c6aba6 commit 164dddb
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def test_shortfin_benchmark(
model_path = tmp_dir / model_param_file_name

# Start shortfin llm server
port = find_available_port()
server_process = start_llm_server(
port,
server_process, port = start_llm_server(
tokenizer_path,
config_path,
vmfb_path,
Expand Down
12 changes: 5 additions & 7 deletions app_tests/integration_tests/llm/sglang/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def register_shortfin_backend(available_port):
def register_shortfin_backend(port):
backend = sgl.Shortfin(
chat_template=get_chat_template("llama-3-instruct"),
base_url=f"http://localhost:{available_port}",
base_url=f"http://localhost:{port}",
)
sgl.set_default_backend(backend)

Expand Down Expand Up @@ -73,7 +72,7 @@ def available_port():


@pytest.fixture(scope="module")
def start_server(request, pre_process_model, available_port):
def start_server(request, pre_process_model):
os.environ["ROCR_VISIBLE_DEVICES"] = "1"
device_settings = request.param["device_settings"]

Expand All @@ -85,8 +84,7 @@ def start_server(request, pre_process_model, available_port):
config_path = export_dir / "config.json"

logger.info("Starting server...")
server_process = start_llm_server(
available_port,
server_process, port = start_llm_server(
tokenizer_path,
config_path,
vmfb_path,
Expand All @@ -96,7 +94,7 @@ def start_server(request, pre_process_model, available_port):
)
logger.info("Server started")

yield server_process
yield server_process, port

server_process.terminate()
server_process.wait()
Expand Down
30 changes: 22 additions & 8 deletions app_tests/integration_tests/llm/sglang/sglang_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
ACCEPTED_THRESHOLD = 0.7


def register_shortfin_backend(port):
backend = sgl.Shortfin(
chat_template=get_chat_template("llama-3-instruct"),
base_url=f"http://localhost:{port}",
)
sgl.set_default_backend(backend)


def compute_similarity(model: SentenceTransformer, sentence_1: str, sentence_2: str):
embeddings = model.encode([sentence_1, sentence_2])
return util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
Expand Down Expand Up @@ -72,7 +80,10 @@ def tip_suggestion(s):
],
indirect=True,
)
def test_multi_turn_qa(load_comparison_model, start_server, register_shortfin_backend):
def test_multi_turn_qa(load_comparison_model, start_server):
server, port = start_server
register_shortfin_backend(port)

model = load_comparison_model

question_1 = "Name the capital city of the USA."
Expand Down Expand Up @@ -130,9 +141,10 @@ def test_multi_turn_qa(load_comparison_model, start_server, register_shortfin_ba
],
indirect=True,
)
def test_stream_multi_turn_qa(
load_comparison_model, start_server, register_shortfin_backend
):
def test_stream_multi_turn_qa(load_comparison_model, start_server):
server, port = start_server
register_shortfin_backend(port)

def clean_message(message: str):
"""Remove chat tags from message before comparison.
Expand Down Expand Up @@ -183,9 +195,9 @@ def clean_message(message: str):
],
indirect=True,
)
def test_batch_multi_turn_qa(
load_comparison_model, start_server, register_shortfin_backend
):
def test_batch_multi_turn_qa(load_comparison_model, start_server):
server, port = start_server
register_shortfin_backend(port)
model = load_comparison_model

question_1_1 = "Name the capital city of the USA."
Expand Down Expand Up @@ -287,7 +299,9 @@ def test_batch_multi_turn_qa(
],
indirect=True,
)
def test_fork(load_comparison_model, start_server, register_shortfin_backend):
def test_fork(load_comparison_model, start_server):
server, port = start_server
register_shortfin_backend(port)
model = load_comparison_model

logger.info("Testing fork...")
Expand Down
5 changes: 2 additions & 3 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,15 @@ def llm_server(request, model_test_dir, write_config, available_port):
parameters_path = tmp_dir / model_file

# Start llm server
server_process = start_llm_server(
available_port,
server_process, port = start_llm_server(
tokenizer_path,
config_path,
vmfb_path,
parameters_path,
settings,
)
logger.info("LLM server started!" + end_log_group())
yield server_process
yield server_process, port
# Teardown: kill the server
server_process.terminate()
server_process.wait()
5 changes: 3 additions & 2 deletions app_tests/integration_tests/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def start_llm_server(
timeout=10,
multi=False,
):
logger.info("Starting LLM server...")
port = find_available_port()
logger.info(f"Starting LLM server on port {port}...")
if multi:
server_process = multiprocessing.Process(
target=subprocess.Popen(
Expand Down Expand Up @@ -204,7 +205,7 @@ def start_llm_server(
logger.info("Process started... waiting for server")
# Wait for server to start
wait_for_server(f"http://localhost:{port}", timeout)
return server_process
return server_process, port


def start_log_group(headline):
Expand Down

0 comments on commit 164dddb

Please sign in to comment.