Skip to content

Commit

Permalink
SWC-7091
Browse files Browse the repository at this point in the history
  • Loading branch information
jay-hodgson committed Sep 24, 2024
1 parent 9eabcbd commit c209b06
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import SynapseChatInteraction from './SynapseChatInteraction'
import { SkeletonParagraph } from '../Skeleton'
import {
useCreateAgentSession,
useGetChatAgentTraceEvents,
useSendChatMessageToAgent,
useUpdateAgentSession,
} from '../../synapse-queries/chat/useChat'
Expand All @@ -21,6 +22,8 @@ import { TextField } from '@mui/material'
import { useSynapseContext } from '../../utils'
import AccessLevelMenu from './AccessLevelMenu'
import { displayToast } from '../ToastMessage'
import { SynapseSpinner } from '../LoadingScreen/LoadingScreen'
import { Tooltip } from '@mui/material'

export type SynapseChatProps = {
initialMessage?: string //optional initial message
Expand Down Expand Up @@ -66,26 +69,39 @@ export const SynapseChat: React.FunctionComponent<SynapseChatProps> = ({
const [interactions, setInteractions] = useState<ChatInteraction[]>([])
const [pendingInteraction, setPendingInteraction] =
useState<ChatInteraction>()
const [currentlyProcessingJobId, setCurrentlyProcessingJobId] =
useState<string>()
const [currentResponse, setCurrentResponse] = useState('')
const [currentProgressMessage, setCurrentProgressMessage] = useState<
string | undefined
>()
const [currentResponseError, setCurrentResponseError] = useState('')
// Keep track of the text that the user is currently typing into the textfield
const [userChatTextfieldValue, setUserChatTextfieldValue] = useState('')
const [initialMessageProcessed, setInitialMessageProcessed] = useState(false)
const { mutate: sendChatMessageToAgent } = useSendChatMessageToAgent(
{
onSuccess: data => {
onSuccess: async data => {
// whenever the response is returned, set the last interaction response text
setCurrentResponse(data.responseText)
setCurrentlyProcessingJobId(undefined)
},
onError: err => {
setCurrentResponseError(err.reason)
setCurrentlyProcessingJobId(undefined)
},
},
(status: AsynchronousJobStatus<AgentChatRequest, AgentChatResponse>) => {
setCurrentProgressMessage(status?.progressMessage)
setCurrentlyProcessingJobId(status.jobId)
},
)

const { data: traceEvents } = useGetChatAgentTraceEvents(
{
jobId: currentlyProcessingJobId!,
},
{
//enabled if there is a pending interaction
enabled: !!currentlyProcessingJobId,
refetchInterval: !!currentlyProcessingJobId ? 1000 : false, // Re-fetch every second if enabled
refetchIntervalInBackground: true, // Continue polling even when the tab is not active
},
)

Expand Down Expand Up @@ -122,7 +138,6 @@ export const SynapseChat: React.FunctionComponent<SynapseChatProps> = ({
setInteractions([...interactions, pendingInteraction])
setCurrentResponse('')
setCurrentResponseError('')
setCurrentProgressMessage('')
setPendingInteraction(undefined)
}
}, [currentResponse, currentResponseError, pendingInteraction])
Expand All @@ -144,6 +159,7 @@ export const SynapseChat: React.FunctionComponent<SynapseChatProps> = ({
sendChatMessageToAgent({
chatText: initialMessage,
sessionId: agentSession!.sessionId,
enableTrace: true,
})
setInitialMessageProcessed(true)
}
Expand All @@ -156,6 +172,7 @@ export const SynapseChat: React.FunctionComponent<SynapseChatProps> = ({
sendChatMessageToAgent({
chatText: userChatTextfieldValue,
sessionId: agentSession!.sessionId,
enableTrace: true,
})
}
}
Expand Down Expand Up @@ -245,13 +262,48 @@ export const SynapseChat: React.FunctionComponent<SynapseChatProps> = ({
)
})}
{pendingInteraction && (
<SynapseChatInteraction
userMessage={pendingInteraction.userMessage}
chatResponseText={pendingInteraction.chatResponseText}
chatErrorReason={pendingInteraction.chatErrorReason}
progressMessage={currentProgressMessage}
scrollIntoView
/>
<>
<SynapseChatInteraction
userMessage={pendingInteraction.userMessage}
chatResponseText={pendingInteraction.chatResponseText}
chatErrorReason={pendingInteraction.chatErrorReason}
scrollIntoView
/>
<Box
sx={{
display: 'flex',
flexDirection: 'column',
justifyContent: 'center',
}}
>
<SynapseSpinner size={40} />
{/* Show the current message, as well as the full trace log in a tooltip */}
{traceEvents && traceEvents.page && (
<Tooltip
placement="bottom"
title={
<div style={{ textAlign: 'center' }}>
{traceEvents?.page?.map((event, index) => {
return (
<Typography key={`${index}-${event.message}`}>
{event.message}
</Typography>
)
})}
</div>
}
>
<Typography
sx={{ textAlign: 'center' }}
variant="body1Italic"
>
{traceEvents.page[traceEvents.page.length - 1].message}
</Typography>
</Tooltip>
)}
<SkeletonParagraph numRows={3} />
</Box>
</>
)}
</List>
</Box>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import React, { useEffect, useRef } from 'react'
import { Alert, Box, ListItem, ListItemText, Typography } from '@mui/material'
import { Alert, Box, ListItem, ListItemText } from '@mui/material'
import { useTheme } from '@mui/material'
import { ColorPartial } from '@mui/material/styles/createPalette'
import { SkeletonParagraph } from '../Skeleton'
import { SmartToyTwoTone } from '@mui/icons-material'
import { SynapseSpinner } from '../LoadingScreen/LoadingScreen'
import MarkdownSynapse from '../Markdown/MarkdownSynapse'

export type SynapseChatInteractionProps = {
userMessage: string
progressMessage?: string
chatResponseText?: string
scrollIntoView?: boolean
chatErrorReason?: string
Expand All @@ -19,7 +16,6 @@ export const SynapseChatInteraction: React.FunctionComponent<
SynapseChatInteractionProps
> = ({
userMessage,
progressMessage,
chatResponseText,
chatErrorReason,
scrollIntoView = false,
Expand Down Expand Up @@ -93,21 +89,6 @@ export const SynapseChatInteraction: React.FunctionComponent<
{chatErrorReason}
</Alert>
)}
{!chatResponseText && !chatErrorReason && (
<Box
sx={{
display: 'flex',
flexDirection: 'column',
justifyContent: 'center',
}}
>
<Typography sx={{ textAlign: 'center' }} variant="body1Italic">
{progressMessage ?? 'Processing...'}
</Typography>
<SynapseSpinner size={40} />
<SkeletonParagraph numRows={3} />
</Box>
)}
</>
)
}
Expand Down
33 changes: 33 additions & 0 deletions packages/synapse-react-client/src/mocks/chat/mockChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
ListAgentSessionsRequest,
ListAgentSessionsResponse,
SessionHistoryResponse,
TraceEventsResponse,
} from '@sage-bionetworks/synapse-types'

