Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow returning token_ids from server + streaming changes + add devic…
…e_count export option (#884) # Changes 1. Allow token_ids to be returned from server 2. Stream token-by-token and add rid for batch scenario 3. Add device_block_count option for longer prompt lengths Really we need the `Standard API` task in #245, but this serves as a temporary implementation, so that we can handle returning token_ids and handle streaming for batch requests. # Desciption ## 1 The [reference implementation](https://github.com/mlcommons/inference/tree/master/language/llama3.1-405b) for mlperf was already setup to process non-decoded tokens coming back from the LLM server. However, currently shortfin doesn't have the ability to return token_ids. This adds an option for requesting raw token_ids back from the server. For a normal request, you receive the bytes of the `json.dumps` of the list: ```text # Single b'[21, 220, 22, 220, 23, 220, 24, 220, 605, 220, 806]' # Batch b'[[21, 220, 22, 220, 23, 220, 24, 220, 605, 220, 806], [845, 220, 1114, 220, 972, 220, 777, 220, 508, 220, 1691]]' ``` ## 2 I also change the way that we stream. Currently, we do this: ```text 1. b'data: Hello\n\n' 5. b'data: Hello how' 6. b'data: Hello how are you' 7. b'data: Hello how are you today?' ``` Where, each time we return a response to the stream, we repeat the entire contents of the response. This is pretty inefficient, especially if we were submitting a request with a high `max_completion_tokens`. Instead we stream token-by-token, like this: ```text # Text 1. b'data(rid1): Hello' 2. b'data(rid1): how' 3. b'data(rid1): are you' 4. b'data(rid1): today?' # Tokens 1. data(rid1): 21\n\n 2. data(rid1): 220\n\n 3. data(rid1): 22\n\n 4. data(rid1): 220\n\n ``` ## Why data(rid):? Another issue that was already present was that, we didn't have a way to tell which response aligned with which prompt when streaming from the server. For each chunk coming back from the server, you would either receive a token from response 1 or a token from response 2, but can't know which token aligns with which request. As a simple patch for this, I add the `rid` for streaming. So we receive: ``` b'data(rid1): I\n\n' b'data(rid2): I'm\n\n' b'data(rid1): am\n\n' b'data(rid2): sure\n\n' ``` ## 3 Finally, I add an option in `export_paged_llm_v1` to make the `device_block_count` configurable. This is needed for longer input prompts.
- Loading branch information