Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add doRawStream #1639

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/clean-brooms-fold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
'@ai-sdk/provider-utils': patch
'@ai-sdk/provider': patch
'@ai-sdk/openai': patch
'ai': patch
---

Prototype Raw Response
59 changes: 59 additions & 0 deletions packages/core/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ import { runToolsTransformation } from './run-tools-transformation';
import { TokenUsage } from './token-usage';
import { ToToolCall } from './tool-call';
import { ToToolResult } from './tool-result';
import { LanguageModelV1CallWarning } from '@ai-sdk/provider';

/**
Generate a text and call tools for a given prompt using a language model.
@@ -110,6 +111,64 @@ The tools that the model can call. The model needs to support calling tools.
});
}

export async function streamResponse<TOOLS extends Record<string, CoreTool>>({
model,
tools,
system,
prompt,
messages,
maxRetries,
abortSignal,
...settings
}: CallSettings &
Prompt & {
/**
The language model to use.
*/
model: LanguageModel;

/**
The tools that the model can call. The model needs to support calling tools.
*/
tools?: TOOLS;
}): Promise<{
stream: ReadableStream;
warnings: LanguageModelV1CallWarning[] | undefined;
rawResponse:
| {
headers?: Record<string, string>;
}
| undefined;
}> {
const retry = retryWithExponentialBackoff({ maxRetries });
const validatedPrompt = getValidatedPrompt({ system, prompt, messages });
const { stream, warnings, rawResponse } = await retry(() => {
if (!model.doRawStream) {
throw new Error('The model does not support raw streaming.');
}
return model.doRawStream({
mode: {
type: 'regular',
tools:
tools == null
? undefined
: Object.entries(tools).map(([name, tool]) => ({
type: 'function',
name,
description: tool.description,
parameters: convertZodToJSONSchema(tool.parameters),
})),
},
...prepareCallSettings(settings),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
});
});

return { stream, warnings, rawResponse };
}

export type TextStreamPart<TOOLS extends Record<string, CoreTool>> =
| {
type: 'text-delta';
40 changes: 40 additions & 0 deletions packages/openai/src/openai-chat-language-model.ts
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import {
} from '@ai-sdk/provider';
import {
ParseResult,
createEventSourcePassThroughHandler,
createEventSourceResponseHandler,
createJsonResponseHandler,
generateId,
@@ -188,6 +189,45 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
};
}

async doRawStream(
options: Parameters<LanguageModelV1['doStream']>[0],
): Promise<
Omit<Awaited<ReturnType<LanguageModelV1['doStream']>>, 'stream'> & {
stream: ReadableStream<Uint8Array>;
}
> {
const args = this.getArgs(options);

const { responseHeaders, value: responseBody } = await postJsonToApi({
url: `${this.config.baseURL}/chat/completions`,
headers: this.config.headers(),
body: {
...args,
stream: true,

// only include stream_options when in strict compatibility mode:
stream_options:
this.config.compatibility === 'strict'
? { include_usage: true }
: undefined,
},
failedResponseHandler: openaiFailedResponseHandler,
successfulResponseHandler: createEventSourcePassThroughHandler(
openaiChatChunkSchema,
),
abortSignal: options.abortSignal,
});

const { messages: rawPrompt, ...rawSettings } = args;

return {
stream: responseBody,
rawCall: { rawPrompt, rawSettings },
rawResponse: { headers: responseHeaders },
warnings: [],
};
}

async doStream(
options: Parameters<LanguageModelV1['doStream']>[0],
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
15 changes: 15 additions & 0 deletions packages/provider-utils/src/response-handler.ts
Original file line number Diff line number Diff line change
@@ -118,6 +118,21 @@ export const createEventSourceResponseHandler =
};
};

export const createEventSourcePassThroughHandler =
<T>(chunkSchema: ZodSchema<T>): ResponseHandler<ReadableStream<Uint8Array>> =>
async ({ response }: { response: Response }) => {
const responseHeaders = extractResponseHeaders(response);

if (response.body == null) {
throw new EmptyResponseBodyError({});
}

return {
responseHeaders,
value: response.body,
};
};

export const createJsonResponseHandler =
<T>(responseSchema: ZodSchema<T>): ResponseHandler<T> =>
async ({ response, url, requestBodyValues }) => {
33 changes: 33 additions & 0 deletions packages/provider/src/language-model/v1/language-model-v1.ts
Original file line number Diff line number Diff line change
@@ -145,6 +145,39 @@ Response headers.

warnings?: LanguageModelV1CallWarning[];
}>;

doRawStream?: (options: LanguageModelV1CallOptions) => PromiseLike<{
stream: ReadableStream<Uint8Array>;

/**
Raw prompt and setting information for observability provider integration.
*/
rawCall: {
/**
Raw prompt after expansion and conversion to the format that the
provider uses to send the information to their API.
*/
rawPrompt: unknown;

/**
Raw settings that are used for the API call. Includes provider-specific
settings.
*/
rawSettings: Record<string, unknown>;
};

/**
Optional raw response data.
*/
rawResponse?: {
/**
Response headers.
*/
headers?: Record<string, string>;
};

warnings?: LanguageModelV1CallWarning[];
}>;
};

export type LanguageModelV1StreamPart =