Skip to content

Commit

Permalink
[API] Deprecate engine.generate() (#541)
Browse files Browse the repository at this point in the history
This PR deprecates `generate()` from all `MLCEngineInterface`. Its usage
can be completely covered by `engine.chat.completions()` for
conversation-style generation, and `engine.completions()` for raw text
completion.

Tested with streaming/non-streaming on MLCEngine/WebWorkerMLCEngine to
ensure other APIs are not affected.
  • Loading branch information
CharlieFRuan committed Aug 12, 2024
1 parent 552ec95 commit 4e018b9
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 159 deletions.
38 changes: 2 additions & 36 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import * as API from "./openai_api_protocols/index";
import {
InitProgressCallback,
MLCEngineInterface,
GenerateProgressCallback,
LogitProcessor,
LogLevel,
} from "./types";
Expand Down Expand Up @@ -373,8 +372,6 @@ export class MLCEngine implements MLCEngineInterface {
| string
| ChatCompletionRequestNonStreaming
| CompletionCreateParamsNonStreaming,
progressCallback?: GenerateProgressCallback,
streamInterval = 1,
genConfig?: GenerationConfig,
): Promise<string> {
this.interruptSignal = false;
Expand All @@ -391,9 +388,6 @@ export class MLCEngine implements MLCEngineInterface {
}
counter += 1;

Check warning on line 389 in src/engine.ts

View workflow job for this annotation

GitHub Actions / lint

'counter' is assigned a value but never used
await this.decode(genConfig);
if (counter % streamInterval == 0 && progressCallback !== undefined) {
progressCallback(counter, await this.getMessage());
}
}
return await this.getMessage();
}
Expand Down Expand Up @@ -635,24 +629,6 @@ export class MLCEngine implements MLCEngineInterface {
// 3. High-level generation APIs
//------------------------------

/**
* A legacy E2E generation API. Functionally equivalent to `chatCompletion()`.
*/
async generate(
input: string | ChatCompletionRequestNonStreaming,
progressCallback?: GenerateProgressCallback,
streamInterval = 1,
genConfig?: GenerationConfig,
): Promise<string> {
log.warn(
"WARNING: `generate()` will soon be deprecated. " +
"Please use `engine.chat.completions.create()` instead. " +
"For multi-round chatting, see `examples/multi-round-chat` on how to use " +
"`engine.chat.completions.create()` to achieve the same effect.",
);
return this._generate(input, progressCallback, streamInterval, genConfig);
}

/**
* Completes a single ChatCompletionRequest.
*
Expand Down Expand Up @@ -714,12 +690,7 @@ export class MLCEngine implements MLCEngineInterface {
this.getPipeline().triggerStop();
outputMessage = "";
} else {
outputMessage = await this._generate(
request,
/*progressCallback=*/ undefined,
/*streamInterval=*/ 1,
/*genConfig=*/ genConfig,
);
outputMessage = await this._generate(request, genConfig);
}
let finish_reason = this.getFinishReason()!;

Expand Down Expand Up @@ -846,12 +817,7 @@ export class MLCEngine implements MLCEngineInterface {
this.getPipeline().triggerStop();
outputMessage = "";
} else {
outputMessage = await this._generate(
request,
/*progressCallback=*/ undefined,
/*streamInterval=*/ 1,
/*genConfig=*/ genConfig,
);
outputMessage = await this._generate(request, genConfig);
}
const finish_reason = this.getFinishReason()!;

