Skip to content

Commit

Permalink
[Trivial] Generalize internal helper getModelStates (#548)
Browse files Browse the repository at this point in the history
This PR reorganizes internal code a bit to reuse code. We rename
internal helper method `getLLMStates()` to `getModelStates()` and add
`getEmebeddingStates()`, so engine's methods `embedding()`,
`completion()`, and `chatCompletion()` can share the same preprocessing
helper method.
  • Loading branch information
CharlieFRuan authored Aug 14, 2024
1 parent 15277a5 commit c7a7285
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 33 deletions.
90 changes: 58 additions & 32 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ export class MLCEngine implements MLCEngineInterface {
): Promise<AsyncIterable<Completion> | Completion> {
// 0. Check model loaded and preprocess inputs
const [selectedModelId, selectedPipeline, selectedChatConfig] =
this.getLLMStates("ChatCompletionRequest", request.model);
this.getLLMStates("CompletionCreateParams", request.model);
API.postInitAndCheckFieldsCompletion(request, selectedModelId);
const genConfig: GenerationConfig = {
frequency_penalty: request.frequency_penalty,
Expand Down Expand Up @@ -930,28 +930,10 @@ export class MLCEngine implements MLCEngineInterface {
request: EmbeddingCreateParams,
): Promise<CreateEmbeddingResponse> {
// 0. Preprocess inputs
const loadedModelIds: string[] = Array.from(
this.loadedModelIdToPipeline.keys(),
);
const selectedModelId: string = getModelIdToUse(
loadedModelIds,
request.model,
const [selectedModelId, selectedPipeline] = this.getEmbeddingStates(
"EmbeddingCreateParams",
request.model,
);
const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId);
if (!(selectedPipeline instanceof EmbeddingPipeline)) {
throw new IncorrectPipelineLoadedError(
selectedModelId,
"EmbeddingPipeline",
"EmbeddingCreateParams",
);
}
if (
findModelRecord(selectedModelId, this.appConfig).model_type !==
ModelType.embedding
) {
throw new EmbeddingUnsupportedModelError(selectedModelId);
}
API.postInitAndCheckFieldsEmbedding(request, selectedModelId);

// 1. Call EmbeddingPipeline to get embeddings
Expand Down Expand Up @@ -1031,19 +1013,42 @@ export class MLCEngine implements MLCEngineInterface {
// Needed due to possibly multiple loaded models.
//---------------------------------------------------------------

private getLLMStates(
requestName: string,
modelId?: string | null,
): [string, LLMChatPipeline, ChatConfig] {
return this.getModelStates(requestName, ModelType.LLM, modelId) as [
string,
LLMChatPipeline,
ChatConfig,
];
}

private getEmbeddingStates(
requestName: string,
modelId?: string | null,
): [string, EmbeddingPipeline, ChatConfig] {
return this.getModelStates(requestName, ModelType.embedding, modelId) as [
string,
EmbeddingPipeline,
ChatConfig,
];
}

/**
* Return the model, its LLMChatPipeline, and ChatConfig to use. Throws error when unclear which
* model to load.
* model to load. Ensure all loadedModelIdToXXX maps contain entry for the selected modelId.
* @param requestName The type of request or API to load the model for. Needed for error throwing.
* @param modelType The typ of model, determining what type of pipeline to expect.
* @param modelId Model the user specified to load via the request. Required when multiple
* models are loaded
*/
private getLLMStates(
private getModelStates(
requestName: string,
modelType: ModelType,
modelId?: string | null,
): [string, LLMChatPipeline, ChatConfig] {
// TODO(webllm-team): when more modalities/pipelines are supported, make this method
// generic for different pipelines. e.g. currently embedding() does not use this method
): [string, LLMChatPipeline | EmbeddingPipeline, ChatConfig] {
// 0. Select model based on request.model and loadedModelIds
const loadedModelIds: string[] = Array.from(
this.loadedModelIdToPipeline.keys(),
);
Expand All @@ -1052,14 +1057,35 @@ export class MLCEngine implements MLCEngineInterface {
modelId,
requestName,
);

// 1. Retrieve pipeline
const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId);
if (!(selectedPipeline instanceof LLMChatPipeline)) {
throw new IncorrectPipelineLoadedError(
selectedModelId,
"LLMChatPipeline",
requestName,
);
if (modelType === ModelType.LLM) {
if (!(selectedPipeline instanceof LLMChatPipeline)) {
throw new IncorrectPipelineLoadedError(
selectedModelId,
"LLMChatPipeline",
requestName,
);
}
} else {
// ModelType.Embedding
if (!(selectedPipeline instanceof EmbeddingPipeline)) {
throw new IncorrectPipelineLoadedError(
selectedModelId,
"EmbeddingPipeline",
requestName,
);
}
if (
findModelRecord(selectedModelId, this.appConfig).model_type !==
ModelType.embedding
) {
throw new EmbeddingUnsupportedModelError(selectedModelId);
}
}

// 2. Retrieve chat config
const selectedChatConfig =
this.loadedModelIdToChatConfig.get(selectedModelId);
if (selectedChatConfig === undefined) {
Expand Down
2 changes: 1 addition & 1 deletion src/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ export class IncorrectPipelineLoadedError extends Error {
requestName: string,
) {
super(
`${requestName} expects model be loaded with ${expectedPipeline}. However, ` +
`${requestName} expects model to be loaded with ${expectedPipeline}. However, ` +
`${selectedModelId} is not loaded with this pipeline.`,
);
this.name = "IncorrectPipelineLoadedError";
Expand Down

0 comments on commit c7a7285

Please sign in to comment.