Skip to content

Commit

Permalink
Merge pull request #82 from ant-xuexiao/chore_switch_llm
Browse files Browse the repository at this point in the history
feat: add the search tool tavily to the agent
  • Loading branch information
xingwanying authored Apr 10, 2024
2 parents f07fb4e + 55ce041 commit feaadee
Show file tree
Hide file tree
Showing 19 changed files with 79 additions and 73 deletions.
2 changes: 1 addition & 1 deletion lui/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "petercat-lui",
"version": "0.0.3",
"version": "0.0.4",
"description": "A react library developed with dumi",
"module": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
2 changes: 1 addition & 1 deletion lui/src/Assistant/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ atomId: Assistant

```tsx
import React from 'react';
import { Assistant } from 'lui';
import { Assistant } from 'petercat-lui';

export default () => (
<Assistant
Expand Down
2 changes: 1 addition & 1 deletion lui/src/Assistant/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const Assistant = (props: AssistantProps) => {
className="fixed right-0 top-0 h-full flex flex-row z-[999] overflow-hidden text-left text-black bg-gradient-to-r from-f2e9ed via-e9eefb to-f0eeea shadow-[0px_0px_1px_#919eab3d]"
style={{ width: drawerWidth, zIndex: 9999 }}
>
<Chat {...props} />
<Chat {...props} drawerWidth={drawerWidth} />
<div className="absolute top-0 right-0 m-1">
<ActionIcon
icon={<CloseCircleFilled />}
Expand Down
2 changes: 1 addition & 1 deletion lui/src/Chat/index.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

```jsx
import React from 'react';
import { Chat } from 'lui';
import { Chat } from 'petercat-lui';