Expand Down
28 changes: 1 addition & 27 deletions src/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import {
*/
type RequestKind =
| "reload"
| "generate"
| "runtimeStatsText"
| "interruptGenerate"
| "unload"
Expand All @@ -38,28 +37,15 @@ type RequestKind =
| "setAppConfig";

// eslint-disable-next-line @typescript-eslint/no-unused-vars
type ResponseKind =
| "return"
| "throw"
| "initProgressCallback"
| "generateProgressCallback";
type ResponseKind = "return" | "throw" | "initProgressCallback";

export interface ReloadParams {
modelId: string;
chatOpts?: ChatOptions;
}
export interface GenerateParams {
input: string | ChatCompletionRequestNonStreaming;
streamInterval?: number;
genConfig?: GenerationConfig;
}
export interface ResetChatParams {
keepStats: boolean;
}
export interface GenerateProgressCallbackParams {
step: number;
currentMessage: string;
}
export interface ForwardTokensAndSampleParams {
inputIds: Array<number>;
isPrefill: boolean;
Expand Down Expand Up @@ -110,9 +96,7 @@ export interface CustomRequestParams {
requestMessage: string;
}
export type MessageContent =
| GenerateProgressCallbackParams
| ReloadParams
| GenerateParams
| ResetChatParams
| ForwardTokensAndSampleParams
| ChatCompletionNonStreamingParams
Expand Down Expand Up @@ -160,17 +144,7 @@ type InitProgressWorkerResponse = {
content: InitProgressReport;
};

type GenerateProgressWorkerResponse = {
kind: "generateProgressCallback";
uuid: string;
content: {
step: number;
currentMessage: string;
};
};

export type WorkerResponse =
| OneTimeWorkerResponse
| InitProgressWorkerResponse
| GenerateProgressWorkerResponse
| HeartbeatWorkerResponse;
30 changes: 1 addition & 29 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AppConfig, ChatOptions, GenerationConfig } from "./config";
import { AppConfig, ChatOptions } from "./config";
import {
ChatCompletionRequest,
ChatCompletionRequestBase,
Expand Down Expand Up @@ -30,14 +30,6 @@ export interface InitProgressReport {
*/
export type InitProgressCallback = (report: InitProgressReport) => void;

/**
* Callbacks used to report initialization process.
*/
export type GenerateProgressCallback = (
step: number,
currentMessage: string,
) => void;

/**
* A stateful logitProcessor used to post-process logits after forwarding the input and before
* sampling the next token. If used with `GenerationConfig.logit_bias`, logit_bias is applied after
Expand Down Expand Up @@ -114,26 +106,6 @@ export interface MLCEngineInterface {
*/
reload: (modelId: string, chatOpts?: ChatOptions) => Promise<void>;

/**
* Generate a response for a given input.
*
* @param input The input prompt or a non-streaming ChatCompletionRequest.
* @param progressCallback Callback that is being called to stream intermediate results.
* @param streamInterval callback interval to call progresscallback
* @param genConfig Configuration for this single generation that overrides pre-existing configs.
* @returns The final result.
*
* @note This will be deprecated soon. Please use `engine.chat.completions.create()` instead.
* For multi-round chatting, see `examples/multi-round-chat` on how to use
* `engine.chat.completions.create()` to achieve the same effect.
*/
generate: (
input: string | ChatCompletionRequestNonStreaming,
progressCallback?: GenerateProgressCallback,
streamInterval?: number,
genConfig?: GenerationConfig,
) => Promise<string>;

/**
* OpenAI-style API. Generate a chat completion response for the given conversation and
* configuration. Use `engine.chat.completions.create()` to invoke this API.
Expand Down
68 changes: 1 addition & 67 deletions src/web_worker.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import {
AppConfig,
ChatOptions,
MLCEngineConfig,
GenerationConfig,
} from "./config";
import { AppConfig, ChatOptions, MLCEngineConfig } from "./config";
import {
MLCEngineInterface,
GenerateProgressCallback,
InitProgressCallback,
InitProgressReport,
LogLevel,
Expand All @@ -31,12 +25,10 @@ import * as API from "./openai_api_protocols/index";
import {
MessageContent,
ReloadParams,
GenerateParams,
ForwardTokensAndSampleParams,
ChatCompletionNonStreamingParams,
ChatCompletionStreamInitParams,
ResetChatParams,
GenerateProgressCallbackParams,
WorkerResponse,
WorkerRequest,
CompletionNonStreamingParams,
Expand Down Expand Up @@ -153,31 +145,6 @@ export class WebWorkerMLCEngineHandler {
});
return;
}
case "generate": {
this.handleTask(msg.uuid, async () => {
const params = msg.content as GenerateParams;
const progressCallback = (step: number, currentMessage: string) => {
const cbMessage: WorkerResponse = {
kind: "generateProgressCallback",
uuid: msg.uuid,
content: {
step: step,
currentMessage: currentMessage,
},
};
this.postMessage(cbMessage);
};
const res = await this.engine.generate(
params.input,
progressCallback,
params.streamInterval,
params.genConfig,
);
onComplete?.(res);
return res;
});
return;
}
case "forwardTokensAndSample": {
this.handleTask(msg.uuid, async () => {
const params = msg.content as ForwardTokensAndSampleParams;
Expand Down Expand Up @@ -433,10 +400,6 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
chatOpts?: ChatOptions;

private initProgressCallback?: InitProgressCallback;
private generateCallbackRegistry = new Map<
string,
GenerateProgressCallback
>();
private pendingPromise = new Map<string, (msg: WorkerResponse) => void>();

constructor(worker: ChatWorker, engineConfig?: MLCEngineConfig) {
Expand Down Expand Up @@ -559,27 +522,6 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
return await this.getPromise<string>(msg);
}

async generate(
input: string | ChatCompletionRequestNonStreaming,
progressCallback?: GenerateProgressCallback,
streamInterval?: number,
genConfig?: GenerationConfig,
): Promise<string> {
const msg: WorkerRequest = {
kind: "generate",
uuid: crypto.randomUUID(),
content: {
input: input,
streamInterval: streamInterval,
genConfig: genConfig,
},
};
if (progressCallback !== undefined) {
this.generateCallbackRegistry.set(msg.uuid, progressCallback);
}
return await this.getPromise<string>(msg);
}

async runtimeStatsText(): Promise<string> {
const msg: WorkerRequest = {
kind: "runtimeStatsText",
Expand Down Expand Up @@ -791,14 +733,6 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
}
return;
}
case "generateProgressCallback": {
const params = msg.content as GenerateProgressCallbackParams;
const cb = this.generateCallbackRegistry.get(msg.uuid);
if (cb !== undefined) {
cb(params.step, params.currentMessage);
}
return;
}
case "return": {
const cb = this.pendingPromise.get(msg.uuid);
if (cb === undefined) {
Expand Down

0 comments on commit 4e018b9

Please sign in to comment.