Skip to content

Commit

Permalink
adjust
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Feb 3, 2025
1 parent 9da69dd commit f03ac6d
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 42 deletions.
74 changes: 56 additions & 18 deletions packages/react/src/use-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ import type {
CreateMessage,
JSONValue,
Message,
ReasoningUIPart,
TextUIPart,
ToolInvocationUIPart,
UIMessage,
UseChatOptions,
} from '@ai-sdk/ui-utils';
import {
Expand All @@ -21,7 +25,7 @@ export type { CreateMessage, Message, UseChatOptions };

export type UseChatHelpers = {
/** Current messages in the chat */
messages: Message[];
messages: UIMessage[];
/** The error object of the API request */
error: undefined | Error;
/**
Expand Down Expand Up @@ -163,14 +167,19 @@ By default, it's set to 1, which means that only a single LLM call is made.
const [initialMessagesFallback] = useState([]);

// Store the chat state in SWR, using the chatId as the key to share states.
const { data: messages, mutate } = useSWR<Message[]>(
const { data: messages, mutate } = useSWR<UIMessage[]>(
[chatKey, 'messages'],
null,
{ fallbackData: initialMessages ?? initialMessagesFallback },
{
fallbackData:
initialMessages != null
? fillInUIMessageParts(initialMessages)
: initialMessagesFallback,
},
);

// Keep the latest messages in a ref.
const messagesRef = useRef<Message[]>(messages || []);
const messagesRef = useRef<UIMessage[]>(messages || []);
useEffect(() => {
messagesRef.current = messages || [];
}, [messages]);
Expand Down Expand Up @@ -215,9 +224,11 @@ By default, it's set to 1, which means that only a single LLM call is made.

const triggerRequest = useCallback(
async (chatRequest: ChatRequest) => {
const messageCount = chatRequest.messages.length;
const chatMessages = fillInUIMessageParts(chatRequest.messages);

const messageCount = chatMessages.length;
const maxStep = extractMaxToolInvocationStep(
chatRequest.messages[chatRequest.messages.length - 1]?.toolInvocations,
chatMessages[chatMessages.length - 1]?.toolInvocations,
);

try {
Expand All @@ -235,11 +246,11 @@ By default, it's set to 1, which means that only a single LLM call is made.

// Do an optimistic update to the chat state to show the updated messages immediately:
const previousMessages = messagesRef.current;
throttledMutate(chatRequest.messages, false);
throttledMutate(chatMessages, false);

const constructedMessagesPayload = sendExtraMessageFields
? chatRequest.messages
: chatRequest.messages.map(
? chatMessages
: chatMessages.map(
({
role,
content,
Expand All @@ -265,7 +276,7 @@ By default, it's set to 1, which means that only a single LLM call is made.
api,
body: experimental_prepareRequestBody?.({
id: chatId,
messages: chatRequest.messages,
messages: chatMessages,
requestData: chatRequest.data,
requestBody: chatRequest.body,
}) ?? {
Expand All @@ -292,11 +303,8 @@ By default, it's set to 1, which means that only a single LLM call is made.
throttledMutate(
[
...(replaceLastMessage
? chatRequest.messages.slice(
0,
chatRequest.messages.length - 1,
)
: chatRequest.messages),
? chatMessages.slice(0, chatMessages.length - 1)
: chatMessages),
message,
],
false,
Expand All @@ -313,7 +321,7 @@ By default, it's set to 1, which means that only a single LLM call is made.
onFinish,
generateId,
fetch,
lastMessage: chatRequest.messages[chatRequest.messages.length - 1],
lastMessage: chatMessages[chatMessages.length - 1],
});

abortControllerRef.current = null;
Expand Down Expand Up @@ -403,6 +411,7 @@ By default, it's set to 1, which means that only a single LLM call is made.
createdAt: message.createdAt ?? new Date(),
experimental_attachments:
attachmentsForRequest.length > 0 ? attachmentsForRequest : undefined,
parts: getMessageParts(message),
});

return triggerRequest({ messages, headers, body, data });
Expand Down Expand Up @@ -444,8 +453,9 @@ By default, it's set to 1, which means that only a single LLM call is made.
messages = messages(messagesRef.current);
}

mutate(messages, false);
messagesRef.current = messages;
const messagesWithParts = fillInUIMessageParts(messages);
mutate(messagesWithParts, false);
messagesRef.current = messagesWithParts;
},
[mutate],
);
Expand Down Expand Up @@ -574,3 +584,31 @@ function isAssistantMessageWithCompletedToolCalls(
message.toolInvocations.every(toolInvocation => 'result' in toolInvocation)
);
}

function fillInUIMessageParts(messages: Message[]): UIMessage[] {
return messages.map(message => ({
...message,
parts: getMessageParts(message),
}));
}

function getMessageParts(
message: Message | CreateMessage | UIMessage,
): (TextUIPart | ReasoningUIPart | ToolInvocationUIPart)[] {
return (
message.parts ?? [
...(message.reasoning
? [{ type: 'reasoning' as const, reasoning: message.reasoning }]
: []),
...(message.content
? [{ type: 'text' as const, text: message.content }]
: []),
...(message.toolInvocations
? message.toolInvocations.map(toolInvocation => ({
type: 'tool-invocation' as const,
toolInvocation,
}))
: []),
]
);
}
1 change: 1 addition & 0 deletions packages/react/src/use-chat.ui.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ describe('prepareRequestBody', () => {
id: expect.any(String),
experimental_attachments: undefined,
createdAt: expect.any(Date),
parts: [{ type: 'text', text: 'hi' }],
},
],
requestData: { 'test-data-key': 'test-data-value' },
Expand Down
10 changes: 2 additions & 8 deletions packages/solid/src/use-chat.ui.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,8 @@ import { useChat } from './use-chat';

describe('file attachments with data url', () => {
const TestComponent = () => {
const {
messages,
handleSubmit,
handleInputChange,
isLoading,
input,
setInput,
} = useChat();
const { messages, handleSubmit, handleInputChange, isLoading, input } =
useChat();

const [attachments, setAttachments] = createSignal<FileList | undefined>();
let fileInputRef: HTMLInputElement | undefined;
Expand Down
6 changes: 3 additions & 3 deletions packages/ui-utils/src/call-chat-api.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { processChatResponse } from './process-chat-response';
import { processChatTextResponse } from './process-chat-text-response';
import { IdGenerator, JSONValue, Message, UseChatOptions } from './types';
import { IdGenerator, JSONValue, UIMessage, UseChatOptions } from './types';

// use function to allow for mocking in tests:
const getOriginalFetch = () => fetch;
Expand Down Expand Up @@ -30,15 +30,15 @@ export async function callChatApi({
restoreMessagesOnFailure: () => void;
onResponse: ((response: Response) => void | Promise<void>) | undefined;
onUpdate: (options: {
message: Message;
message: UIMessage;
data: JSONValue[] | undefined;
replaceLastMessage: boolean;
}) => void;
onFinish: UseChatOptions['onFinish'];
onToolCall: UseChatOptions['onToolCall'];
generateId: IdGenerator;
fetch: ReturnType<typeof getOriginalFetch> | undefined;
lastMessage: Message | undefined;
lastMessage: UIMessage | undefined;
}) {
const response = await fetch(api, {
method: 'POST',
Expand Down
12 changes: 6 additions & 6 deletions packages/ui-utils/src/process-chat-response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ import { parsePartialJson } from './parse-partial-json';
import { processDataStream } from './process-data-stream';
import type {
JSONValue,
Message,
ReasoningUIPart,
TextUIPart,
ToolInvocation,
ToolInvocationUIPart,
UIMessage,
UseChatOptions,
} from './types';
import { LanguageModelV1FinishReason } from '@ai-sdk/provider';
Expand All @@ -27,19 +27,19 @@ export async function processChatResponse({
}: {
stream: ReadableStream<Uint8Array>;
update: (options: {
message: Message;
message: UIMessage;
data: JSONValue[] | undefined;
replaceLastMessage: boolean;
}) => void;
onToolCall?: UseChatOptions['onToolCall'];
onFinish?: (options: {
message: Message | undefined;
message: UIMessage | undefined;
finishReason: LanguageModelV1FinishReason;
usage: LanguageModelUsage;
}) => void;
generateId?: () => string;
getCurrentDate?: () => Date;
lastMessage: Message | undefined;
lastMessage: UIMessage | undefined;
}) {
const replaceLastMessage = lastMessage?.role === 'assistant';
let step = replaceLastMessage
Expand All @@ -50,7 +50,7 @@ export async function processChatResponse({
}, 0) ?? 0)
: 0;

const message: Message = replaceLastMessage
const message: UIMessage = replaceLastMessage
? structuredClone(lastMessage)
: {
id: generateId(),
Expand Down Expand Up @@ -123,7 +123,7 @@ export async function processChatResponse({
// is updated with SWR (without it, the changes get stuck in SWR and are not
// forwarded to rendering):
revisionId: generateId(),
} as Message;
} as UIMessage;

update({
message: copiedMessage,
Expand Down
6 changes: 3 additions & 3 deletions packages/ui-utils/src/process-chat-text-response.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { JSONValue } from '@ai-sdk/provider';
import { generateId as generateIdFunction } from '@ai-sdk/provider-utils';
import { processTextStream } from './process-text-stream';
import { Message, TextUIPart, UseChatOptions } from './types';
import { TextUIPart, UIMessage, UseChatOptions } from './types';

export async function processChatTextResponse({
stream,
Expand All @@ -12,7 +12,7 @@ export async function processChatTextResponse({
}: {
stream: ReadableStream<Uint8Array>;
update: (options: {
message: Message;
message: UIMessage;
data: JSONValue[] | undefined;
replaceLastMessage: boolean;
}) => void;
Expand All @@ -22,7 +22,7 @@ export async function processChatTextResponse({
}) {
const textPart: TextUIPart = { type: 'text', text: '' };

const resultMessage: Message = {
const resultMessage: UIMessage = {
id: generateId(),
createdAt: getCurrentDate(),
role: 'assistant' as const,
Expand Down
7 changes: 6 additions & 1 deletion packages/ui-utils/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,14 @@ that the assistant made as part of this message.
*/
toolInvocations?: Array<ToolInvocation>;

parts: Array<TextUIPart | ReasoningUIPart | ToolInvocationUIPart>;
// note: optional on the Message type (which serves as input)
parts?: Array<TextUIPart | ReasoningUIPart | ToolInvocationUIPart>;
}

export type UIMessage = Message & {
parts: Array<TextUIPart | ReasoningUIPart | ToolInvocationUIPart>;
};

export type TextUIPart = {
type: 'text';
text: string;
Expand Down
6 changes: 3 additions & 3 deletions packages/ui-utils/src/update-tool-call-result.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Message, ToolInvocationUIPart } from './types';
import { ToolInvocationUIPart, UIMessage } from './types';

/**
* Updates the result of a specific tool invocation in the last message of the given messages array.
*
* @param {object} params - The parameters object.
* @param {Message[]} params.messages - An array of messages, from which the last one is updated.
* @param {UIMessage[]} params.messages - An array of messages, from which the last one is updated.
* @param {string} params.toolCallId - The unique identifier for the tool invocation to update.
* @param {unknown} params.toolResult - The result object to attach to the tool invocation.
* @returns {void} This function does not return anything.
Expand All @@ -14,7 +14,7 @@ export function updateToolCallResult({
toolCallId,
toolResult: result,
}: {
messages: Message[];
messages: UIMessage[];
toolCallId: string;
toolResult: unknown;
}) {
Expand Down

0 comments on commit f03ac6d

Please sign in to comment.