Skip to content

Commit

Permalink
solid
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Feb 3, 2025
1 parent 0525428 commit eec6111
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
10 changes: 6 additions & 4 deletions packages/react/src/use-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -530,18 +530,20 @@ By default, it's set to 1, which means that only a single LLM call is made.

const addToolResult = useCallback(
({ toolCallId, result }: { toolCallId: string; result: any }) => {
const currentMessages = messagesRef.current;

updateToolCallResult({
messages: messagesRef.current,
messages: currentMessages,
toolCallId,
toolResult: result,
});

mutate(messagesRef.current, false);
mutate(currentMessages, false);

// auto-submit when all tool calls in the last assistant message have results:
const lastMessage = messagesRef.current[messagesRef.current.length - 1];
const lastMessage = currentMessages[currentMessages.length - 1];
if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
triggerRequest({ messages: messagesRef.current });
triggerRequest({ messages: currentMessages });
}
},
[mutate, triggerRequest],
Expand Down
57 changes: 33 additions & 24 deletions packages/solid/src/use-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import type {
JSONValue,
Message,
UseChatOptions as SharedUseChatOptions,
UIMessage,
} from '@ai-sdk/ui-utils';
import {
callChatApi,
extractMaxToolInvocationStep,
fillMessageParts,
generateId as generateIdFunc,
getMessageParts,
prepareAttachmentsForRequest,
updateToolCallResult,
} from '@ai-sdk/ui-utils';
Expand All @@ -33,7 +36,7 @@ export type UseChatHelpers = {
/**
* Current messages in the chat as a SolidJS store.
*/
messages: () => Store<Message[]>;
messages: () => Store<UIMessage[]>;

/** The error object of the API request */
error: Accessor<undefined | Error>;
Expand Down Expand Up @@ -115,11 +118,11 @@ or to provide a custom fetch implementation for e.g. testing.
const processStreamedResponse = async (
api: string,
chatRequest: ChatRequest,
mutate: (data: Message[]) => void,
mutate: (data: UIMessage[]) => void,
setStreamData: Setter<JSONValue[] | undefined>,
streamData: Accessor<JSONValue[] | undefined>,
extraMetadata: any,
messagesRef: Message[],
messagesRef: UIMessage[],
abortController: AbortController | null,
generateId: IdGenerator,
streamProtocol: UseChatOptions['streamProtocol'] = 'data',
Expand All @@ -134,14 +137,15 @@ const processStreamedResponse = async (
// Do an optimistic update to the chat state to show the updated messages
// immediately.
const previousMessages = messagesRef;
const chatMessages = fillMessageParts(chatRequest.messages);

mutate(chatRequest.messages);
mutate(chatMessages);

const existingStreamData = streamData() ?? [];

const constructedMessagesPayload = sendExtraMessageFields
? chatRequest.messages
: chatRequest.messages.map(
? chatMessages
: chatMessages.map(
({
role,
content,
Expand Down Expand Up @@ -186,8 +190,8 @@ const processStreamedResponse = async (
onUpdate({ message, data, replaceLastMessage }) {
mutate([
...(replaceLastMessage
? chatRequest.messages.slice(0, chatRequest.messages.length - 1)
: chatRequest.messages),
? chatMessages.slice(0, chatMessages.length - 1)
: chatMessages),
message,
]);

Expand All @@ -199,7 +203,7 @@ const processStreamedResponse = async (
onFinish,
generateId,
fetch,
lastMessage: chatRequest.messages[chatRequest.messages.length - 1],
lastMessage: chatMessages[chatMessages.length - 1],
});
};

Expand Down Expand Up @@ -235,12 +239,14 @@ export function useChat(
chatCache.get(chatKey()) ?? useChatOptions().initialMessages?.() ?? [],
);

const [messagesStore, setMessagesStore] = createStore<Message[]>(_messages());
const [messagesStore, setMessagesStore] = createStore<UIMessage[]>(
fillMessageParts(_messages()),
);
createEffect(() => {
setMessagesStore(reconcile(_messages(), { merge: true }));
setMessagesStore(reconcile(fillMessageParts(_messages()), { merge: true }));
});

const mutate = (messages: Message[]) => {
const mutate = (messages: UIMessage[]) => {
chatCache.set(chatKey(), messages);
};

Expand All @@ -250,9 +256,9 @@ export function useChat(
);
const [isLoading, setIsLoading] = createSignal(false);

let messagesRef: Message[] = _messages() || [];
let messagesRef: UIMessage[] = fillMessageParts(_messages()) || [];
createEffect(() => {
messagesRef = _messages() || [];
messagesRef = fillMessageParts(_messages()) || [];
});

let abortController: AbortController | null = null;
Expand Down Expand Up @@ -354,15 +360,17 @@ export function useChat(
experimental_attachments,
);

const newMessage = {
const messages = messagesRef.concat({
...message,
id: message.id ?? generateId()(),
createdAt: message.createdAt ?? new Date(),
experimental_attachments:
attachmentsForRequest.length > 0 ? attachmentsForRequest : undefined,
};
parts: getMessageParts(message),
});

return triggerRequest({
messages: messagesRef.concat(newMessage as Message),
messages,
headers,
body,
data,
Expand Down Expand Up @@ -405,8 +413,9 @@ export function useChat(
messagesArg = messagesArg(messagesRef);
}

mutate(messagesArg);
messagesRef = messagesArg;
const messagesWithParts = fillMessageParts(messagesArg);
mutate(messagesWithParts);
messagesRef = messagesWithParts;
};

const setData = (
Expand Down Expand Up @@ -476,20 +485,20 @@ export function useChat(
toolCallId: string;
result: any;
}) => {
const messagesSnapshot = _messages() ?? [];
const currentMessages = messagesRef ?? [];

updateToolCallResult({
messages: messagesSnapshot,
messages: currentMessages,
toolCallId,
toolResult: result,
});

mutate(messagesSnapshot);
mutate(currentMessages);

// auto-submit when all tool calls in the last assistant message have results:
const lastMessage = messagesSnapshot[messagesSnapshot.length - 1];
const lastMessage = currentMessages[currentMessages.length - 1];
if (isAssistantMessageWithCompletedToolCalls(lastMessage)) {
triggerRequest({ messages: messagesSnapshot });
triggerRequest({ messages: currentMessages });
}
};

Expand Down
2 changes: 2 additions & 0 deletions packages/solid/src/use-chat.ui.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ describe('data protocol stream', () => {
createdAt: expect.any(Date),
role: 'assistant',
content: 'Hello, world.',
parts: [{ text: 'Hello, world.', type: 'text' }],
},
options: {
finishReason: 'stop',
Expand Down Expand Up @@ -511,6 +512,7 @@ describe('text stream', () => {
createdAt: expect.any(Date),
role: 'assistant',
content: 'Hello, world.',
parts: [{ text: 'Hello, world.', type: 'text' }],
},
options: {
finishReason: 'unknown',
Expand Down

0 comments on commit eec6111

Please sign in to comment.