Skip to content

Commit

Permalink
coral-web: support multihop tool streaming events (#301)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujessica authored Jun 27, 2024
1 parent 8f40adf commit 51795fe
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { IconButton } from '@/components/IconButton';
import { DocumentIcon, Icon, IconName, Text } from '@/components/Shared';
import { TOOL_FALLBACK_ICON, TOOL_ID_TO_DISPLAY_INFO, TOOL_INTERNET_SEARCH_ID } from '@/constants';
import { TOOL_FALLBACK_ICON, TOOL_ID_TO_DISPLAY_INFO, TOOL_WEB_SEARCH_ID } from '@/constants';
import { cn, getSafeUrl, getWebDomain } from '@/utils';

const getWebSourceName = (toolId?: string) => {
if (!toolId) {
return '';
} else if (toolId === TOOL_INTERNET_SEARCH_ID) {
} else if (toolId === TOOL_WEB_SEARCH_ID) {
return 'from the web';
}
return `from ${toolId}`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@ import { Transition } from '@headlessui/react';
import React, { useMemo } from 'react';

import ButtonGroup from '@/components/ButtonGroup';
import {
TOOL_CALCULATOR_ID,
TOOL_INTERNET_SEARCH_ID,
TOOL_PYTHON_INTERPRETER_ID,
} from '@/constants';
import { TOOL_CALCULATOR_ID, TOOL_PYTHON_INTERPRETER_ID, TOOL_WEB_SEARCH_ID } from '@/constants';
import { useListTools } from '@/hooks/tools';
import { useParamsStore } from '@/stores';
import { ConfigurableParams } from '@/stores/slices/paramsSlice';
Expand All @@ -26,7 +22,7 @@ const SUGGESTED_PROMPTS: Prompt[] = [
tools: [
{ name: TOOL_PYTHON_INTERPRETER_ID },
{ name: TOOL_CALCULATOR_ID },
{ name: TOOL_INTERNET_SEARCH_ID },
{ name: TOOL_WEB_SEARCH_ID },
],
},
message:
Expand Down
8 changes: 1 addition & 7 deletions src/interfaces/coral_web/src/components/MessageRow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ const MessageRow = forwardRef<HTMLDivElement, Props>(function MessageRowInternal

const [isShowing, setIsShowing] = useState(false);
const [isLongPressMenuOpen, setIsLongPressMenuOpen] = useState(false);
const [isStepsExpanded, setIsStepsExpanded] = useState<boolean>(isLast);
const [isStepsExpanded, setIsStepsExpanded] = useState<boolean>(true);
const {
citations: { selectedCitation, hoveredGenerationId },
hoverCitation,
Expand Down Expand Up @@ -82,12 +82,6 @@ const MessageRow = forwardRef<HTMLDivElement, Props>(function MessageRowInternal
}
}, []);

useEffect(() => {
if (isLast) {
setIsStepsExpanded(true);
}
}, [isLast]);

const [highlightMessage, setHighlightMessage] = useState(false);
const prevSelectedCitationGenId = usePreviousDistinct(selectedCitation?.generationId);

Expand Down
38 changes: 25 additions & 13 deletions src/interfaces/coral_web/src/components/ToolEvents.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import { Transition } from '@headlessui/react';
import { Fragment, PropsWithChildren } from 'react';

import { StreamToolInput, ToolInputType } from '@/cohere-client';
import { StreamToolCallsGeneration, ToolCall } from '@/cohere-client';
import { Icon, IconName, Markdown, Text } from '@/components/Shared';
import {
TOOL_CALCULATOR_ID,
TOOL_FALLBACK_ICON,
TOOL_ID_TO_DISPLAY_INFO,
TOOL_INTERNET_SEARCH_ID,
TOOL_PYTHON_INTERPRETER_ID,
TOOL_WEB_SEARCH_ID,
} from '@/constants';
import { cn } from '@/utils';

type Props = {
show: boolean;
events: StreamToolInput[] | undefined;
events: StreamToolCallsGeneration[] | undefined;
};

/**
Expand All @@ -34,7 +35,9 @@ export const ToolEvents: React.FC<Props> = ({ show, events }) => {
{events?.map((toolEvent, i) => (
<Fragment key={i}>
{toolEvent.text && <ToolEvent plan={toolEvent.text} />}
<ToolEvent event={toolEvent} />
{toolEvent.tool_calls?.map((toolCall, j) => (
<ToolEvent key={`event-${j}`} event={toolCall} />
))}
</Fragment>
))}
</Transition>
Expand All @@ -43,7 +46,7 @@ export const ToolEvents: React.FC<Props> = ({ show, events }) => {

type ToolEventProps = {
plan?: string;
event?: StreamToolInput;
event?: ToolCall;
};

/**
Expand All @@ -54,15 +57,14 @@ const ToolEvent: React.FC<ToolEventProps> = ({ plan, event }) => {
return <ToolEventWrapper>{plan}</ToolEventWrapper>;
}

const toolName = event.tool_name;
const input = event.input;
const icon = toolName ? TOOL_ID_TO_DISPLAY_INFO[toolName]?.icon : TOOL_FALLBACK_ICON;
const toolName = event.name;
const icon = TOOL_ID_TO_DISPLAY_INFO[toolName]?.icon ?? TOOL_FALLBACK_ICON;

switch (toolName) {
case TOOL_PYTHON_INTERPRETER_ID: {
if (event.input_type === ToolInputType.CODE) {
if (event?.parameters?.code) {
let codeString = '```python\n';
codeString += input;
codeString += event?.parameters?.code;
codeString += '\n```';

return (
Expand All @@ -82,10 +84,18 @@ const ToolEvent: React.FC<ToolEventProps> = ({ plan, event }) => {
}
}

case TOOL_INTERNET_SEARCH_ID: {
case TOOL_CALCULATOR_ID: {
return (
<ToolEventWrapper icon={icon}>
Searching <b className="font-medium">{input}</b>
Calculating <b className="font-medium">{event?.parameters?.expression}</b>
</ToolEventWrapper>
);
}

case TOOL_WEB_SEARCH_ID: {
return (
<ToolEventWrapper icon={icon}>
Searching <b className="font-medium">{event?.parameters?.query}</b>
</ToolEventWrapper>
);
}
Expand All @@ -110,7 +120,9 @@ const ToolEventWrapper: React.FC<PropsWithChildren<{ icon?: IconName }>> = ({
return (
<div className="flex w-full gap-x-2 rounded bg-secondary-50 px-3 py-2 transition-colors ease-in-out group-hover:bg-secondary-100">
<Icon name={icon} kind="outline" className="flex h-[21px] items-center text-secondary-600" />
<Text className="text-secondary-800">{children}</Text>
<Text className="text-secondary-800" styleAs="p-sm">
{children}
</Text>
</div>
);
};
8 changes: 4 additions & 4 deletions src/interfaces/coral_web/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ export const LOCAL_STORAGE_KEYS = {
/**
* Tools
*/
export const TOOL_INTERNET_SEARCH_ID = 'internet_search';
export const TOOL_PYTHON_INTERPRETER_ID = 'python_interpreter';
export const TOOL_CALCULATOR_ID = 'calculator';
export const TOOL_WEB_SEARCH_ID = 'web_search';
export const TOOL_PYTHON_INTERPRETER_ID = 'toolkit_python_interpreter';
export const TOOL_CALCULATOR_ID = 'toolkit_calculator';
export const TOOL_WIKIPEDIA_ID = 'wikipedia';
export const TOOL_SEARCH_FILE_ID = 'search_file';

export const TOOL_FALLBACK_ICON = 'circles-four';
export const TOOL_ID_TO_DISPLAY_INFO: { [id: string]: { icon: IconName } } = {
[TOOL_INTERNET_SEARCH_ID]: { icon: 'search' },
[TOOL_WEB_SEARCH_ID]: { icon: 'search' },
[TOOL_PYTHON_INTERPRETER_ID]: { icon: 'code' },
[TOOL_CALCULATOR_ID]: { icon: 'calculator' },
[TOOL_WIKIPEDIA_ID]: { icon: 'web' },
Expand Down
117 changes: 93 additions & 24 deletions src/interfaces/coral_web/src/hooks/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import {
StreamSearchResults,
StreamStart,
StreamTextGeneration,
StreamToolInput,
StreamToolResult,
StreamToolCallsChunk,
StreamToolCallsGeneration,
isCohereNetworkError,
isSessionUnavailableError,
isStreamError,
Expand Down Expand Up @@ -202,7 +202,11 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => {
let citations: Citation[] = [];
let documentsMap: IdToDocument = {};
let outputFiles: OutputFiles = {};
let toolEvents: StreamToolInput[] = [];
let toolEvents: StreamToolCallsGeneration[] = [];
let currentToolEventIndex = 0;

// Temporarily store the streaming `parameters` partial JSON string for a tool call
let toolCallParamaterStr = '';

try {
clearComposerFiles();
Expand Down Expand Up @@ -248,9 +252,59 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => {
break;
}

case StreamEvent.TOOL_INPUT: {
const data = eventData.data as StreamToolInput;
toolEvents.push(data);
case StreamEvent.TOOL_CALLS_CHUNK: {
const data = eventData.data as StreamToolCallsChunk;

// Initiate an empty tool event if one doesn't already exist at the current index
const toolEvent: StreamToolCallsGeneration = toolEvents[currentToolEventIndex] ?? {
text: '',
tool_calls: [],
};
toolEvent.text += data?.text ?? '';

// A tool call needs to be added/updated if a tool call delta is present in the event
if (data?.tool_call_delta) {
const currentToolCallsIndex = data.tool_call_delta.index ?? 0;
let toolCall = toolEvent.tool_calls?.[currentToolCallsIndex];
if (!toolCall) {
toolCall = {
name: '',
parameters: {},
};
toolCallParamaterStr = '';
}

if (data?.tool_call_delta?.name) {
toolCall.name = data.tool_call_delta.name;
}
if (data?.tool_call_delta?.parameters) {
toolCallParamaterStr += data?.tool_call_delta?.parameters;

// Attempt to parse the partial parameter string as valid JSON to show that the parameters
// are streaming in. To make the partial JSON string valid JSON after the object key comes in,
// we naively try to add `"}` to the end.
try {
const partialParams = JSON.parse(toolCallParamaterStr + `"}`);
toolCall.parameters = partialParams;
} catch (e) {
// Ignore parsing error
}
}

// Update the tool call list with the new/updated tool call
if (toolEvent.tool_calls?.[currentToolCallsIndex]) {
toolEvent.tool_calls[currentToolCallsIndex] = toolCall;
} else {
toolEvent.tool_calls?.push(toolCall);
}
}

// Update the tool event list with the new/updated tool event
if (toolEvents[currentToolEventIndex]) {
toolEvents[currentToolEventIndex] = toolEvent;
} else {
toolEvents.push(toolEvent);
}

setStreamingMessage({
type: MessageType.BOT,
Expand All @@ -264,28 +318,43 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => {
break;
}

// This event only occurs when we're using experimental langchain multihop.
case StreamEvent.TOOL_RESULT: {
const data = eventData.data as StreamToolResult;
if (data.tool_name.toLowerCase() === TOOL_PYTHON_INTERPRETER_ID) {
const resultsWithOutputFile = data.result.filter((r: any) => r.output_file);
outputFiles = { ...mapOutputFiles(resultsWithOutputFile) };
saveOutputFiles(outputFiles);
}

setStreamingMessage({
type: MessageType.BOT,
state: BotState.TYPING,
text: botResponse,
isRAGOn,
generationId,
originalText: botResponse,
toolEvents,
});
case StreamEvent.TOOL_CALLS_GENERATION: {
const data = eventData.data as StreamToolCallsGeneration;

if (toolEvents[currentToolEventIndex]) {
toolEvents[currentToolEventIndex] = data;
currentToolEventIndex += 1;
} else {
toolEvents.push(data);
currentToolEventIndex = toolEvents.length; // double check this is right
}
break;
}

// TODO(@wujessica): temporarily remove support for experimental langchain multihop
// as it diverges from the current implementation.
// This event only occurs when we're using experimental langchain multihop.
// case StreamEvent.TOOL_RESULT: {
// const data = eventData.data as StreamToolResult;
// if (data.tool_name === TOOL_PYTHON_INTERPRETER_ID) {
// const resultsWithOutputFile = data.result.filter((r: any) => r.output_file);
// outputFiles = { ...mapOutputFiles(resultsWithOutputFile) };
// saveOutputFiles(outputFiles);
// }

// setStreamingMessage({
// type: MessageType.BOT,
// state: BotState.TYPING,
// text: botResponse,
// isRAGOn,
// generationId,
// originalText: botResponse,
// toolEvents,
// });

// break;
// }

case StreamEvent.CITATION_GENERATION: {
const data = eventData.data as StreamCitationGeneration;
const newCitations = [...(data?.citations ?? [])];
Expand Down
8 changes: 4 additions & 4 deletions src/interfaces/coral_web/src/types/message.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Citation, File, StreamToolInput } from '@/cohere-client';
import { Citation, File, StreamToolCallsGeneration, StreamToolInput } from '@/cohere-client';

export enum BotState {
LOADING = 'loading',
Expand Down Expand Up @@ -36,7 +36,7 @@ export type FulfilledMessage = BaseMessage & {
citations?: Citation[];
isRAGOn?: boolean;
originalText: string;
toolEvents?: StreamToolInput[];
toolEvents?: StreamToolCallsGeneration[];
};

/**
Expand All @@ -45,7 +45,7 @@ export type FulfilledMessage = BaseMessage & {
export type AbortedMessage = BaseMessage & {
type: MessageType.BOT;
state: BotState.ABORTED;
toolEvents?: StreamToolInput[];
toolEvents?: StreamToolCallsGeneration[];
};

/**
Expand All @@ -67,7 +67,7 @@ export type TypingMessage = BaseMessage & {
originalText: string;
citations?: Citation[];
isRAGOn?: boolean;
toolEvents?: StreamToolInput[];
toolEvents?: StreamToolCallsGeneration[];
};

/**
Expand Down

0 comments on commit 51795fe

Please sign in to comment.