Skip to content

Commit

Permalink
Add a simple script to eval TTFT and TPS of an inference API
Browse files Browse the repository at this point in the history
  • Loading branch information
utensil committed Feb 25, 2025
1 parent 7a87ad9 commit 793348e
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 18 deletions.
20 changes: 2 additions & 18 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -865,22 +865,6 @@ omni *PARAMS:
#!/usr/bin/env zsh
uvx mlx-omni-server {{PARAMS}}
# not working: run https://github.com/ray-project/llmperf using uvx
# also not easily working without a config python:
# https://opencompass.readthedocs.io/en/latest/user_guides/models.html#api-based-models

perf *PARAMS:
#!/usr/bin/env zsh
mkdir -p /tmp/opencompass
cd /tmp/opencompass
# if the file does not exist
if [ ! -f OpenCompassData-core-20240207.zip ]; then
wget https://github.com/open-compass/opencompass/releases/download/0.2.2.rc1/OpenCompassData-core-20240207.zip
fi
# if dir data does not exist
if [ ! -d data ]; then
unzip OpenCompassData-core-20240207.zip
fi
ls /tmp/opencompass
uvx --python 3.10 --from 'opencompass[api]' opencompass {{PARAMS}} # --models lms:deepseek-r1-distill-qwen-32b --datasets "opencompass/humaneval"
# run llm_perf.py using uv with package requests installed
uv run --with requests llm_perf.py {{PARAMS}}
104 changes: 104 additions & 0 deletions llm_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import time
import requests
from datetime import datetime
import os
import json

ENDPOINT_URL = os.getenv("OPENAI_API_BASE")
API_KEY = os.getenv("OPENAI_API_KEY")
MODEL = os.getenv("OPENAI_API_MODEL")

# PROMPT = "List 2 sights near but not in Paris."
PROMPT = "Give me only one pair of words that can be combined to form a new word. Finish your thinking in 50 words or less."

COUNT = 2

print(f"Endpoint URL: {ENDPOINT_URL}")
print(f"Model: {MODEL}")

# quit()


def benchmark_endpoint():
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}

payload = {
"model": MODEL,
"messages": [{"role": "user", "content": PROMPT}],
"stream": True,
"max_tokens": 10000,
}

metrics = {"ttft": [], "tps": [], "total_tokens": [], "total_time": []}

for i in range(COUNT):
try:
# 计时开始
start_time = time.perf_counter()
first_token_time = None
token_count = 0

print(f"请求 {i + 1}/{COUNT} 开始: {datetime.now().isoformat()}")

# 发送请求
response = requests.post(
f"{ENDPOINT_URL}chat/completions",
headers=headers,
json=payload,
stream=True,
timeout=30,
)

# 处理流式响应
for chunk in response.iter_lines():
if chunk:
decoded = chunk.decode().lstrip("data: ").strip()
# print(decoded)

if decoded == "[DONE]":
break

# 解析JSON
try:
data = json.loads(decoded)
if "choices" in data:
# 记录首令牌时间
if not first_token_time:
first_token_time = time.perf_counter()
ttft = first_token_time - start_time
metrics["ttft"].append(ttft)

# 统计令牌
delta = data["choices"][0]["delta"].get("content", "")
token_count += len(delta.split())

# print the variable delta, just not to start a new line
print(delta, end="")
except Exception as e:
print(f"解析错误: {e}")
continue

print("\n")

# 计算TPS
if first_token_time:
total_time = time.perf_counter() - first_token_time
tps = token_count / total_time if total_time > 0 else 0
metrics["tps"].append(tps)
metrics["total_tokens"].append(token_count)
metrics["total_time"].append(total_time)

except Exception as e:
print(f"请求失败: {str(e)}")
continue

# 打印结果
print(f"\n基准测试结果({COUNT}次请求):")
print(f"平均耗时: {sum(metrics['total_time']) / len(metrics['total_time']):.3f}s")
print(f"平均TTFT: {sum(metrics['ttft']) / len(metrics['ttft']):.3f}s")
print(f"平均TPS: {sum(metrics['tps']) / len(metrics['tps']):.1f} tokens/s")
print(f"总处理token数: {sum(metrics['total_tokens'])} tokens")


if __name__ == "__main__":
benchmark_endpoint()

0 comments on commit 793348e

Please sign in to comment.