Skip to content

Commit

Permalink
Allow returning token_ids from server + streaming changes + add devic…
Browse files Browse the repository at this point in the history
…e_count export option (#884)

# Changes
1. Allow token_ids to be returned from server
2. Stream token-by-token and add rid for batch scenario
3. Add device_block_count option for longer prompt lengths

Really we need the `Standard API` task in #245, but this serves as a
temporary implementation, so that we can handle returning token_ids and
handle streaming for batch requests.

# Desciption

## 1

The [reference
implementation](https://github.com/mlcommons/inference/tree/master/language/llama3.1-405b)
for mlperf was already setup to process non-decoded tokens coming back
from the LLM server. However, currently shortfin doesn't have the
ability to return token_ids.

This adds an option for requesting raw token_ids back from the server.

For a normal request, you receive the bytes of the `json.dumps` of the
list:

```text
# Single
b'[21, 220, 22, 220, 23, 220, 24, 220, 605, 220, 806]'

# Batch
b'[[21, 220, 22, 220, 23, 220, 24, 220, 605, 220, 806], [845, 220, 1114, 220, 972, 220, 777, 220, 508, 220, 1691]]'
```

## 2

I also change the way that we stream. Currently, we do this:

```text
1. b'data: Hello\n\n'
5. b'data: Hello how'
6. b'data: Hello how are you'
7. b'data: Hello how are you today?'
```

Where, each time we return a response to the stream, we repeat the
entire contents of the response. This is pretty inefficient, especially
if we were submitting a request with a high `max_completion_tokens`.

Instead we stream token-by-token, like this:

```text
# Text
1. b'data(rid1): Hello'
2. b'data(rid1): how'
3. b'data(rid1): are you'
4. b'data(rid1): today?'

# Tokens
1. data(rid1): 21\n\n
2. data(rid1): 220\n\n
3. data(rid1): 22\n\n
4. data(rid1): 220\n\n
```

## Why data(rid):?

Another issue that was already present was that, we didn't have a way to
tell which response aligned with which prompt when streaming from the
server. For each chunk coming back from the server, you would either
receive a token from response 1 or a token from response 2, but can't
know which token aligns with which request.

As a simple patch for this, I add the `rid` for streaming. So we
receive:

```
b'data(rid1):  I\n\n'
b'data(rid2):  I'm\n\n'
b'data(rid1):  am\n\n'
b'data(rid2):  sure\n\n'
```

## 3

Finally, I add an option in `export_paged_llm_v1` to make the
`device_block_count` configurable. This is needed for longer input
prompts.
  • Loading branch information
stbaione authored Jan 29, 2025
1 parent 2e27a97 commit 7c273b5
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
2 changes: 1 addition & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def generate_params_json(
"paged_kv_cache": {
"attention_head_count_kv": hp.attention_head_count_kv,
"block_seq_stride": llama_config.block_seq_stride,
"device_block_count": 256, # so that this makes its way into the config file & can be edited.
"device_block_count": args.device_block_count, # so that this makes its way into the config file & can be edited.
},
}

Expand Down
6 changes: 6 additions & 0 deletions sharktank/sharktank/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def add_model_options(parser: argparse.ArgumentParser):
type=int,
default=32,
)
parser.add_argument(
"--device-block-count",
help="Block per device for paged KV cache",
type=int,
default=512,
)


def add_quantization_options(parser: argparse.ArgumentParser):
Expand Down
52 changes: 38 additions & 14 deletions shortfin/python/shortfin_apps/llm/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import asyncio
import io
import json
import logging

import shortfin as sf
Expand Down Expand Up @@ -48,6 +49,8 @@ def __init__(
self.max_completion_tokens = max_completion_tokens
self.eos_token_id = eos_token_id

self.streamed_tokens_index = 0

async def run(self):
exec = InferenceExecRequest(
phase=InferencePhase.PREFILL,
Expand Down Expand Up @@ -133,14 +136,17 @@ async def run(self):
else:
input_batch = self.tokenize()
for index, input_tokens in enumerate(input_batch):
max_completion_tokens = (
self.gen_req.sampling_params["max_completion_tokens"]
if self.gen_req.is_single
else self.gen_req.sampling_params[index]["max_completion_tokens"]
)
gen_process = GenerateItemProcess(
self,
self.gen_req,
index,
input_tokens if is_pretokenized else input_tokens.ids,
max_completion_tokens=self.gen_req.sampling_params[
"max_completion_tokens"
],
max_completion_tokens=max_completion_tokens,
eos_token_id=self.tokenizer.eos_token_id,
)
gen_processes.append(gen_process)
Expand All @@ -155,26 +161,44 @@ async def run(self):
else:
logging.debug("Responding to one shot batch")
out = io.BytesIO()
result_texts = self.tokenizer.decode(
[p.result_token_ids for p in gen_processes]
)
for result_text in result_texts:
out.write(b"data: ")
out.write(result_text.encode())
out.write(b"\n\n")
result_tokens = [p.result_token_ids for p in gen_processes]
if self.gen_req.return_input_ids:
if self.gen_req.is_single:
result_tokens = result_tokens[0]
out.write(bytes(json.dumps(result_tokens), "utf-8"))
else:
result_texts = self.tokenizer.decode(result_tokens)
for result_text in result_texts:
out.write(b"data: ")
out.write(result_text.encode())
out.write(b"\n\n")
self.responder.send_response(out.getvalue())
finally:
self.responder.ensure_response()

def stream_results(self, gen_process: GenerateItemProcess):
if not self.gen_req.stream:
return
(result_text,) = self.tokenizer.decode([gen_process.result_token_ids])
out = io.BytesIO()
out.write(b"data: ")
out.write(result_text.encode())
out.write(b"\n\n")
result_tokens = gen_process.result_token_ids[
gen_process.streamed_tokens_index :
]
rid = (
gen_process.gen_req.rid
if gen_process.gen_req.is_single
else gen_process.gen_req.rid[gen_process.index]
)
if not self.gen_req.return_input_ids:
(result_text,) = self.tokenizer.decode([result_tokens])
out.write(f"data({rid}): ".encode())
out.write(result_text.encode())
out.write(b"\n\n")
else:
out.write(f"data({rid}): ".encode())
out.write(str(result_tokens[0]).encode())
out.write(b"\n\n")
self.responder.stream_part(out.getvalue())
gen_process.streamed_tokens_index += len(result_tokens)

def tokenize(self) -> list[Encoding]:
gen_req = self.gen_req
Expand Down
4 changes: 3 additions & 1 deletion shortfin/python/shortfin_apps/llm/components/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
sglang: Copyright 2023-2024 SGLang Team, Licensed under the Apache License, Version 2.0
"""

from typing import Dict, List, Optional, Union
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import uuid


Expand All @@ -31,6 +31,8 @@ class GenerateReqInput:
sampling_params: Union[List[Dict], Dict] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
# Whether to decode the response before returning it.
return_input_ids: bool = False
# Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None
# If return logprobs, the start location in the prompt for returning logprobs.
Expand Down

0 comments on commit 7c273b5

Please sign in to comment.