diff --git a/src/engine.ts b/src/engine.ts index 342cc3cf..5cb92f3f 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -40,7 +40,6 @@ import * as API from "./openai_api_protocols/index"; import { InitProgressCallback, MLCEngineInterface, - GenerateProgressCallback, LogitProcessor, LogLevel, } from "./types"; @@ -373,8 +372,6 @@ export class MLCEngine implements MLCEngineInterface { | string | ChatCompletionRequestNonStreaming | CompletionCreateParamsNonStreaming, - progressCallback?: GenerateProgressCallback, - streamInterval = 1, genConfig?: GenerationConfig, ): Promise { this.interruptSignal = false; @@ -391,9 +388,6 @@ export class MLCEngine implements MLCEngineInterface { } counter += 1; await this.decode(genConfig); - if (counter % streamInterval == 0 && progressCallback !== undefined) { - progressCallback(counter, await this.getMessage()); - } } return await this.getMessage(); } @@ -635,24 +629,6 @@ export class MLCEngine implements MLCEngineInterface { // 3. High-level generation APIs //------------------------------ - /** - * A legacy E2E generation API. Functionally equivalent to `chatCompletion()`. - */ - async generate( - input: string | ChatCompletionRequestNonStreaming, - progressCallback?: GenerateProgressCallback, - streamInterval = 1, - genConfig?: GenerationConfig, - ): Promise { - log.warn( - "WARNING: `generate()` will soon be deprecated. " + - "Please use `engine.chat.completions.create()` instead. " + - "For multi-round chatting, see `examples/multi-round-chat` on how to use " + - "`engine.chat.completions.create()` to achieve the same effect.", - ); - return this._generate(input, progressCallback, streamInterval, genConfig); - } - /** * Completes a single ChatCompletionRequest. * @@ -714,12 +690,7 @@ export class MLCEngine implements MLCEngineInterface { this.getPipeline().triggerStop(); outputMessage = ""; } else { - outputMessage = await this._generate( - request, - /*progressCallback=*/ undefined, - /*streamInterval=*/ 1, - /*genConfig=*/ genConfig, - ); + outputMessage = await this._generate(request, genConfig); } let finish_reason = this.getFinishReason()!; @@ -846,12 +817,7 @@ export class MLCEngine implements MLCEngineInterface { this.getPipeline().triggerStop(); outputMessage = ""; } else { - outputMessage = await this._generate( - request, - /*progressCallback=*/ undefined, - /*streamInterval=*/ 1, - /*genConfig=*/ genConfig, - ); + outputMessage = await this._generate(request, genConfig); } const finish_reason = this.getFinishReason()!; diff --git a/src/message.ts b/src/message.ts index 214ef2f4..f9cd5775 100644 --- a/src/message.ts +++ b/src/message.ts @@ -17,7 +17,6 @@ import { */ type RequestKind = | "reload" - | "generate" | "runtimeStatsText" | "interruptGenerate" | "unload" @@ -38,28 +37,15 @@ type RequestKind = | "setAppConfig"; // eslint-disable-next-line @typescript-eslint/no-unused-vars -type ResponseKind = - | "return" - | "throw" - | "initProgressCallback" - | "generateProgressCallback"; +type ResponseKind = "return" | "throw" | "initProgressCallback"; export interface ReloadParams { modelId: string; chatOpts?: ChatOptions; } -export interface GenerateParams { - input: string | ChatCompletionRequestNonStreaming; - streamInterval?: number; - genConfig?: GenerationConfig; -} export interface ResetChatParams { keepStats: boolean; } -export interface GenerateProgressCallbackParams { - step: number; - currentMessage: string; -} export interface ForwardTokensAndSampleParams { inputIds: Array; isPrefill: boolean; @@ -110,9 +96,7 @@ export interface CustomRequestParams { requestMessage: string; } export type MessageContent = - | GenerateProgressCallbackParams | ReloadParams - | GenerateParams | ResetChatParams | ForwardTokensAndSampleParams | ChatCompletionNonStreamingParams @@ -160,17 +144,7 @@ type InitProgressWorkerResponse = { content: InitProgressReport; }; -type GenerateProgressWorkerResponse = { - kind: "generateProgressCallback"; - uuid: string; - content: { - step: number; - currentMessage: string; - }; -}; - export type WorkerResponse = | OneTimeWorkerResponse | InitProgressWorkerResponse - | GenerateProgressWorkerResponse | HeartbeatWorkerResponse; diff --git a/src/types.ts b/src/types.ts index 91cc185b..1dc15899 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,4 +1,4 @@ -import { AppConfig, ChatOptions, GenerationConfig } from "./config"; +import { AppConfig, ChatOptions } from "./config"; import { ChatCompletionRequest, ChatCompletionRequestBase, @@ -30,14 +30,6 @@ export interface InitProgressReport { */ export type InitProgressCallback = (report: InitProgressReport) => void; -/** - * Callbacks used to report initialization process. - */ -export type GenerateProgressCallback = ( - step: number, - currentMessage: string, -) => void; - /** * A stateful logitProcessor used to post-process logits after forwarding the input and before * sampling the next token. If used with `GenerationConfig.logit_bias`, logit_bias is applied after @@ -114,26 +106,6 @@ export interface MLCEngineInterface { */ reload: (modelId: string, chatOpts?: ChatOptions) => Promise; - /** - * Generate a response for a given input. - * - * @param input The input prompt or a non-streaming ChatCompletionRequest. - * @param progressCallback Callback that is being called to stream intermediate results. - * @param streamInterval callback interval to call progresscallback - * @param genConfig Configuration for this single generation that overrides pre-existing configs. - * @returns The final result. - * - * @note This will be deprecated soon. Please use `engine.chat.completions.create()` instead. - * For multi-round chatting, see `examples/multi-round-chat` on how to use - * `engine.chat.completions.create()` to achieve the same effect. - */ - generate: ( - input: string | ChatCompletionRequestNonStreaming, - progressCallback?: GenerateProgressCallback, - streamInterval?: number, - genConfig?: GenerationConfig, - ) => Promise; - /** * OpenAI-style API. Generate a chat completion response for the given conversation and * configuration. Use `engine.chat.completions.create()` to invoke this API. diff --git a/src/web_worker.ts b/src/web_worker.ts index 36626cad..5f4ec6f3 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -1,12 +1,6 @@ -import { - AppConfig, - ChatOptions, - MLCEngineConfig, - GenerationConfig, -} from "./config"; +import { AppConfig, ChatOptions, MLCEngineConfig } from "./config"; import { MLCEngineInterface, - GenerateProgressCallback, InitProgressCallback, InitProgressReport, LogLevel, @@ -31,12 +25,10 @@ import * as API from "./openai_api_protocols/index"; import { MessageContent, ReloadParams, - GenerateParams, ForwardTokensAndSampleParams, ChatCompletionNonStreamingParams, ChatCompletionStreamInitParams, ResetChatParams, - GenerateProgressCallbackParams, WorkerResponse, WorkerRequest, CompletionNonStreamingParams, @@ -153,31 +145,6 @@ export class WebWorkerMLCEngineHandler { }); return; } - case "generate": { - this.handleTask(msg.uuid, async () => { - const params = msg.content as GenerateParams; - const progressCallback = (step: number, currentMessage: string) => { - const cbMessage: WorkerResponse = { - kind: "generateProgressCallback", - uuid: msg.uuid, - content: { - step: step, - currentMessage: currentMessage, - }, - }; - this.postMessage(cbMessage); - }; - const res = await this.engine.generate( - params.input, - progressCallback, - params.streamInterval, - params.genConfig, - ); - onComplete?.(res); - return res; - }); - return; - } case "forwardTokensAndSample": { this.handleTask(msg.uuid, async () => { const params = msg.content as ForwardTokensAndSampleParams; @@ -433,10 +400,6 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { chatOpts?: ChatOptions; private initProgressCallback?: InitProgressCallback; - private generateCallbackRegistry = new Map< - string, - GenerateProgressCallback - >(); private pendingPromise = new Map void>(); constructor(worker: ChatWorker, engineConfig?: MLCEngineConfig) { @@ -559,27 +522,6 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { return await this.getPromise(msg); } - async generate( - input: string | ChatCompletionRequestNonStreaming, - progressCallback?: GenerateProgressCallback, - streamInterval?: number, - genConfig?: GenerationConfig, - ): Promise { - const msg: WorkerRequest = { - kind: "generate", - uuid: crypto.randomUUID(), - content: { - input: input, - streamInterval: streamInterval, - genConfig: genConfig, - }, - }; - if (progressCallback !== undefined) { - this.generateCallbackRegistry.set(msg.uuid, progressCallback); - } - return await this.getPromise(msg); - } - async runtimeStatsText(): Promise { const msg: WorkerRequest = { kind: "runtimeStatsText", @@ -791,14 +733,6 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { } return; } - case "generateProgressCallback": { - const params = msg.content as GenerateProgressCallbackParams; - const cb = this.generateCallbackRegistry.get(msg.uuid); - if (cb !== undefined) { - cb(params.step, params.currentMessage); - } - return; - } case "return": { const cb = this.pendingPromise.get(msg.uuid); if (cb === undefined) {