export const mockChatSessionId = 'session-456'
Expand Down Expand Up @@ -38,6 +39,7 @@ export const mockListAgentSessionsResponse: ListAgentSessionsResponse = {
export const mockAgentChatRequest: AgentChatRequest = {
sessionId: mockChatSessionId,
chatText: 'Hello! How can I access My Projects?',
enableTrace: true,
}

export const mockAgentChatResponse: AgentChatResponse = {
Expand Down Expand Up @@ -65,3 +67,34 @@ export const mockSessionHistoryResponse: SessionHistoryResponse = {
],
nextPageToken: undefined,
}

export const mockTraceEventsResponse1: TraceEventsResponse = {
jobId: ':id',
page: [
{
timestamp: 1695567600, // Example timestamp (in seconds)
message: 'Executing search on Synapse',
},
],
}

export const mockTraceEventsResponse2: TraceEventsResponse = {
jobId: ':id',
page: [
...mockTraceEventsResponse1.page,
{
timestamp: 1695567700, // Example timestamp (in seconds)
message: 'Gathering entity metadata',
},
],
}
export const mockTraceEventsResponse3: TraceEventsResponse = {
jobId: ':id',
page: [
...mockTraceEventsResponse2.page,
{
timestamp: 1695567800, // Example timestamp (in seconds)
message: 'Combining search results and entity metadata',
},
],
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { rest } from 'msw'
import {
AGENT_CHAT_TRACE,
AGENT_SESSION,
AGENT_SESSION_HISTORY,
GET_CHAT_ASYNC,
Expand All @@ -16,10 +17,14 @@ import {
mockChatSessionId,
mockListAgentSessionsResponse,
mockSessionHistoryResponse,
mockTraceEventsResponse1,
mockTraceEventsResponse2,
mockTraceEventsResponse3,
} from 'src/mocks/chat/mockChat'
import { generateAsyncJobHandlers } from './asyncJobHandlers'
import { BackendDestinationEnum, getEndpoint } from 'src/utils/functions'

let traceCallCount = 0
export const getChatbotHandlers = (
backendOrigin = getEndpoint(BackendDestinationEnum.REPO_ENDPOINT),
) => [
Expand Down Expand Up @@ -47,6 +52,20 @@ export const getChatbotHandlers = (
mockAgentChatResponse,
backendOrigin,
),

//trace events
rest.post(
`${backendOrigin}${AGENT_CHAT_TRACE(':id')}`,
async (_req, res, ctx) => {
//mock showing progress (increasing number of items)
traceCallCount++
if (traceCallCount == 2) {
return res(ctx.status(201), ctx.json(mockTraceEventsResponse1))
} else if (traceCallCount == 3) {
return res(ctx.status(201), ctx.json(mockTraceEventsResponse2))
} else return res(ctx.status(201), ctx.json(mockTraceEventsResponse3))
},
),
// generateAsyncJobHandlers(
// START_CHAT_ASYNC,
// tokenParam => GET_CHAT_ASYNC(tokenParam),
Expand Down
17 changes: 17 additions & 0 deletions packages/synapse-react-client/src/synapse-client/SynapseClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
ACCESS_REQUIREMENT_STATUS,
ACCESS_REQUIREMENT_WIKI_PAGE_KEY,
ACTIVITY_FOR_ENTITY,
AGENT_CHAT_TRACE,
AGENT_SESSION,
AGENT_SESSION_HISTORY,
ALIAS_AVAILABLE,
Expand Down Expand Up @@ -325,6 +326,8 @@ import {
SessionHistoryRequest,
VerificationState,
UpdateAgentSessionRequest,
TraceEventsRequest,
TraceEventsResponse,
} from '@sage-bionetworks/synapse-types'
import { calculateFriendlyFileSize } from '../utils/functions/calculateFriendlyFileSize'
import {
Expand Down Expand Up @@ -5547,3 +5550,17 @@ export const getSessionHistory = (
{ signal },
)
}

export const getChatAgentTraceEvents = (
request: TraceEventsRequest,
accessToken: string | undefined = undefined,
signal?: AbortSignal,
): Promise<TraceEventsResponse> => {
return doPost<TraceEventsResponse>(
AGENT_CHAT_TRACE(request.jobId),
request,
accessToken,
BackendDestinationEnum.REPO_ENDPOINT,
{ signal },
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
SubmissionSearchRequest,
SubscriptionObjectType,
SubscriptionQuery,
TraceEventsRequest,
TYPE_FILTER,
ViewColumnModelRequest,
ViewEntityType,
Expand Down Expand Up @@ -869,6 +870,10 @@ export class KeyFactory {
return this.getKey('fileBatch', request)
}

public getChatAgentTraceKey(request: TraceEventsRequest) {
return this.getKey('chatbotTraceEvents', request)
}

public getPaginatedDockerTagQueryKey(
id: string,
offset: string,
Expand Down
17 changes: 17 additions & 0 deletions packages/synapse-react-client/src/synapse-queries/chat/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import {
UseInfiniteQueryOptions,
useMutation,
UseMutationOptions,
useQuery,
UseQueryOptions,
} from '@tanstack/react-query'
import SynapseClient from '../../synapse-client'
import { SynapseClientError, useSynapseContext } from '../../utils'
Expand All @@ -19,6 +21,8 @@ import {
CreateAgentSessionRequest,
SessionHistoryRequest,
SessionHistoryResponse,
TraceEventsRequest,
TraceEventsResponse,
UpdateAgentSessionRequest,
} from '@sage-bionetworks/synapse-types'

Expand Down Expand Up @@ -142,3 +146,16 @@ export function useGetAgentChatSessionHistoryInfinite<
getNextPageParam: page => page.nextPageToken,
})
}

export function useGetChatAgentTraceEvents(
request: TraceEventsRequest,
options?: Partial<UseQueryOptions<TraceEventsResponse, SynapseClientError>>,
) {
const { accessToken, keyFactory } = useSynapseContext()
return useQuery({
...options,
queryKey: keyFactory.getChatAgentTraceKey(request),

queryFn: () => SynapseClient.getChatAgentTraceEvents(request, accessToken),
})
}
3 changes: 3 additions & 0 deletions packages/synapse-react-client/src/utils/APIConstants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ export const START_CHAT_ASYNC = `${AGENT}/chat/async/start`
export const GET_CHAT_ASYNC = (jobId: string | number) =>
`${AGENT}/chat/async/get/${jobId}`

export const AGENT_CHAT_TRACE = (jobId: string | number) =>
`${AGENT}/chat/async/get/${jobId}`

export const DOI = `${REPO}/doi`
export const DOI_ASSOCIATION = `${DOI}/association`

Expand Down
Loading

0 comments on commit c209b06

Please sign in to comment.