export default () => (
Expand Down
36 changes: 23 additions & 13 deletions lui/src/Chat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import type {
} from '@ant-design/pro-chat';
import { ProChat } from '@ant-design/pro-chat';
import { Markdown } from '@ant-design/pro-editor';
import StopBtn from 'lui/StopBtn';
import { theme } from 'lui/Theme';
import ThoughtChain from 'lui/ThoughtChain';
import { Role } from 'lui/interface';
import { BOT_INFO } from 'lui/mock';
import { streamChat } from 'lui/services/ChatController';
import { handleStream } from 'lui/utils';
import React, { ReactNode, memo, useRef, useState, type FC } from 'react';
import StopBtn from '../StopBtn';
import { theme } from '../Theme';
import ThoughtChain from '../ThoughtChain';
import { Role } from '../interface';
import { BOT_INFO } from '../mock';
import { streamChat } from '../services/ChatController';
import { handleStream } from '../utils';
import Actions from './inputArea/actions';

const { getDesignToken } = theme;
Expand All @@ -23,15 +23,19 @@ export interface ChatProps {
assistantMeta?: MetaData;
helloMessage?: string;
host?: string;
drawerWidth?: number;
slot?: {
componentID: string;
renderFunc: (data: any) => React.ReactNode;
}[];
}

const Chat: FC<ChatProps> = memo(({ helloMessage, host }) => {
const Chat: FC<ChatProps> = memo(({ helloMessage, host, drawerWidth }) => {
const proChatRef = useRef<ProChatInstance>();
const [chats, setChats] = useState<ChatMessage<Record<string, any>>[]>();
const messageMinWidth = drawerWidth
? `calc(${drawerWidth}px - 90px)`
: '100%';
return (
<div
className="h-full w-full"
Expand Down Expand Up @@ -60,10 +64,16 @@ const Chat: FC<ChatProps> = memo(({ helloMessage, host }) => {
},
contentRender: (props: ChatItemProps, defaultDom: ReactNode) => {
const originData = props.originData || {};
if (originData?.role === Role.user) {
return defaultDom;
}
const message = originData.content;
const defaultMessageContent = (
<div style={{ minWidth: messageMinWidth }}>{defaultDom}</div>
);

if (!message || !message.startsWith('<TOOL>')) {
return defaultDom;
return defaultMessageContent;
}

const [toolStr, answerStr] = message.split('<ANSWER>');
Expand All @@ -75,23 +85,23 @@ const Chat: FC<ChatProps> = memo(({ helloMessage, host }) => {

if (!match) {
console.error('No valid JSON found in input');
return defaultDom;
return defaultMessageContent;
}

try {
const config = JSON.parse(match[1]);
const { type, extra } = config;

if (![Role.knowledge, Role.tool].includes(type)) {
return defaultDom;
return defaultMessageContent;
}

const { status, source } = extra;

return (
<div
className="p-2 bg-white rounded-md "
style={{ minWidth: 'calc(375px - 90px)' }}
style={{ minWidth: messageMinWidth }}
>
<div className="mb-1">
<ThoughtChain
Expand All @@ -107,7 +117,7 @@ const Chat: FC<ChatProps> = memo(({ helloMessage, host }) => {
);
} catch (error) {
console.error(`JSON parse error: ${error}`);
return defaultDom;
return defaultMessageContent;
}
},
}}
Expand Down
2 changes: 1 addition & 1 deletion lui/src/StopBtn/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ atomId: StopBtn
# StopBtn
``` tsx
import React from 'react';
import { StopBtn } from 'lui';
import { StopBtn } from 'petercat-lui';

export default () => <StopBtn visible={true} />;
```
2 changes: 1 addition & 1 deletion lui/src/ThoughtChain/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ atomId: ThoughtChain

```tsx
import React from 'react';
import { ThoughtChain } from 'lui';
import { ThoughtChain } from 'petercat-lui';

export default () => (
<ThoughtChain
Expand Down
37 changes: 6 additions & 31 deletions lui/src/ThoughtChain/index.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import {
ApiOutlined,
CheckCircleOutlined,
CloseCircleOutlined,
DownOutlined,
ExclamationCircleOutlined,
FileTextOutlined,
LoadingOutlined,
UnorderedListOutlined,
UpOutlined,
Expand Down Expand Up @@ -72,35 +70,12 @@ const ThoughtChain: React.FC<ThoughtChainProps> = (params) => {
<DownOutlined className={`${getColorClass(status!)}`} />
</span>
),
children: (
<Collapse
ghost
size="small"
expandIcon={(panelProps) => {
const {
status: itemStatus,
knowledgeName,
pluginName,
} = (panelProps as IExtraInfo) || {};

if (itemStatus === Status.loading) {
return <LoadingOutlined className="text-blue-600 text-xs" />;
} else if (knowledgeName) {
return <FileTextOutlined className="text-gray-900 text-xs" />;
} else if (pluginName) {
return <ApiOutlined className="text-gray-900 text-xs" />;
}
return <></>;
}}
>
{safeJsonParse(content?.data) ? (
<Highlight language="json" theme="light" type="block">
{JSON.stringify(safeJsonParse(content?.data), null, 2)}
</Highlight>
) : (
<>{content?.data}</>
)}
</Collapse>
children: safeJsonParse(content?.data) ? (
<Highlight language="json" theme="light" type="block">
{JSON.stringify(safeJsonParse(content?.data), null, 2)}
</Highlight>
) : (
<>{content?.data}</>
),
},
];
Expand Down
2 changes: 1 addition & 1 deletion lui/src/mock/inputArea.mock.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { IBot } from 'lui/interface';
import { IBot } from '../interface';

export const DEFAULT_HELLO_MESSAGE =
'我是你的私人助理Kate, 我有许多惊人的能力,比如你可以对我说我想创建一个机器人';
Expand Down
2 changes: 1 addition & 1 deletion lui/src/services/ChatController.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { IPrompt } from 'lui/interface';
import { IPrompt } from '../interface';

/**
* Chat api
Expand Down
2 changes: 1 addition & 1 deletion lui/src/utils/chatTranslator.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { map } from 'lodash';
import { Role } from 'lui/interface';
import { Role } from '../interface';

export const convertChunkToJson = (rawData: string) => {
const regex = /data:(.*)/;
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"axios": "^1.6.7",
"concurrently": "^8.2.2",
"dayjs": "^1.11.10",
"petercat-lui": "^0.0.3",
"petercat-lui": "^0.0.4",
"eslint": "8.46.0",
"eslint-config-next": "13.4.12",
"framer-motion": "^10.16.15",
Expand Down
5 changes: 3 additions & 2 deletions server/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ DOCKER_SOCKET_LOCATION=/var/run/docker.sock
GOOGLE_PROJECT_ID=GOOGLE_PROJECT_ID
GOOGLE_PROJECT_NUMBER=GOOGLE_PROJECT_NUMBER

# GitHub Access Token
GITHUB_TOKEN=GITHUB_TOKEN

#TAVILY_API_KEY
TAVILY_API_KEY=TAVILY_API_KEY
18 changes: 12 additions & 6 deletions server/agent/stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import json
import os
import uuid
from langchain.tools import tool
from typing import AsyncIterator
Expand All @@ -12,11 +13,15 @@
from langchain.prompts import MessagesPlaceholder
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.prompts import ChatPromptTemplate
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
from langchain.tools.tavily_search import TavilySearchResults
from langchain_openai import ChatOpenAI
from uilts.env import get_env_variable
from tools import issue
from tools import sourcecode
from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage

TAVILY_API_KEY = get_env_variable("TAVILY_API_KEY")

prompt = ChatPromptTemplate.from_messages(
[
Expand Down Expand Up @@ -56,10 +61,11 @@ def get_datetime() -> datetime:
TOOLS = ["get_datetime", "create_issue", "get_issues", "search_issues", "search_code"]


def _create_agent_with_tools(openai_api_key: str ) -> AgentExecutor:
openai_api_key=openai_api_key
llm = ChatOpenAI(model="gpt-4", temperature=0.2, streaming=True)
tools = []
def _create_agent_with_tools(open_api_key: str) -> AgentExecutor:
llm = ChatOpenAI(model="gpt-4-1106-preview", temperature=0.2, streaming=True, max_tokens=1500, openai_api_key=open_api_key)
search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search)
tools = [tavily_tool]

for requested_tool in TOOLS:
if requested_tool not in TOOL_MAPPING:
Expand Down Expand Up @@ -104,10 +110,10 @@ def chat_history_transform(messages: list[Message]):
return transformed_messages


async def agent_chat(input_data: ChatData, openai_api_key) -> AsyncIterator[str]:
async def agent_chat(input_data: ChatData, open_api_key: str) -> AsyncIterator[str]:
try:
messages = input_data.messages
agent_executor = _create_agent_with_tools(openai_api_key)
agent_executor = _create_agent_with_tools(open_api_key)
print(chat_history_transform(messages))
async for event in agent_executor.astream_events(
{
Expand Down
6 changes: 3 additions & 3 deletions server/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from data_class import DalleData, ChatData
from openai_api import dalle
from langchain_api import chat
from agent import stream
from uilts.env import get_env_variable
import uvicorn

open_api_key = os.getenv("OPENAI_API_KEY")
open_api_key = get_env_variable("OPENAI_API_KEY")

app = FastAPI(
title="Bo-meta Server",
Expand Down Expand Up @@ -46,4 +46,4 @@ def run_agent_chat(input_data: ChatData):
return StreamingResponse(result, media_type="text/event-stream")

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080")))
1 change: 1 addition & 0 deletions server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ langchain-openai
PyGithub
python-multipart
httpx[socks]
load_dotenv
6 changes: 2 additions & 4 deletions server/tools/issue.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import json
import os
from typing import Optional
from github import Github
from langchain.tools import tool

GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')
from uilts.env import get_env_variable

DEFAULT_REPO_NAME = "ant-design/ant-design"

g = Github(GITHUB_TOKEN)
g = Github()

@tool
def create_issue(repo_name, title, body):
Expand Down
6 changes: 2 additions & 4 deletions server/tools/sourcecode.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
from typing import List, Optional
from github import Github
from github.ContentFile import ContentFile
from langchain.tools import tool
from uilts.env import get_env_variable


GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')

DEFAULT_REPO_NAME = "ant-design/ant-design"

g = Github(GITHUB_TOKEN)
g = Github()

@tool
def search_code(
Expand Down
17 changes: 17 additions & 0 deletions server/uilts/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dotenv import load_dotenv
import os

# Define a method to load an environmental variable and return its value
def get_env_variable(key: str, default=None):
"""
Retrieve the specified environment variable. Return the specified default value if the variable does not exist.
:param key: The name of the environment variable to retrieve.
:param default: The default value to return if the environment variable does not exist.
:return: The value of the environment variable, or the default value if it does not exist.
"""
# Load the .env file
load_dotenv(verbose=True, override=True)

# Get the environment variable, returning the default value if it does not exist
return os.getenv(key, default)

0 comments on commit feaadee

Please sign in to comment.