Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trivial] Generalize internal helper getModelStates #548

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 58 additions & 32 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@
pipeline.triggerStop();
break;
}
counter += 1;

Check warning on line 430 in src/engine.ts

View workflow job for this annotation

GitHub Actions / lint

'counter' is assigned a value but never used
await this.decode(pipeline, genConfig);
}
return pipeline.getMessage();
Expand Down Expand Up @@ -833,7 +833,7 @@
): Promise<AsyncIterable<Completion> | Completion> {
// 0. Check model loaded and preprocess inputs
const [selectedModelId, selectedPipeline, selectedChatConfig] =
this.getLLMStates("ChatCompletionRequest", request.model);
this.getLLMStates("CompletionCreateParams", request.model);
API.postInitAndCheckFieldsCompletion(request, selectedModelId);
const genConfig: GenerationConfig = {
frequency_penalty: request.frequency_penalty,
Expand Down Expand Up @@ -930,28 +930,10 @@
request: EmbeddingCreateParams,
): Promise<CreateEmbeddingResponse> {
// 0. Preprocess inputs
const loadedModelIds: string[] = Array.from(
this.loadedModelIdToPipeline.keys(),
);
const selectedModelId: string = getModelIdToUse(
loadedModelIds,
request.model,
const [selectedModelId, selectedPipeline] = this.getEmbeddingStates(
"EmbeddingCreateParams",
request.model,
);
const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId);
if (!(selectedPipeline instanceof EmbeddingPipeline)) {
throw new IncorrectPipelineLoadedError(
selectedModelId,
"EmbeddingPipeline",
"EmbeddingCreateParams",
);
}
if (
findModelRecord(selectedModelId, this.appConfig).model_type !==
ModelType.embedding
) {
throw new EmbeddingUnsupportedModelError(selectedModelId);
}
API.postInitAndCheckFieldsEmbedding(request, selectedModelId);

// 1. Call EmbeddingPipeline to get embeddings
Expand Down Expand Up @@ -1031,19 +1013,42 @@
// Needed due to possibly multiple loaded models.
//---------------------------------------------------------------

private getLLMStates(
requestName: string,
modelId?: string | null,
): [string, LLMChatPipeline, ChatConfig] {
return this.getModelStates(requestName, ModelType.LLM, modelId) as [
string,
LLMChatPipeline,
ChatConfig,
];
}

private getEmbeddingStates(
requestName: string,
modelId?: string | null,
): [string, EmbeddingPipeline, ChatConfig] {
return this.getModelStates(requestName, ModelType.embedding, modelId) as [
string,
EmbeddingPipeline,
ChatConfig,
];
}

/**
* Return the model, its LLMChatPipeline, and ChatConfig to use. Throws error when unclear which
* model to load.
* model to load. Ensure all loadedModelIdToXXX maps contain entry for the selected modelId.
* @param requestName The type of request or API to load the model for. Needed for error throwing.
* @param modelType The typ of model, determining what type of pipeline to expect.
* @param modelId Model the user specified to load via the request. Required when multiple
* models are loaded
*/
private getLLMStates(
private getModelStates(
requestName: string,
modelType: ModelType,
modelId?: string | null,
): [string, LLMChatPipeline, ChatConfig] {
// TODO(webllm-team): when more modalities/pipelines are supported, make this method
// generic for different pipelines. e.g. currently embedding() does not use this method
): [string, LLMChatPipeline | EmbeddingPipeline, ChatConfig] {
// 0. Select model based on request.model and loadedModelIds
const loadedModelIds: string[] = Array.from(
this.loadedModelIdToPipeline.keys(),
);
Expand All @@ -1052,14 +1057,35 @@
modelId,
requestName,
);

// 1. Retrieve pipeline
const selectedPipeline = this.loadedModelIdToPipeline.get(selectedModelId);
if (!(selectedPipeline instanceof LLMChatPipeline)) {
throw new IncorrectPipelineLoadedError(
selectedModelId,
"LLMChatPipeline",
requestName,
);
if (modelType === ModelType.LLM) {
if (!(selectedPipeline instanceof LLMChatPipeline)) {
throw new IncorrectPipelineLoadedError(
selectedModelId,
"LLMChatPipeline",
requestName,
);
}
} else {
// ModelType.Embedding
if (!(selectedPipeline instanceof EmbeddingPipeline)) {
throw new IncorrectPipelineLoadedError(
selectedModelId,
"EmbeddingPipeline",
requestName,
);
}
if (
findModelRecord(selectedModelId, this.appConfig).model_type !==
ModelType.embedding
) {
throw new EmbeddingUnsupportedModelError(selectedModelId);
}
}

// 2. Retrieve chat config
const selectedChatConfig =
this.loadedModelIdToChatConfig.get(selectedModelId);
if (selectedChatConfig === undefined) {
Expand Down
2 changes: 1 addition & 1 deletion src/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ export class IncorrectPipelineLoadedError extends Error {
requestName: string,
) {
super(
`${requestName} expects model be loaded with ${expectedPipeline}. However, ` +
`${requestName} expects model to be loaded with ${expectedPipeline}. However, ` +
`${selectedModelId} is not loaded with this pipeline.`,
);
this.name = "IncorrectPipelineLoadedError";
Expand Down
Loading