Skip to content

Commit

Permalink
Chat: add agent_id (#319)
Browse files Browse the repository at this point in the history
* Chat: add agent_id

* fix be

* fix non streamed chat

* logs

* remove logs

* frontend: agent_id in request body

---------

Co-authored-by: Abigail Mackenzie-Armes <[email protected]>
  • Loading branch information
lusmoura and abimacarmes authored Jun 28, 2024
1 parent 6f5761c commit 0d19337
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 17 deletions.
2 changes: 0 additions & 2 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ def call_chat(self, chat_request, deployment_model, **kwargs: Any):
kwargs.get("user_id"),
)

print(f"Chat history: {chat_request.chat_history}")

# Loop until there are no new tool calls
for step in range(MAX_STEPS):
logger.info(f"Step {step + 1}")
Expand Down
4 changes: 2 additions & 2 deletions src/backend/model_deployments/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def is_available(cls) -> bool:
@collect_metrics_chat
def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids"}),
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
**kwargs,
)
yield to_dict(response)
Expand All @@ -75,7 +75,7 @@ def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
) -> Generator[StreamedChatResponse, None, None]:
stream = self.client.chat_stream(
**chat_request.model_dump(exclude={"stream", "file_ids"}),
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
**kwargs,
)

Expand Down
4 changes: 2 additions & 2 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def is_available(cls) -> bool:
@collect_metrics_chat
def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids"}),
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
**kwargs,
)
yield to_dict(response)
Expand All @@ -82,7 +82,7 @@ def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
) -> Generator[StreamedChatResponse, None, None]:
stream = self.client.chat_stream(
**chat_request.model_dump(exclude={"stream", "file_ids"}),
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
**kwargs,
)

Expand Down
8 changes: 6 additions & 2 deletions src/backend/model_deployments/single_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def is_available(cls) -> bool:
@collect_metrics_chat
def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids", "model"}),
**chat_request.model_dump(
exclude={"stream", "file_ids", "model", "agent_id"}
),
**kwargs,
)
yield to_dict(response)
Expand All @@ -60,7 +62,9 @@ def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
) -> Generator[StreamedChatResponse, None, None]:
stream = self.client.chat_stream(
**chat_request.model_dump(exclude={"stream", "file_ids", "model"}),
**chat_request.model_dump(
exclude={"stream", "file_ids", "model", "agent_id"}
),
**kwargs,
)

Expand Down
6 changes: 2 additions & 4 deletions src/backend/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ async def chat_stream(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
agent_id: str | None = None,
) -> Generator[ChatResponseEvent, Any, None]:
"""
Stream chat endpoint to handle user messages and return chatbot responses.
Expand All @@ -40,7 +39,6 @@ async def chat_stream(
session (DBSessionDep): Database session.
chat_request (CohereChatRequest): Chat request data.
request (Request): Request object.
agent_id (str | None): Agent ID.
Returns:
EventSourceResponse: Server-sent event response with chatbot responses.
Expand All @@ -50,6 +48,7 @@ async def chat_stream(
trace_id = request.state.trace_id

user_id = request.headers.get("User-Id", None)
agent_id = chat_request.agent_id

(
session,
Expand Down Expand Up @@ -96,7 +95,6 @@ async def chat(
session: DBSessionDep,
chat_request: CohereChatRequest,
request: Request,
agent_id: str | None = None,
) -> NonStreamedChatResponse:
"""
Chat endpoint to handle user messages and return chatbot responses.
Expand All @@ -105,7 +103,6 @@ async def chat(
chat_request (CohereChatRequest): Chat request data.
session (DBSessionDep): Database session.
request (Request): Request object.
agent_id (str | None): Agent ID.
Returns:
NonStreamedChatResponse: Chatbot response.
Expand All @@ -115,6 +112,7 @@ async def chat(
trace_id = request.state.trace_id

user_id = request.headers.get("User-Id", None)
agent_id = chat_request.agent_id

(
session,
Expand Down
4 changes: 4 additions & 0 deletions src/backend/schemas/cohere_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,7 @@ class CohereChatRequest(BaseChatRequest):
default=None,
title="If set to true, the model will generate a single response in a single step. This is useful for generating a response to a single message.",
)
agent_id: str | None = Field(
default=None,
title="The agent ID to use for the chat.",
)
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ export type CohereChatRequest = {
prompt_truncation?: CohereChatPromptTruncation;
tool_results?: null;
force_single_step?: boolean | null;
agent_id?: string | null;
};
1 change: 1 addition & 0 deletions src/interfaces/coral_web/src/cohere-client/mappings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { CohereChatRequest } from './generated';

export const mapToChatRequest = (request: CohereChatRequest): CohereChatRequest => {
return {
agent_id: request.agent_id,
message: request.message,
model: request.model,
temperature: request.temperature ?? DEFAULT_CHAT_TEMPERATURE,
Expand Down
4 changes: 1 addition & 3 deletions src/interfaces/coral_web/src/hooks/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => {
newMessages: ChatMessage[];
request: CohereChatRequest;
headers: Record<string, string>;
agentId?: string;
streamConverse: UseMutateAsyncFunction<
StreamEnd | undefined,
CohereNetworkError,
Expand Down Expand Up @@ -214,7 +213,6 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => {
await streamConverse({
request,
headers,
agentId,
onRead: (eventData: ChatResponseEvent) => {
switch (eventData.event) {
case StreamEvent.STREAM_START: {
Expand Down Expand Up @@ -548,6 +546,7 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => {
file_ids: fileIds && fileIds.length > 0 ? fileIds : undefined,
temperature,
model,
agent_id: agentId,
...restOverrides,
};
};
Expand Down Expand Up @@ -585,7 +584,6 @@ export const useChat = (config?: { onSend?: (msg: string) => void }) => {
newMessages,
request,
headers,
agentId,
streamConverse: streamChat,
});
};
Expand Down
3 changes: 1 addition & 2 deletions src/interfaces/coral_web/src/hooks/streamChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ interface StreamingParams {
export interface StreamingChatParams extends StreamingParams {
request: CohereChatRequest;
headers: Record<string, string>;
agentId?: string;
}

const getUpdatedConversations =
Expand Down Expand Up @@ -116,7 +115,7 @@ export const useStreamChat = () => {
if (experimentalFeatures?.USE_EXPERIMENTAL_LANGCHAIN) {
await cohereClient.langchainChat(chatStreamParams);
} else {
await cohereClient.chat({ ...chatStreamParams, agentId: params.agentId });
await cohereClient.chat({ ...chatStreamParams });
}
} catch (e) {
if (isUnauthorizedError(e)) {
Expand Down

0 comments on commit 0d19337

Please sign in to comment.