diff --git a/src/engine.ts b/src/engine.ts index 3fe0cbb2..54daa73c 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -833,7 +833,7 @@ export class MLCEngine implements MLCEngineInterface { ): Promise | Completion> { // 0. Check model loaded and preprocess inputs const [selectedModelId, selectedPipeline, selectedChatConfig] = - this.getLLMStates("ChatCompletionRequest", request.model); + this.getLLMStates("CompletionCreateParams", request.model); API.postInitAndCheckFieldsCompletion(request, selectedModelId); const genConfig: GenerationConfig = { frequency_penalty: request.frequency_penalty, @@ -930,28 +930,10 @@ export class MLCEngine implements MLCEngineInterface { request: EmbeddingCreateParams, ): Promise { // 0. Preprocess inputs - const loadedModelIds: string[] = Array.from( - this.loadedModelIdToPipeline.keys(), - ); - const selectedModelId: string = getModelIdToUse( - loadedModelIds, - request.model, + const [selectedModelId, selectedPipeline] = this.getEmbeddingStates( "EmbeddingCreateParams", + request.model, ); - const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId); - if (!(selectedPipeline instanceof EmbeddingPipeline)) { - throw new IncorrectPipelineLoadedError( - selectedModelId, - "EmbeddingPipeline", - "EmbeddingCreateParams", - ); - } - if ( - findModelRecord(selectedModelId, this.appConfig).model_type !== - ModelType.embedding - ) { - throw new EmbeddingUnsupportedModelError(selectedModelId); - } API.postInitAndCheckFieldsEmbedding(request, selectedModelId); // 1. Call EmbeddingPipeline to get embeddings @@ -1031,19 +1013,42 @@ export class MLCEngine implements MLCEngineInterface { // Needed due to possibly multiple loaded models. //--------------------------------------------------------------- + private getLLMStates( + requestName: string, + modelId?: string | null, + ): [string, LLMChatPipeline, ChatConfig] { + return this.getModelStates(requestName, ModelType.LLM, modelId) as [ + string, + LLMChatPipeline, + ChatConfig, + ]; + } + + private getEmbeddingStates( + requestName: string, + modelId?: string | null, + ): [string, EmbeddingPipeline, ChatConfig] { + return this.getModelStates(requestName, ModelType.embedding, modelId) as [ + string, + EmbeddingPipeline, + ChatConfig, + ]; + } + /** * Return the model, its LLMChatPipeline, and ChatConfig to use. Throws error when unclear which - * model to load. + * model to load. Ensure all loadedModelIdToXXX maps contain entry for the selected modelId. * @param requestName The type of request or API to load the model for. Needed for error throwing. + * @param modelType The typ of model, determining what type of pipeline to expect. * @param modelId Model the user specified to load via the request. Required when multiple * models are loaded */ - private getLLMStates( + private getModelStates( requestName: string, + modelType: ModelType, modelId?: string | null, - ): [string, LLMChatPipeline, ChatConfig] { - // TODO(webllm-team): when more modalities/pipelines are supported, make this method - // generic for different pipelines. e.g. currently embedding() does not use this method + ): [string, LLMChatPipeline | EmbeddingPipeline, ChatConfig] { + // 0. Select model based on request.model and loadedModelIds const loadedModelIds: string[] = Array.from( this.loadedModelIdToPipeline.keys(), ); @@ -1052,14 +1057,35 @@ export class MLCEngine implements MLCEngineInterface { modelId, requestName, ); + + // 1. Retrieve pipeline const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId); - if (!(selectedPipeline instanceof LLMChatPipeline)) { - throw new IncorrectPipelineLoadedError( - selectedModelId, - "LLMChatPipeline", - requestName, - ); + if (modelType === ModelType.LLM) { + if (!(selectedPipeline instanceof LLMChatPipeline)) { + throw new IncorrectPipelineLoadedError( + selectedModelId, + "LLMChatPipeline", + requestName, + ); + } + } else { + // ModelType.Embedding + if (!(selectedPipeline instanceof EmbeddingPipeline)) { + throw new IncorrectPipelineLoadedError( + selectedModelId, + "EmbeddingPipeline", + requestName, + ); + } + if ( + findModelRecord(selectedModelId, this.appConfig).model_type !== + ModelType.embedding + ) { + throw new EmbeddingUnsupportedModelError(selectedModelId); + } } + + // 2. Retrieve chat config const selectedChatConfig = this.loadedModelIdToChatConfig.get(selectedModelId); if (selectedChatConfig === undefined) { diff --git a/src/error.ts b/src/error.ts index 0c4bd568..ae09a90f 100644 --- a/src/error.ts +++ b/src/error.ts @@ -524,7 +524,7 @@ export class IncorrectPipelineLoadedError extends Error { requestName: string, ) { super( - `${requestName} expects model be loaded with ${expectedPipeline}. However, ` + + `${requestName} expects model to be loaded with ${expectedPipeline}. However, ` + `${selectedModelId} is not loaded with this pipeline.`, ); this.name = "IncorrectPipelineLoadedError";