Skip to content

Commit

Permalink
[Frontend] Support embeddings in the run_batch API (vllm-project#7132)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Mo <[email protected]>
  • Loading branch information
pooyadavoodi and simon-mo authored Aug 9, 2024
1 parent 74af2bb commit 249b882
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 24 deletions.
41 changes: 37 additions & 4 deletions examples/offline_inference_openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details.

**NOTE:** We currently only support to `/v1/chat/completions` endpoint (embeddings and completions coming soon).
**NOTE:** We currently only support `/v1/chat/completions` and `/v1/embeddings` endpoints (completions coming soon).

## Pre-requisites

Expand All @@ -21,7 +21,7 @@
- Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions.


## Example: Running with a local file
## Example 1: Running with a local file

### Step 1: Create your batch file

Expand Down Expand Up @@ -54,7 +54,7 @@ python -m vllm.entrypoints.openai.run_batch -i openai_example_batch.jsonl -o res
You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl`

```
$ cat ../results.jsonl
$ cat results.jsonl
{"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null}
{"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null}
```
Expand Down Expand Up @@ -107,7 +107,7 @@ aws s3 cp openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl

### Step 2: Generate your presigned urls

Presigned put urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.
Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.

(The script is adapted from https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py)

Expand Down Expand Up @@ -170,3 +170,36 @@ Your results are now on S3. You can view them in your terminal by running
```
aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl -
```

## Example 4: Using embeddings endpoint

### Additional prerequisites

* Ensure you are using `vllm >= 0.5.5`.

### Step 1: Create your batch file

Add embedding requests to your batch file. The following is an example:

```
{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
```

You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model).


### Step 2: Run the batch

You can run the batch using the same command as in earlier examples.


### Step 3: Check your results

You can check your results by running `cat results.jsonl`

```
$ cat results.jsonl
{"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null}
...```
```
52 changes: 50 additions & 2 deletions tests/entrypoints/openai/test_run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,39 @@
# ruff: noqa: E501
INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NonExistModel", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""

INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}"""

INPUT_EMBEDDING_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "Hello world!"}}
{"custom_id": "request-4", "method": "POST", "url": "/v1/embeddings", "body": {"model": "NonExistModel", "input": "Hello world!"}}"""


def test_e2e():
def test_empty_file():
with tempfile.NamedTemporaryFile(
"w") as input_file, tempfile.NamedTemporaryFile(
"r") as output_file:
input_file.write("")
input_file.flush()
proc = subprocess.Popen([
sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i",
input_file.name, "-o", output_file.name, "--model",
"intfloat/e5-mistral-7b-instruct"
], )
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"

contents = output_file.read()
assert contents.strip() == ""


def test_completions():
with tempfile.NamedTemporaryFile(
"w") as input_file, tempfile.NamedTemporaryFile(
"r") as output_file:
Expand All @@ -35,7 +61,7 @@ def test_e2e():
BatchRequestOutput.model_validate_json(line)


def test_e2e_invalid_input():
def test_completions_invalid_input():
"""
Ensure that we fail when the input doesn't conform to the openai api.
"""
Expand All @@ -52,3 +78,25 @@ def test_e2e_invalid_input():
proc.communicate()
proc.wait()
assert proc.returncode != 0, f"{proc=}"


def test_embeddings():
with tempfile.NamedTemporaryFile(
"w") as input_file, tempfile.NamedTemporaryFile(
"r") as output_file:
input_file.write(INPUT_EMBEDDING_BATCH)
input_file.flush()
proc = subprocess.Popen([
sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i",
input_file.name, "-o", output_file.name, "--model",
"intfloat/e5-mistral-7b-instruct"
], )
proc.communicate()
proc.wait()
assert proc.returncode == 0, f"{proc=}"

contents = output_file.read()
for line in contents.strip().split("\n"):
# Ensure that the output format conforms to the openai api.
# Validation should throw if the schema is wrong.
BatchRequestOutput.model_validate_json(line)
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ class BatchRequestInput(OpenAIBaseModel):
url: str

# The parameters of the request.
body: ChatCompletionRequest
body: Union[ChatCompletionRequest, EmbeddingRequest]


class BatchResponseData(OpenAIBaseModel):
Expand All @@ -683,7 +683,7 @@ class BatchResponseData(OpenAIBaseModel):
request_id: str

# The body of the response.
body: Optional[ChatCompletionResponse] = None
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None


class BatchRequestOutput(OpenAIBaseModel):
Expand Down
48 changes: 37 additions & 11 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import asyncio
from io import StringIO
from typing import Awaitable, List
from typing import Awaitable, Callable, List

import aiohttp

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
ErrorResponse)
EmbeddingResponse, ErrorResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
Expand Down Expand Up @@ -82,27 +85,26 @@ async def write_file(path_or_url: str, data: str) -> None:
f.write(data)


async def run_request(chat_serving: OpenAIServingChat,
async def run_request(serving_engine_func: Callable,
request: BatchRequestInput) -> BatchRequestOutput:
chat_request = request.body
chat_response = await chat_serving.create_chat_completion(chat_request)
response = await serving_engine_func(request.body)

if isinstance(chat_response, ChatCompletionResponse):
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
body=chat_response, request_id=f"vllm-batch-{random_uuid()}"),
body=response, request_id=f"vllm-batch-{random_uuid()}"),
error=None,
)
elif isinstance(chat_response, ErrorResponse):
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=chat_response.code,
status_code=response.code,
request_id=f"vllm-batch-{random_uuid()}"),
error=chat_response,
error=response,
)
else:
raise ValueError("Request must not be sent in stream mode")
Expand All @@ -128,6 +130,7 @@ async def main(args):
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)

# Create the openai serving objects.
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
Expand All @@ -138,12 +141,35 @@ async def main(args):
request_logger=request_logger,
chat_template=None,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
served_model_names,
request_logger=request_logger,
)

# Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue

request = BatchRequestInput.model_validate_json(request_json)
response_futures.append(run_request(openai_serving_chat, request))

# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
response_futures.append(
run_request(openai_serving_chat.create_chat_completion,
request))
elif request.url == "/v1/embeddings":
response_futures.append(
run_request(openai_serving_embedding.create_embedding,
request))
else:
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint.")

responses = await asyncio.gather(*response_futures)

Expand Down
17 changes: 12 additions & 5 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import base64
import time
from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
Union, cast)

import numpy as np
from fastapi import Request
Expand All @@ -11,7 +12,8 @@
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
Expand Down Expand Up @@ -71,8 +73,11 @@ def __init__(
request_logger=request_logger)
self._check_embedding_mode(model_config.embedding_mode)

async def create_embedding(self, request: EmbeddingRequest,
raw_request: Request):
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None
) -> Union[ErrorResponse, EmbeddingResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
Expand Down Expand Up @@ -140,7 +145,9 @@ async def create_embedding(self, request: EmbeddingRequest,

result_generator: AsyncIterator[Tuple[
int, EmbeddingRequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
*generators,
is_cancelled=raw_request.is_disconnected
if raw_request else None)

# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
Expand Down

0 comments on commit 249b882

Please sign in to comment.