[API][Engine] Support loading multiple models in a single engine #542
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR does not change any behavior of existingly supported API/workflow, but only introduces new behaviors when multiple models are loaded in a single engine. For users, please refer to
examples/multi-models
and "User-facing changes" below.User-facing changes
model
field is required to be specified inChatCompletionRequest
,EmbeddingCreateParams
, andCompletionCreateParams
getMessage()
,runtimeStatsText()
,resetChat()
,forwardTokensAndSample()
reload()
/unload()
:reload([model_id1, model_id2])
reload(model_id1)
unload()
unloads all loaded models, or interrupt a currentreload
(as before)CreateMLCEngine()
,CreateWebWorkerEngine()
,CreateServiceWorkerEngine()
are all updated accordingly to allowstring[]
input for modelIdInternal code change
Before diving in, we first identify the following APIs in MLCEngineInterface that directly interact with pipeline, as already mentioned above:
getMessage()
,forwardTokensAndSample()
,runtimeStatsText()
,resetChat()
support.ts
,utils.ts
getModelIdToUse()
: given the loaded model ids and the requestedModel (can be undefined), reconcile the model to use. Added unit testareChatOptionsListEqual()
: called when the service worker decides whether to skip loading. Added unit testsrc/openai_api_protocols
model
from the list of unsupported fields and update the documentationtypes.ts
reload()
'smodelId
is nowstring | string[]
. Updatereload()
andunload()
documentationmodelId
is required when multiple models are loaded, as explained in "User-facing changes"src/engine.ts
-- main changes happen herethis.loadedModelIdToPipeline
andthis.loadedModelIdToChatConfig
to replacethis.currentModelId
,this.config
,this.pipeline
. UpdatereloadInternal()
andunload()
to reflect this.reload()
: more preprocessing due to possiblestring[]
. For reload abort, a single abort stops all models loadinggetLLMStates()
to query the currently loaded pipeline, model id, and chatConfigthis.getPipeline()
. Instead, each internal API (prefill()
,decode()
,_generate()
,asyncGenerate()
) all know which pipeline to execute on from their parameter, passed by the high-level APIs (embedding()
,chatCompletion()
,completion()
)getLLMStates()
. These are:getMessage()
,forwardTokensAndSample()
,runtimeStatsText()
,resetChat()
WebWorker changes:
message.ts
,extension_service_worker.ts
,service_worker.ts
,web_worker.ts
modelId?
parameter for the user-facing APIs that directly interact with a pipeline (resetChat()
, etc. as mentioned earlier):message.ts
WebWorkerMLCEngineHandler
, and the sending of these messages inWebWorkerMLCEngine
Handler.modelId
andHandler.chatOpts
tostring[]
andChatOptions[]
ChatCompletionNonStreamingParams
), also sendstring[]
andchatOptions[]
. That is, we send all the expected loaded models, despite this API only requiring one. As a result, we reload if theHandler.modelId
andmodelId
in API's message param do not strictly matchreload()
,ReloadParams
, and reload-handling in Handler inServiceWorker
. To skip loading:this.modelId === params.modelId
, Now:areArraysEqual(this.modelId, params.modelId)
areChatOptionsEqual(...)
, Now:areChatOptionsListEqual(...)
Tested
chat.completions.create()
andembeddings.create()
for MLCEngine, web worker, and service workerexamples/multi-models
; also in web worker and service worker settingsabortController()
by removingawait
fromreload()
andasyncInitChat()
insimple_chat.ts