-
Notifications
You must be signed in to change notification settings - Fork 799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support concurrent requests to a single model instance #522
Comments
Thanks for reporting this! I'll look into fixing this, perhaps blocking subsequent However, if you instantiate multiple engines, two requests can be processed concurrently. We will soon support having multiple models loaded in a single engine, so in that case same thing applies. |
Thank you. I don't strictly need this to run in parallel (though that would be nice). The concurrency bug is very nonintuitive though and worth fixing. I did some investigation and this is a gross start, but I think if you separate out Some sort of key to identify the specific request and keep track of its specific |
A model cannot handle > 1 concurrent request (e.g. >1 calls to `chat.completions.create()`) since we do not support continuous batching, and each request requires its own resources such as the KV cache. (Though "concurrent" requests to different models in the same engine is supported) As a result, as pointed out in #522, when users try something like the following code: ```typescript const engine = await CreateMLCEngine("Phi-3-mini-4k-instruct-q4f16_1-MLC") async function sendRequest() { const reply = await engine.chat.completions.create({ messages: [{ role: "user", content: "Hello!" }], max_tokens: 64, }); console.log(reply.choices[0].message.content); } await Promise.all([sendRequest(), sendRequest()]); ``` the model's state and the generation result are messed up. To resolve this, we implement `CustomLock` using Promise, maintaining a queue to ensure FCFS for incoming requests to a model, such that for a single model, a request only starts when all previous requests are finished. The code above now works. ### Implementation Details - We add `loadedModelIdToLock` to MLCEngine, maintaining a lock for each loaded engine - Reminder: the need for a critical section is only per model, since each loaded model has its own `LLMChatPipeline` / `EmbeddingPipeline` - `loadedModelIdToLock` is cleared in `unload()`, set in `reloadInternal()` - We acquire lock at the very beginning of `completion()`, `chatCompletion()` and `embedding()`, after knowing which model this current call will use - We release lock at the end of `embedding()`, `completion()` and `chatCompletion()` (for non-streaming cases), and `asyncGenerate()` (for streaming cases) - Since we also want to release the lock when errors occur, we wrap the code with a big `try` `finally` - Since `asyncGenerate()` is an async generator, we add `try` `catch` fine-grainedly, only in places that can throw errors - This makes the code less readable, but not sure if there is a better solution. - For WebWorkerMLCEngine, no special handling is needed, since the WebWorkerMLCEngineHandler calls the underlying engine's APIs (e.g. `chatCompletion()`), which will block ### Tested - Tested `CustomLock` implementation with unit test (implementation follows [this blog post](https://jackpordi.com/posts/locks-in-js-because-why-not)) - Above example now works - [get-started, get-started-web-worker] x [streaming, non-streaming] x [concurrent requests, single request] - examples/simple-chat-ts - examples/multi-models - WebLLMChat (with generation interrupts, manual termination of service worker) - Opening two tabs WebLLMChat, sending concurrent request, the latter request will wait for the previous one to finish (prior to this PR, garbage output will be generated just like the above simple example, since the two WebLLMChat shares the same service worker, hence the same engine).
Hi @LEXNY this should be fixed in #549 and reflected in npm 0.2.61. You can check out the PR description for the specifics of the problem and the solution. Your example now works, though the second request does not start until the first request is finished, as we maintain a FCFS schedule, with only one request running per-model. However, there can be multiple models running in an engine, hence multiple requests can be running per-engine. For more, you can try |
Powerhouse! |
Closing this issue as completed. Feel free to reopen/open new ones if issues arise! |
If you make multiple requests with the same engine without awaiting, you get garbage.
I would like to make multiple concurrent (ideally parallel) requests to the same engine, without loading the same model into memory multiple times.
Even with
stream: false
, the engine uses streaming internally, those streams interleave, and the engine gets confused.Reproduction:
Output:
The text was updated successfully, but these errors were encountered: