From 766a4aaa853ca575a20ed3bcec781ea15ec18221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E6=9C=9D=E6=99=96?= Date: Sat, 23 Mar 2024 00:30:50 +0800 Subject: [PATCH] Refactor --- bin/action.php | 8 +- bin/agent.php | 83 -------- bin/interactive.php | 43 +++- bin/unittest.php | 2 +- composer.json | 5 +- config/autoload/odin.php | 73 ++++--- data/response.txt | 40 ++++ src/Action/AbstractAction.php | 49 ----- src/Action/ActionFactory.php | 18 -- src/Action/ActionTemplate.php | 85 -------- src/Action/CalculatorAction.php | 20 -- src/Action/SearchAction.php | 32 --- src/Agent/AbstractAgent.php | 36 ++++ src/Agent/Agent.php | 122 ----------- src/Agent/OpenAIToolsAgent.php | 183 +++++++++++++++++ src/Apis/AzureOpenAI/AzureOpenAI.php | 4 +- src/Apis/AzureOpenAI/AzureOpenAIConfig.php | 22 +- src/Apis/AzureOpenAI/Client.php | 53 ++--- ...nCallDefinition.php => ToolDefinition.php} | 42 ++-- ...ionCallParameter.php => ToolParameter.php} | 2 +- ...nCallParameters.php => ToolParameters.php} | 15 +- .../OpenAI/Response/ChatCompletionChoice.php | 4 +- .../Response/ChatCompletionResponse.php | 5 +- src/Apis/OpenAI/Response/FunctionCall.php | 113 ----------- src/Apis/OpenAI/Response/ToolCall.php | 94 +++++++++ src/Conversation/Conversation.php | 192 ------------------ src/Conversation/ConversationBak.php | 139 ------------- src/Conversation/Option.php | 76 ------- src/Interpreter/CodeRunner.php | 22 +- src/Logger.php | 35 ++-- src/Memory/AbstractMemory.php | 23 +-- src/Memory/MemoryInterface.php | 23 +++ src/Memory/MessageHistory.php | 61 +++--- src/Message/AbstractMessage.php | 25 ++- src/Message/AssistantMessage.php | 36 +++- src/Message/FunctionMessage.php | 18 ++ src/Message/MessageBuffer.php | 27 +++ src/Message/Role.php | 1 + src/Model/AzureOpenAIModel.php | 36 ++++ src/Model/ModelInterface.php | 16 ++ src/Model/OpenAIModel.php | 24 +++ src/{LLM.php => ModelFacade.php} | 2 +- src/ModelMapper.php | 17 +- src/Observer.php | 41 ++++ src/Prompt/CodeInterpreter.prompt | 13 +- src/Prompt/OpenAIToolsAgentPrompt.php | 67 ++++++ src/Prompt/PromptInterface.php | 17 ++ src/Tools/AbstractTool.php | 26 +++ src/Tools/TavilySearchResults.php | 124 +++++++++++ src/Tools/ToolInterface.php | 12 ++ src/Wrapper/TavilySearchApiWrapper.php | 66 ++++++ 51 files changed, 1160 insertions(+), 1132 deletions(-) delete mode 100644 bin/agent.php create mode 100644 data/response.txt delete mode 100644 src/Action/AbstractAction.php delete mode 100644 src/Action/ActionFactory.php delete mode 100644 src/Action/ActionTemplate.php delete mode 100644 src/Action/CalculatorAction.php delete mode 100644 src/Action/SearchAction.php create mode 100644 src/Agent/AbstractAgent.php delete mode 100644 src/Agent/Agent.php create mode 100644 src/Agent/OpenAIToolsAgent.php rename src/Apis/OpenAI/Request/{FunctionCallDefinition.php => ToolDefinition.php} (52%) rename src/Apis/OpenAI/Request/{FunctionCallParameter.php => ToolParameter.php} (98%) rename src/Apis/OpenAI/Request/{FunctionCallParameters.php => ToolParameters.php} (76%) delete mode 100644 src/Apis/OpenAI/Response/FunctionCall.php create mode 100644 src/Apis/OpenAI/Response/ToolCall.php delete mode 100644 src/Conversation/Conversation.php delete mode 100644 src/Conversation/ConversationBak.php delete mode 100644 src/Conversation/Option.php create mode 100644 src/Memory/MemoryInterface.php create mode 100644 src/Message/FunctionMessage.php create mode 100644 src/Message/MessageBuffer.php create mode 100644 src/Model/AzureOpenAIModel.php create mode 100644 src/Model/ModelInterface.php create mode 100644 src/Model/OpenAIModel.php rename src/{LLM.php => ModelFacade.php} (99%) create mode 100644 src/Observer.php create mode 100644 src/Prompt/OpenAIToolsAgentPrompt.php create mode 100644 src/Prompt/PromptInterface.php create mode 100644 src/Tools/AbstractTool.php create mode 100644 src/Tools/TavilySearchResults.php create mode 100644 src/Tools/ToolInterface.php create mode 100644 src/Wrapper/TavilySearchApiWrapper.php diff --git a/bin/action.php b/bin/action.php index fc58c5d..8d22939 100644 --- a/bin/action.php +++ b/bin/action.php @@ -8,7 +8,7 @@ use Hyperf\Odin\Apis\OpenAI\OpenAI; use Hyperf\Odin\Apis\OpenAI\OpenAIConfig; use Hyperf\Odin\Apis\RWKV\RWKVConfig; -use Hyperf\Odin\Conversation\Conversation; +use Hyperf\Odin\Conversation\ConversationBak; use Hyperf\Odin\Memory\MessageHistory; use function Hyperf\Support\env as env; @@ -44,15 +44,15 @@ function getClient(string $type = 'azure') $client = getClient('azure'); $conversionId = uniqid(); -$conversation = new Conversation(); +$conversation = new ConversationBak(); $memory = new MessageHistory(); $input = '1+2=?,然后帮我查查东莞的明天的天气情况'; //$input = '东莞明天的最高多少度?以及 1+1=?,并将计算结果赋值给x用于下一次计算,x+10=?'; -$response = $conversation->chat($client, $input, 'gpt-3.5-turbo', $conversionId, $memory, [ +$response = $conversation->chat($input, 'gpt-3.5-turbo', $conversionId, $memory, [ new CalculatorAction(), new WeatherAction(), new SearchAction() -]); +], client: $client); echo PHP_EOL . PHP_EOL; echo '[FINAL] AI: ' . $response; \ No newline at end of file diff --git a/bin/agent.php b/bin/agent.php deleted file mode 100644 index 2e3cb12..0000000 --- a/bin/agent.php +++ /dev/null @@ -1,83 +0,0 @@ -agent = new Agent(); - $this->memory = new MessageHistory(); - $this->actions = [new CalculatorAction(), new WeatherAction(), new SearchAction()]; - } - - public function chat(string $input, string $conversionId, string $llmType = 'azure'): string - { - $client = $this->getClient($llmType); - $client->setDebug($this->debug); - return $this->agent->chat($client, $input, $this->model, $conversionId, $this->memory, $this->actions); - } - - public function getClient(string $type = 'azure') - { - switch ($type) { - case 'openai': - $openAI = new OpenAI(); - $config = new OpenAIConfig(env('OPENAI_API_KEY'),); - $client = $openAI->getClient($config); - break; - case 'azure': - $openAI = new AzureOpenAI(); - $config = new AzureOpenAIConfig(apiKey: env('AZURE_OPENAI_API_KEY'), baseUrl: env('AZURE_OPENAI_API_BASE'), apiVersion: env('AZURE_OPENAI_API_VERSION'), deploymentName: env('AZURE_OPENAI_DEPLOYMENT_NAME'),); - $client = $openAI->getClient($config); - break; - case 'rwkv': - $rwkv = new Hyperf\Odin\Apis\RWKV\RWKV(); - $config = new RWKVConfig(env('RWKV_HOST'),); - $client = $rwkv->getClient($config); - break; - default: - throw new \RuntimeException('Invalid type'); - } - return $client; - } -} - -$llm = new LLM(true); - -$inputs = [ - '1+12=?,以及东莞明天的天气如何?', - '我刚才询问天气的是哪个城市?', - '能见度如何?', - '12加上22等于多少', - '我都询问过哪些数学计算,列出所有', -]; - -$conversionId = uniqid(); - -foreach ($inputs as $input) { - echo '[Human]: ' . $input . PHP_EOL; - echo '[AI]: ' . $llm->chat($input, $conversionId, llmType: 'azure') . PHP_EOL; -} \ No newline at end of file diff --git a/bin/interactive.php b/bin/interactive.php index ba999b3..76f10c1 100644 --- a/bin/interactive.php +++ b/bin/interactive.php @@ -11,17 +11,50 @@ * @license https://github.com/hyperf/hyperf/blob/master/LICENSE */ -use Hyperf\Odin\Conversation\Option; +use Hyperf\Odin\Agent\OpenAIToolsAgent; use Hyperf\Odin\Memory\MessageHistory; -use Hyperf\Odin\Prompt\Prompt; +use Hyperf\Odin\ModelMapper; +use Hyperf\Odin\Observer; +use Hyperf\Odin\Prompt\OpenAIToolsAgentPrompt; +use Hyperf\Odin\Tools\TavilySearchResults; $container = require_once dirname(dirname(__FILE__)) . '/bin/init.php'; -$llm = $container->get(\Hyperf\Odin\LLM::class); -$conversation = $llm->createConversation()->generateConversationId()->withMemory(new MessageHistory()); +/** @var \Hyperf\Odin\ModelMapper $modelMapper */ +$modelMapper = $container->get(ModelMapper::class); +$llm = $modelMapper->getDefaultModel(); +$prompt = new OpenAIToolsAgentPrompt(); +/** @var TavilySearchResults $tavilySearchResults */ +$tavilySearchResults = $container->get(TavilySearchResults::class); +$tools = [ + $tavilySearchResults->setUseAnswerDirectly(false)->setSearchDepth('advanced'), +]; +$observer = $container->get(Observer::class); +/** @var MessageHistory $memory */ +$memory = $container->get(MessageHistory::class); +$conversationId = uniqid('agent_', true); +$agent = new OpenAIToolsAgent(model: $llm, prompt: $prompt, memory: $memory, observer: $observer, tools: $tools); while (true) { echo 'Human: '; $input = trim(fgets(STDIN, 1024)); - $response = $conversation->chat(Prompt::input($input), '', new Option()); + $isCommand = false; + switch ($input) { + case 'dump-messages': + var_dump($memory->getConversations($conversationId)); + $isCommand = true; + break; + case 'enable-debug': + $agent->setDebug(true); + $isCommand = true; + break; + case 'disable-debug': + $agent->setDebug(false); + $isCommand = true; + break; + } + if ($isCommand) { + continue; + } + $response = $agent->invoke(['input' => $input], $conversationId); echo 'AI: ' . $response . PHP_EOL; } diff --git a/bin/unittest.php b/bin/unittest.php index 164d06a..465ecab 100644 --- a/bin/unittest.php +++ b/bin/unittest.php @@ -35,7 +35,7 @@ public function handle(string $a, string $b): string ``` PROMPT; -$llm = $container->get(\Hyperf\Odin\LLM::class); +$llm = $container->get(\Hyperf\Odin\ModelFacade::class); echo '[AI]: ' . $llm->chat([ 'system' => new SystemMessage('You are a unit test generation robot developed by the Hyperf organization. You must return content strictly in accordance with the format requirements.'), diff --git a/composer.json b/composer.json index f00aeaa..0fc4bc8 100644 --- a/composer.json +++ b/composer.json @@ -21,9 +21,10 @@ "php": ">=8.0", "ext-bcmath": "*", "guzzlehttp/guzzle": "^7.0", - "hyperf/di": "~2.2.0 || 3.0.*", - "hyperf/config": "~2.2.0 || 3.0.*", + "hyperf/di": "^3.0.0", + "hyperf/config": "^3.0.0", "hyperf/qdrant-client": "*", + "hyperf/support": "^3.0.0", "yethee/tiktoken": "^0.1.2" }, "require-dev": { diff --git a/config/autoload/odin.php b/config/autoload/odin.php index c0b3996..d218290 100644 --- a/config/autoload/odin.php +++ b/config/autoload/odin.php @@ -1,55 +1,67 @@ [ - 'default' => 'gpt-3.5-turbo', + 'default' => 'gpt-4', // Modify this according to your needs 'models' => [ + 'gpt-35-turbo' => [ + 'name' => 'gpt-35-turbo', + 'implementation' => OpenAIModel::class, + 'config' => [ + 'api_key' => env('AZURE_OPENAI_35_TURBO_API_KEY'), + 'api_base' => env('AZURE_OPENAI_35_TURBO_API_BASE'), + 'api_version' => env('AZURE_OPENAI_35_TURBO_API_VERSION', '2023-08-01-preview'), + 'deployment_name' => env('AZURE_OPENAI_35_TURBO_DEPLOYMENT_NAME'), + ], + ], 'gpt-3.5-turbo' => [ 'name' => 'gpt-3.5-turbo', - 'api_type' => 'azure', + 'implementation' => AzureOpenAIModel::class, + 'config' => [ + 'api_key' => env('AZURE_OPENAI_35_TURBO_API_KEY'), + 'api_base' => env('AZURE_OPENAI_35_TURBO_API_BASE'), + 'api_version' => env('AZURE_OPENAI_35_TURBO_API_VERSION', '2023-08-01-preview'), + 'deployment_name' => env('AZURE_OPENAI_35_TURBO_DEPLOYMENT_NAME'), + ], ], 'gpt-3.5-turbo-16k' => [ 'name' => 'gpt-3.5-turbo-16k', - 'api_type' => 'azure', + 'implementation' => AzureOpenAIModel::class, + 'config' => [ + 'api_key' => env('AZURE_OPENAI_35_TURBO_16K_API_KEY'), + 'api_base' => env('AZURE_OPENAI_35_TURBO_16K_API_BASE'), + 'api_version' => env('AZURE_OPENAI_35_TURBO_16K_API_VERSION', '2023-08-01-preview'), + 'deployment_name' => env('AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_NAME'), + ], ], 'gpt-4' => [ 'name' => 'gpt-4', - 'api_type' => 'azure', + 'implementation' => AzureOpenAIModel::class, + 'config' => [ + 'api_key' => env('AZURE_OPENAI_4_API_KEY'), + 'api_base' => env('AZURE_OPENAI_4_API_BASE'), + 'api_version' => env('AZURE_OPENAI_4_API_VERSION', '2023-08-01-preview'), + 'deployment_name' => env('AZURE_OPENAI_4_DEPLOYMENT_NAME'), + ], ], 'gpt-4-32k' => [ 'name' => 'gpt-4-32k', - 'api_type' => 'azure', + 'implementation' => AzureOpenAIModel::class, + 'config' => [ + 'api_key' => env('AZURE_OPENAI_4_32K_API_KEY'), + 'api_base' => env('AZURE_OPENAI_4_32K_API_BASE'), + 'api_version' => env('AZURE_OPENAI_4_32K_API_VERSION', '2023-08-01-preview'), + 'deployment_name' => env('AZURE_OPENAI_4_32K_DEPLOYMENT_NAME'), + ], ], ], ], 'azure' => [ - 'gpt-3.5-turbo' => [ - 'api_key' => env('AZURE_OPENAI_35_TURBO_API_KEY'), - 'api_base' => env('AZURE_OPENAI_35_TURBO_API_BASE'), - 'api_version' => env('AZURE_OPENAI_35_TURBO_API_VERSION', '2023-08-01-preview'), - 'deployment_name' => env('AZURE_OPENAI_35_TURBO_DEPLOYMENT_NAME'), - ], - 'gpt-3.5-turbo-16k' => [ - 'api_key' => env('AZURE_OPENAI_35_TURBO_16K_API_KEY'), - 'api_base' => env('AZURE_OPENAI_35_TURBO_16K_API_BASE'), - 'api_version' => env('AZURE_OPENAI_35_TURBO_16K_API_VERSION', '2023-08-01-preview'), - 'deployment_name' => env('AZURE_OPENAI_35_TURBO_16K_DEPLOYMENT_NAME'), - ], - 'gpt-4' => [ - 'api_key' => env('AZURE_OPENAI_4_API_KEY'), - 'api_base' => env('AZURE_OPENAI_4_API_BASE'), - 'api_version' => env('AZURE_OPENAI_4_API_VERSION', '2023-08-01-preview'), - 'deployment_name' => env('AZURE_OPENAI_4_DEPLOYMENT_NAME'), - ], - 'gpt-4-32k' => [ - 'api_key' => env('AZURE_OPENAI_4_32K_API_KEY'), - 'api_base' => env('AZURE_OPENAI_4_32K_API_BASE'), - 'api_version' => env('AZURE_OPENAI_4_32K_API_VERSION', '2023-08-01-preview'), - 'deployment_name' => env('AZURE_OPENAI_4_32K_DEPLOYMENT_NAME'), - ], 'text-embedding-ada-002' => [ 'api_key' => env('AZURE_OPENAI_TEXT_EMBEDDING_ADA_002_API_KEY'), 'api_base' => env('AZURE_OPENAI_TEXT_EMBEDDING_ADA_002_API_BASE'), @@ -60,4 +72,7 @@ 'openai' => [ 'api_key' => env('OPENAI_API_KEY'), ], + 'tavily' => [ + 'api_key' => env('TAVILY_API_KEY'), + ], ]; \ No newline at end of file diff --git a/data/response.txt b/data/response.txt new file mode 100644 index 0000000..8c9f6fa --- /dev/null +++ b/data/response.txt @@ -0,0 +1,40 @@ +Thought: The user wants to group the "建店信息表" (Store Information Table) by "大区" (Region). + +Action: ModifyViews + +Observation: I need to modify the view of the "建店信息表-主表" (Store Information Table - Main Table) to include the "大区" (Region) column in the groups section. + +Action: + +```json +{ + "action": "modify_views", + "original_views": [ + { + "view_type": "table", + "name": "建店信息表-主表" + } + ], + "new_views": [ + { + "view_type": "table", + "table_name": "建店信息表", + "name": "建店信息表-主表", + "functions": { + "filters": [], + "sorts": [], + "groups": [ + { + "column": "大区" + } + ] + }, + "columns": [] + } + ] +} +``` + +Thought: The view has been modified to group the "建店信息表" (Store Information Table) by "大区" (Region). + +Final Answer: "建店信息表" (Store Information Table) has been grouped by "大区" (Region). \ No newline at end of file diff --git a/src/Action/AbstractAction.php b/src/Action/AbstractAction.php deleted file mode 100644 index 345e20f..0000000 --- a/src/Action/AbstractAction.php +++ /dev/null @@ -1,49 +0,0 @@ -name; - } - - public function setName(string $name): static - { - $this->name = $name; - return $this; - } - - public function getDesc(): string - { - return $this->desc; - } - - public function setDesc(string $desc): static - { - $this->desc = $desc; - return $this; - } - - public function getClient(): ClientInterface - { - return $this->client; - } - - public function setClient(ClientInterface $client): static - { - $this->client = $client; - return $this; - } - -} \ No newline at end of file diff --git a/src/Action/ActionFactory.php b/src/Action/ActionFactory.php deleted file mode 100644 index 0090a16..0000000 --- a/src/Action/ActionFactory.php +++ /dev/null @@ -1,18 +0,0 @@ - $result) { - $resultPrompt .= sprintf("%s: %s\n", $actionName, $result); - } - return <<getName(), $action->getDesc()); - } - } - return << trim($actionName), - 'args' => $actionArgs, - ]; - } - } - } - } - return $actions; - } - -} \ No newline at end of file diff --git a/src/Action/CalculatorAction.php b/src/Action/CalculatorAction.php deleted file mode 100644 index 2e2ad23..0000000 --- a/src/Action/CalculatorAction.php +++ /dev/null @@ -1,20 +0,0 @@ -get($path, [ - 'query' => [ - 'engine' => 'baidu', - 'q' => $keyword, - ], - ]); - $webContent = $response->getBody()->getContents(); - var_dump($webContent); - exit(); - return '搜索结果:' . $webContent; - } - -} \ No newline at end of file diff --git a/src/Agent/AbstractAgent.php b/src/Agent/AbstractAgent.php new file mode 100644 index 0000000..6e9f738 --- /dev/null +++ b/src/Agent/AbstractAgent.php @@ -0,0 +1,36 @@ +defaultPrompt) { + $prompt = $this->defaultPrompt; + } else { + $prompt = Prompt::getPrompt($this->name); + } + } + $this->prompt = $prompt; + } +} diff --git a/src/Agent/Agent.php b/src/Agent/Agent.php deleted file mode 100644 index 91cd536..0000000 --- a/src/Agent/Agent.php +++ /dev/null @@ -1,122 +0,0 @@ -buildPrompt($input, $agentThoughtAndObservation, $memory, $conversationId, $actions); - $response = $client->chat([ - 'system' => $this->buildSystemMessage(), - 'user' => new UserMessage($prompt), - ], $model, temperature: 0, stop: ['Observation:', 'Observation:\n', 'Observation:\n\n']); - [$result, $isFinalAnswer] = $this->parseAgentOutput($response, $client); - if (! $isFinalAnswer) { - foreach ($result as $item) { - $agentThoughtAndObservation .= " " . $item . "\n"; - } - return $this->chat($client, $input, $model, $conversationId, $memory, $actions, $agentThoughtAndObservation); - } else { - $finalAnswer = end($result); - $finalAnswer = trim(Str::replaceFirst('Final Answer:', '', $finalAnswer)); - } - return trim($finalAnswer); - } - - protected function parseAgentOutput(ChatCompletionResponse $response, ClientInterface $client): array - { - $result = []; - $isFinalAnswer = false; - $currentChoice = current($response->getChoices()); - $currentChoiceMessage = trim($currentChoice?->getMessage()?->getContent() ?? ''); - $observationPrefix = 'Observation: '; - $actionPrefix = 'Action:'; - $thoughtPrefix = 'Thought:'; - $finalAnswerPrefix = 'Final Answer:'; - $lines = explode("\n", $currentChoiceMessage); - foreach ($lines as $line) { - if (str_starts_with($line, $thoughtPrefix)) { - $result[] = trim($line); - } elseif (str_starts_with($line, $actionPrefix) && $line !== ($actionPrefix . ' null')) { - $line = trim($line); - $result[] = $line; - $actionJson = trim(Str::replaceFirst($actionPrefix, '', $line)); - try { - $action = Json::decode($actionJson); - $actionResult = $this->handleAction($action, $client); - if ($actionResult) { - $result[] = $observationPrefix . $actionResult; - } - } catch (InvalidArgumentException $exception) { - var_dump($exception->getMessage(), $line); - exit(); - } - } elseif (str_starts_with($line, $finalAnswerPrefix)) { - $isFinalAnswer = true; - $result[] = trim($line); - } else { - continue; - } - } - - return [$result, $isFinalAnswer]; - } - - protected function handleAction(array $action, ClientInterface $client) - { - if (! isset($action['name'], $action['args'])) { - return null; - } - $name = $action['name']; - $args = $action['args']; - $actionFactory = new ActionFactory(); - try { - $instance = $actionFactory->create($name); - $instance->setClient($client); - if (! method_exists($instance, 'handle')) { - return null; - } - return $instance->handle(...$args); - } catch (\InvalidArgumentException $exception) { - return null; - } - } - - protected function buildPrompt( - string $input, - string $agentThoughtAndObservation, - ?AbstractMemory $memory, - ?string $conversationId, - array $actions - ): string { - $template = new AgentPromptTemplate(); - return $template->build($input, $agentThoughtAndObservation, $actions); - } - - protected function buildSystemMessage(): SystemMessage - { - return new SystemMessage('You are a robot developed by the Hyperf organization, you must return content in strict format requirements.'); - } - -} \ No newline at end of file diff --git a/src/Agent/OpenAIToolsAgent.php b/src/Agent/OpenAIToolsAgent.php new file mode 100644 index 0000000..ec0dca8 --- /dev/null +++ b/src/Agent/OpenAIToolsAgent.php @@ -0,0 +1,183 @@ +prompt->getUserPrompt($inputs['input'] ?? ''); + $this->memory->setSystemMessage($this->prompt->getSystemPrompt(), $conversationId); + $conversationMessages = $this->memory->getConversations($conversationId); + $response = $this->chat(array_merge($conversationMessages, [$currentStageUserMessage]), conversationId: $conversationId, tools: $this->tools); + // Handle tool calls + if ($response->getChoices()) { + foreach ($response->getChoices() as $choice) { + if (! $choice instanceof ChatCompletionChoice || ! $choice->isFinishedByToolCall() || ! $choice->getMessage() instanceof AssistantMessage) { + continue; + } + $toolCalls = $choice->getMessage()->getToolCalls(); + if ($toolCalls) { + $toolCallsResults = []; + $toolsWithKey = []; + foreach ($this->tools as $tool) { + if ($tool instanceof ToolInterface) { + $toolDefinition = $tool->toToolDefinition(); + $toolsWithKey[$toolDefinition->getName()] = $toolDefinition; + } elseif ($tool instanceof ToolDefinition) { + $toolsWithKey[$tool->getName()] = $tool; + } + } + foreach ($toolCalls as $toolCall) { + if (! $toolCall instanceof ToolCall) { + continue; + } + $targetTool = $toolsWithKey[$toolCall->getName()] ?? null; + if (! $targetTool) { + continue; + } + $toolHandler = $targetTool->getToolHandler(); + if (is_callable($toolHandler)) { + $this->observer?->info(sprintf('Invoking tool %s with arguments %s', $toolCall->getName(), json_encode($toolCall->getArguments(), JSON_UNESCAPED_UNICODE))); + $result = call_user_func($toolHandler, ...$toolCall->getArguments()); + if ($result) { + $toolCallsResults[$toolCall->getId()] = [ + 'call' => sprintf('%s(%s)', $toolCall->getName(), implode(', ', $toolCall->getArguments())), + 'result' => $result, + ]; + if ($this->isDebug()) { + $this->observer?->debug(sprintf('Tool %s returned %s', $toolCall->getName(), json_encode($result, JSON_UNESCAPED_UNICODE))); + } else { + $this->observer?->info(sprintf('Tool %s returned', $toolCall->getName())); + } + } else { + $this->observer?->info(sprintf('Tool %s returned nothing', $toolCall->getName())); + } + } + } + if ($toolCallsResults) { + $toolCallsMessages = []; + foreach ($toolCallsResults as $toolCallResult) { + if (! $toolCallResult['call'] || ! $toolCallResult['result']) { + continue; + } + $toolCallsMessages[] = new FunctionMessage("Tool Call: {call}\nObservation: {observation}", [ + 'call' => $toolCallResult['call'], + 'observation' => json_encode($toolCallResult['result'], JSON_UNESCAPED_UNICODE), + ]); + } + $messages = $this->memory->getConversations($conversationId); + return $this->innerChat(array_merge($messages, $toolCallsMessages, [$currentStageUserMessage]), conversationId: $conversationId); + } + } else { + $this->memory->addMessages([$currentStageUserMessage], $conversationId); + } + } + } + return $response; + } + + public function isDebug(): bool + { + return $this->debug; + } + + public function setDebug(bool $debug): static + { + $this->debug = $debug; + $this->observer?->setDebug($debug); + return $this; + } + + protected function innerChat( + array $messages, + string $conversationId, + float $temperature = 0.5, + int $maxTokens = 0, + array $stop = [], + array $tools = [], + ) { + if ($this->currentIteration >= $this->maxIterations) { + throw new InvalidArgumentException('The maximum iterations has been reached.'); + } + $response = $this->model->chat($messages, $temperature, $maxTokens, $stop, $tools); + if ($response instanceof ChatCompletionResponse) { + + } + } + + /** + * @param \Hyperf\Odin\Message\MessageInterface[] $messages + */ + protected function chat( + array $messages, + string $conversationId, + float $temperature = 0.5, + int $maxTokens = 0, + array $stop = [], + array $tools = [], + ) { + if ($this->currentIteration >= $this->maxIterations) { + throw new InvalidArgumentException('The maximum iterations has been reached.'); + } + if ($this->isDebug()) { + $this->observer?->debug('Chatting to the model with messages ' . implode("\r", array_map(function ($message + ) { + return sprintf('%s Prompt: %s', $message->getRole()->name, $message->getContent()); + }, $messages))); + } else { + $this->observer?->info('Chatting to the model'); + } + var_dump($messages); + $response = $this->model->chat($messages, $temperature, $maxTokens, $stop, $tools); + if ($response instanceof ChatCompletionResponse) { + $message = $response->getFirstChoice()->getMessage(); + if ($this->isDebug()) { + $this->observer?->debug(sprintf('Model response %s message: %s', $message->getRole()->value, $message->getContent())); + } else { + $this->observer?->info('Model has responded'); + } + } + ++$this->currentIteration; + return $response; + } +} diff --git a/src/Apis/AzureOpenAI/AzureOpenAI.php b/src/Apis/AzureOpenAI/AzureOpenAI.php index 45ff376..fdfa4b2 100644 --- a/src/Apis/AzureOpenAI/AzureOpenAI.php +++ b/src/Apis/AzureOpenAI/AzureOpenAI.php @@ -24,12 +24,12 @@ class AzureOpenAI public function getClient(AzureOpenAIConfig $config, string $modelName): Client { - $apiKey = $config->getApiKey($modelName); + $apiKey = $config->getApiKey(); $storageKey = $apiKey . '-' . $modelName; if ($apiKey && isset($this->clients[$apiKey])) { return $this->clients[$storageKey]; } - $client = new Client($config, new Logger()); + $client = new Client($config, new Logger(), $modelName); $this->clients[$storageKey] = $client; return $client; } diff --git a/src/Apis/AzureOpenAI/AzureOpenAIConfig.php b/src/Apis/AzureOpenAI/AzureOpenAIConfig.php index 0097c96..57f12ab 100644 --- a/src/Apis/AzureOpenAI/AzureOpenAIConfig.php +++ b/src/Apis/AzureOpenAI/AzureOpenAIConfig.php @@ -15,33 +15,33 @@ class AzureOpenAIConfig { public function __construct( - protected array $mapper = [], + protected array $config = [], ) { } - public function getApiKey(string $model): ?string + public function getApiKey(): ?string { - return $this->mapper[$model]['api_key'] ?? null; + return $this->config['api_key'] ?? null; } - public function getBaseUrl(string $model): string + public function getBaseUrl(): string { - return $this->mapper[$model]['api_base'] ?? ''; + return $this->config['api_base'] ?? ''; } - public function getApiVersion(string $model): ?string + public function getApiVersion(): ?string { - return $this->mapper[$model]['api_version'] ?? null; + return $this->config['api_version'] ?? null; } - public function getDeploymentName(string $model): ?string + public function getDeploymentName(): ?string { - return $this->mapper[$model]['deployment_name'] ?? null; + return $this->config['deployment_name'] ?? null; } - public function getMapper(): array + public function getConfig(): array { - return $this->mapper; + return $this->config; } } diff --git a/src/Apis/AzureOpenAI/Client.php b/src/Apis/AzureOpenAI/Client.php index b257d90..d786e10 100644 --- a/src/Apis/AzureOpenAI/Client.php +++ b/src/Apis/AzureOpenAI/Client.php @@ -14,12 +14,13 @@ use GuzzleHttp\Client as GuzzleClient; use Hyperf\Odin\Apis\ClientInterface; -use Hyperf\Odin\Apis\OpenAI\Request\FunctionCallDefinition; +use Hyperf\Odin\Apis\OpenAI\Request\ToolDefinition; use Hyperf\Odin\Apis\OpenAI\Response\ChatCompletionResponse; use Hyperf\Odin\Apis\OpenAI\Response\ListResponse; use Hyperf\Odin\Apis\OpenAI\Response\TextCompletionResponse; use Hyperf\Odin\Exception\NotImplementedException; use Hyperf\Odin\Message\MessageInterface; +use Hyperf\Odin\Tools\ToolInterface; use InvalidArgumentException; use Psr\Log\LoggerInterface; @@ -35,10 +36,12 @@ class Client implements ClientInterface protected ?LoggerInterface $logger; protected bool $debug = false; + protected string $model; - public function __construct(AzureOpenAIConfig $config, LoggerInterface $logger = null) + public function __construct(AzureOpenAIConfig $config, LoggerInterface $logger, string $model) { $this->logger = $logger; + $this->model = $model; $this->initConfig($config); } @@ -48,7 +51,7 @@ public function chat( float $temperature = 0.9, int $maxTokens = 1000, array $stop = [], - array $functions = [], + array $tools = [], ): ChatCompletionResponse { $deploymentPath = $this->buildDeploymentPath($model); $messagesArr = []; @@ -65,29 +68,33 @@ public function chat( if ($maxTokens) { $json['max_tokens'] = $maxTokens; } - if ($functions) { - $functionsArray = []; - foreach ($functions as $function) { - if ($function instanceof FunctionCallDefinition) { - $functionsArray[] = $function->toArray(); + if (! empty($tools)) { + $toolsArray = []; + foreach ($tools as $tool) { + if ($tool instanceof ToolInterface) { + $toolsArray[] = $tool->toToolDefinition()->toArray(); + } elseif ($tool instanceof ToolDefinition) { + $toolsArray[] = $tool->toArray(); } else { - $functionsArray[] = $function; + $toolsArray[] = $tool; } } - - $json['functions'] = $functionsArray; - $json['function_call'] = 'auto'; + if (! empty($toolsArray)) { + $json['tools'] = $toolsArray; + $json['tool_choice'] = 'auto'; + } } if ($stop) { $json['stop'] = $stop; } - $this->debug && $this->logger?->debug(sprintf("Send: \nSystem Message: %s\nUser Message: %s", $messages['system'] ?? '', $messages['user'] ?? '')); + $this->debug && $this->logger?->debug(sprintf("Send: \nSystem Message: %s\nUser Message: %s\nTools: %s", $messages['system'] ?? '', $messages['user'] ?? '', json_encode($tools))); try { $response = $this->getClient($model)->post($deploymentPath . '/chat/completions', [ 'query' => [ 'api-version' => $this->config->getApiVersion($model), ], 'json' => $json, + 'verify' => false, ]); } catch (\Exception $exception) { var_dump($json); @@ -160,17 +167,15 @@ protected function initConfig(AzureOpenAIConfig $config): static throw new InvalidArgumentException('AzureOpenAIConfig is required'); } $this->config = $config; - foreach ($config->getMapper() as $model => $modelConfig) { - $headers = [ - 'api-key' => $config->getApiKey($model), - 'Content-Type' => 'application/json', - 'User-Agent' => 'Hyperf-Odin/1.0', - ]; - $this->clients[$model] = new GuzzleClient([ - 'base_uri' => $config->getBaseUrl($model), - 'headers' => $headers, - ]); - } + $headers = [ + 'api-key' => $config->getApiKey(), + 'Content-Type' => 'application/json', + 'User-Agent' => 'Hyperf-Odin/1.0', + ]; + $this->clients[$this->model] = new GuzzleClient([ + 'base_uri' => $config->getBaseUrl(), + 'headers' => $headers, + ]); return $this; } diff --git a/src/Apis/OpenAI/Request/FunctionCallDefinition.php b/src/Apis/OpenAI/Request/ToolDefinition.php similarity index 52% rename from src/Apis/OpenAI/Request/FunctionCallDefinition.php rename to src/Apis/OpenAI/Request/ToolDefinition.php index 710cf66..7182ae4 100644 --- a/src/Apis/OpenAI/Request/FunctionCallDefinition.php +++ b/src/Apis/OpenAI/Request/ToolDefinition.php @@ -15,56 +15,54 @@ use Hyperf\Contract\Arrayable; use InvalidArgumentException; -class FunctionCallDefinition implements Arrayable +class ToolDefinition implements Arrayable { protected string $name; protected string $description; - protected ?FunctionCallParameters $parameters; + protected ?ToolParameters $parameters; /** * @var callable[] */ - protected array $functionCallHandlers = []; + protected array $toolHandler = []; public function __construct( string $name, string $description = '', - ?FunctionCallParameters $parameters = null, - callable|array $functionHandlers = [] + ?ToolParameters $parameters = null, + callable|array $toolHandler = [] ) { $this->name = $name; $this->description = $description; $this->parameters = $parameters; - $this->setFunctionCallHandlers($functionHandlers); + $this->setToolHandler($toolHandler); } public function toArray(): array { return [ - 'name' => $this->getName(), - 'description' => $this->getDescription(), - 'parameters' => $this->getParameters()?->toArray(), + 'type' => 'function', + 'function' => [ + 'name' => $this->getName(), + 'description' => $this->getDescription(), + 'parameters' => $this->getParameters()?->toArray(), + ] ]; } - public function getFunctionCallHandlers(): array + public function getToolHandler(): array { - return $this->functionCallHandlers; + return $this->toolHandler; } - public function setFunctionCallHandlers(array|callable $functionCallHandlers): static + public function setToolHandler(array|callable $toolHandler): static { - if (! is_array($functionCallHandlers)) { - $functionCallHandlers = [$functionCallHandlers]; + if (! is_callable($toolHandler)) { + throw new InvalidArgumentException('Tool handler must be callable.'); } - foreach ($functionCallHandlers as $functionCallHandler) { - if (! is_callable($functionCallHandler)) { - throw new InvalidArgumentException('Function call handler must be callable.'); - } - } - $this->functionCallHandlers = $functionCallHandlers; + $this->toolHandler = $toolHandler; return $this; } @@ -90,12 +88,12 @@ public function setDescription(string $description) return $this; } - public function getParameters(): ?FunctionCallParameters + public function getParameters(): ?ToolParameters { return $this->parameters; } - public function setParameters(FunctionCallParameters $parameters): static + public function setParameters(ToolParameters $parameters): static { $this->parameters = $parameters; return $this; diff --git a/src/Apis/OpenAI/Request/FunctionCallParameter.php b/src/Apis/OpenAI/Request/ToolParameter.php similarity index 98% rename from src/Apis/OpenAI/Request/FunctionCallParameter.php rename to src/Apis/OpenAI/Request/ToolParameter.php index 48a8d65..f535d90 100644 --- a/src/Apis/OpenAI/Request/FunctionCallParameter.php +++ b/src/Apis/OpenAI/Request/ToolParameter.php @@ -12,7 +12,7 @@ namespace Hyperf\Odin\Apis\OpenAI\Request; -class FunctionCallParameter +class ToolParameter { protected string $name; diff --git a/src/Apis/OpenAI/Request/FunctionCallParameters.php b/src/Apis/OpenAI/Request/ToolParameters.php similarity index 76% rename from src/Apis/OpenAI/Request/FunctionCallParameters.php rename to src/Apis/OpenAI/Request/ToolParameters.php index 075ad77..05dddbf 100644 --- a/src/Apis/OpenAI/Request/FunctionCallParameters.php +++ b/src/Apis/OpenAI/Request/ToolParameters.php @@ -5,7 +5,7 @@ use Hyperf\Contract\Arrayable; -class FunctionCallParameters implements Arrayable +class ToolParameters implements Arrayable { protected string $type; @@ -17,7 +17,7 @@ public function __construct(array $properties = [], string $type = 'object') $this->properties = $properties; $this->type = $type; foreach ($properties as $property) { - if (! $property instanceof FunctionCallParameter) { + if (! $property instanceof ToolParameter) { continue; } if ($property->isRequired()) { @@ -30,7 +30,7 @@ public function toArray(): array { $properties = []; foreach ($this->getProperties() as $property) { - if (! $property instanceof FunctionCallParameter) { + if (! $property instanceof ToolParameter) { continue; } $item = [ @@ -49,6 +49,15 @@ public function toArray(): array ]; } + public static function fromArray(array $parameters): ToolParameters + { + $properties = []; + foreach ($parameters as $name => $property) { + $properties[] = new ToolParameter($name, $property['description'], $property['type'] ?? 'string', $property['required'] ?? false, $property['enum'] ?? null); + } + return new ToolParameters($properties); + } + public function getType(): string { return $this->type; diff --git a/src/Apis/OpenAI/Response/ChatCompletionChoice.php b/src/Apis/OpenAI/Response/ChatCompletionChoice.php index af39f9e..f5597e3 100644 --- a/src/Apis/OpenAI/Response/ChatCompletionChoice.php +++ b/src/Apis/OpenAI/Response/ChatCompletionChoice.php @@ -50,8 +50,8 @@ public function getFinishReason(): ?string return $this->finishReason; } - public function isFinishedByFunctionCall(): bool + public function isFinishedByToolCall(): bool { - return $this->getFinishReason() === 'function_call'; + return $this->getFinishReason() === 'tool_calls'; } } diff --git a/src/Apis/OpenAI/Response/ChatCompletionResponse.php b/src/Apis/OpenAI/Response/ChatCompletionResponse.php index 97a4d88..8f23a5a 100644 --- a/src/Apis/OpenAI/Response/ChatCompletionResponse.php +++ b/src/Apis/OpenAI/Response/ChatCompletionResponse.php @@ -12,7 +12,9 @@ namespace Hyperf\Odin\Apis\OpenAI\Response; -class ChatCompletionResponse extends AbstractResponse +use Stringable; + +class ChatCompletionResponse extends AbstractResponse implements Stringable { protected ?string $id = null; @@ -134,4 +136,5 @@ protected function buildChoices(mixed $choices): array } return $result; } + } diff --git a/src/Apis/OpenAI/Response/FunctionCall.php b/src/Apis/OpenAI/Response/FunctionCall.php deleted file mode 100644 index d798954..0000000 --- a/src/Apis/OpenAI/Response/FunctionCall.php +++ /dev/null @@ -1,113 +0,0 @@ -setOriginalName($functionCall['name'] ?? ''); - $static->setOriginalArguments($functionCall['arguments'] ?? ''); - $shouldFix && $static->setShouldFix(true); - return $static; - } - - public function toArray(): array - { - return [ - 'name' => $this->name, - 'arguments' => $this->arguments, - ]; - } - - public function getName(): string - { - return $this->name; - } - - public function setName(string $name): static - { - $this->name = $name; - return $this; - } - - public function getArguments(): array - { - return $this->arguments; - } - - public function setArguments(array $arguments): static - { - $this->arguments = $arguments; - return $this; - } - - public function isShouldFix(): bool - { - return $this->shouldFix; - } - - public function setShouldFix(bool $shouldFix): static - { - $this->shouldFix = $shouldFix; - return $this; - } - - public function getOriginalName(): string - { - return $this->originalName; - } - - public function setOriginalName(string $originalName): static - { - $this->originalName = $originalName; - return $this; - } - - public function getOriginalArguments(): string - { - return $this->originalArguments; - } - - public function setOriginalArguments(string $originalArguments): static - { - $this->originalArguments = $originalArguments; - return $this; - } -} diff --git a/src/Apis/OpenAI/Response/ToolCall.php b/src/Apis/OpenAI/Response/ToolCall.php new file mode 100644 index 0000000..de2035b --- /dev/null +++ b/src/Apis/OpenAI/Response/ToolCall.php @@ -0,0 +1,94 @@ + $this->getId(), + 'name' => $this->getName(), + 'arguments' => $this->getArguments(), + ]; + } + + public function getName(): string + { + return $this->name; + } + + public function setName(string $name): static + { + $this->name = $name; + return $this; + } + + public function getArguments(): array + { + return $this->arguments; + } + + public function setArguments(array $arguments): static + { + $this->arguments = $arguments; + return $this; + } + + public function getId(): string + { + return $this->id; + } + + public function setId(string $id): static + { + $this->id = $id; + return $this; + } +} diff --git a/src/Conversation/Conversation.php b/src/Conversation/Conversation.php deleted file mode 100644 index 568f850..0000000 --- a/src/Conversation/Conversation.php +++ /dev/null @@ -1,192 +0,0 @@ - $messages[0] ?? '', - 'user' => $messages[1] ?? '', - ]; - } - if (! isset($messages['user']) || ! isset($messages['system'])) { - throw new InvalidArgumentException('The messages must contain user and system.'); - } - if (! $messages['user'] instanceof UserMessage || ! $messages['system'] instanceof SystemMessage) { - throw new InvalidArgumentException('The messages must be UserMessage and SystemMessage.'); - } - if (! $messages['user']->getContext('original_user_message')) { - $messages['user']->setContext('original_user_message', $messages['user']->getContent()); - } - $originalUserMessage = $messages['user']->getContext('original_user_message'); - // Memory, handle the user message. - if ($this->memory && $conversationId) { - $memoryPrompt = $this->memory->buildPrompt($messages['user'], $conversationId); - if (is_string($memoryPrompt)) { - $messages['user']->setContent($memoryPrompt); - } - if ($chatType === 'user') { - $this->memory->addHumanMessage($originalUserMessage, $conversationId, 'User Input: '); - } elseif ($chatType === 'ai') { - $this->memory->addAIMessage($originalUserMessage, $conversationId, 'AI Input: '); - } - } - // Select Model - if ($model instanceof ModelSelector) { - $model = $model->select($messages['user']->getContent(), $this->modelMapper->getModels()); - } elseif (is_string($model)) { - $model = $this->modelMapper->getModel($model); - } elseif (! $model instanceof Model) { - throw new InvalidArgumentException('The model must be a ModelSelector, Model or string.'); - } - // Chat with the model. - $response = $this->chatWithModel($messages, $model, $option); - // Function Call Handlers - if ($response->getChoices()) { - foreach ($response->getChoices() as $choice) { - if (! $choice instanceof ChatCompletionChoice || ! $choice->isFinishedByFunctionCall()) { - continue; - } - $message = $choice->getMessage(); - if (! $message instanceof AssistantMessage) { - continue; - } - $functionCall = $message->getFunctionCall(); - if (! $functionCall instanceof FunctionCall) { - continue; - } - /** @var FunctionCallDefinition[] $functionDefinitions */ - $functionDefinitions = ArrFilter::filterInstance(FunctionCallDefinition::class, $option->getFunctions()); - if ($functionCall->isShouldFix()) { - $functionCall = $this->fixFunctionCall($functionCall, $functionDefinitions); - } - foreach ($functionDefinitions as $functionDefinition) { - if ($functionDefinition->getName() === $functionCall->getName()) { - $functionCallHandlers = $functionDefinition->getFunctionCallHandlers(); - if (isset($functionCallHandlers[0]) && is_callable($functionCallHandlers[0])) { - $functionCallHandler = $functionCallHandlers[0]; - $result = $functionCallHandler($functionCall); - $code = $functionCall->getArguments()['code'] ?? ''; - $prompt = Prompt::getPrompt('AfterCodeExecuted', [ - 'userRequirement' => $messages['user']->getContext('original_user_message'), - 'code' => $code, - 'codeExecutedResult' => $result, - ]); - $messages = [ - 'system' => $messages['system'], - 'user' => $messages['user']->setContent($prompt), - ]; - $response = $this->chat($messages, $model, $option); - } - break; - } - } - } - } - // Memory, handle the AI message. - if ($this->memory && $conversationId) { - $this->memory->addAIMessage((string)$response, $conversationId, 'AI Response: '); - } - return $response; - } - - public function withClient(ClientInterface $client): static - { - $static = clone $this; - $static->client = $client; - return $static; - } - - public function withMemory(AbstractMemory $memory): static - { - $static = clone $this; - $static->memory = $memory; - return $static; - } - - public function withActions(array $actions): static - { - $static = clone $this; - $static->actions = $actions; - return $static; - } - - public function createConversationId(): string - { - return uniqid(); - } - - protected function chatWithModel(array $messages, Model $model, Option $option = null): ChatCompletionResponse - { - var_dump($messages['user']->getContent()); - return $this->llm->chat(messages: $messages, temperature: $option->getTemperature(), maxTokens: $option->getMaxTokens(), stop: $option->getStop(), functions: $option->getFunctions(), model: $model->getName(),); - } - - protected function fixFunctionCall(FunctionCall $functionCall, array $functionDefinitions): FunctionCall - { - foreach ($functionDefinitions as $functionDefinition) { - if (! $functionDefinition instanceof FunctionCallDefinition) { - continue; - } - $functionCallHandlers = $functionDefinition->getFunctionCallHandlers(); - if (isset($functionCallHandlers[1]) && is_callable($functionCallHandlers[1])) { - $functionCallHandler = $functionCallHandlers[1]; - $functionCall = $functionCallHandler($functionCall); - } - break; - } - return $functionCall; - } -} diff --git a/src/Conversation/ConversationBak.php b/src/Conversation/ConversationBak.php deleted file mode 100644 index cd2be68..0000000 --- a/src/Conversation/ConversationBak.php +++ /dev/null @@ -1,139 +0,0 @@ -thoughtActions($client, $input, $model, $actions); - if ($matchedActions) { - $actionsResults = $this->handleActions($matchedActions); - $actionsResults && $prompt = (new ActionTemplate())->buildAfterActionExecutedPrompt($input, $actionsResults); - if ($memory) { - $prompt = $memory->buildPrompt($prompt, $conversationId); - } - $response = $this->answer($client, $prompt, $model); - if ($response->getContent()) { - $finalAnswer = Str::replaceFirst('Answer:', '', (string)$response); - } - } - } - if (! $finalAnswer) { - if ($memory) { - $prompt = $memory->buildPrompt($input, $conversationId); - } - $response = $this->answer($client, $prompt, $model); - $finalAnswer = (string)$response; - } - if ($memory) { - $memory->addHumanMessage($input, $conversationId); - foreach ($actionsResults ?? [] as $actionName => $actionResult) { - $memory->addMessage(sprintf('%s Action Result: %s', $actionName, $actionResult), $conversationId); - } - $memory->addAIMessage($finalAnswer, $conversationId); - } - return trim($finalAnswer); - } - - public function createConversationId(): string - { - return uniqid(); - } - - protected function thoughtActions( - ClientInterface $client, - string $userInput, - string $model, - array $actions, - ): array { - $actionTemplate = new ActionTemplate(); - $prompt = $actionTemplate->buildThoughtActionsPrompt($userInput, $actions); - $messages = [ - 'system' => $this->buildSystemMessage(), - 'user' => new UserMessage($prompt), - ]; - $response = $client->chat($messages, $model, temperature: 0); - return $actionTemplate->parseActions($response); - } - - protected function buildSystemMessage(): SystemMessage - { - return new SystemMessage('你是一个由 Hyperf 组织开发的聊天机器人,你必须严格按照格式要求返回内容'); - } - - protected function handleActions(array $matchedActions): array - { - // 匹配到了 Actions,按顺序执行 Actions - $actionsResults = []; - foreach ($matchedActions as $action) { - if (! isset($action['name'], $action['args'])) { - continue; - } - $actionName = $action['name']; - $actionArgs = $action['args']; - $actionInstance = match ($actionName) { - 'Calculator' => new CalculatorAction(), - 'Weather' => new WeatherAction(), - 'Search' => new SearchAction(), - default => null, - }; - if (! $actionInstance) { - continue; - } - $actionResult = $actionInstance->handle(...$actionArgs); - if ($actionResult) { - $actionsResults[$actionName] = $actionResult; - } - } - return $actionsResults; - } - - protected function answer( - ClientInterface $client, - string $prompt, - string $model, - float $temperature = 0, - ): ChatCompletionResponse { - $messages = [ - 'system' => $this->buildSystemMessage(), - 'user' => new UserMessage($prompt), - ]; - return $client->chat($messages, $model, temperature: $temperature); - } -} diff --git a/src/Conversation/Option.php b/src/Conversation/Option.php deleted file mode 100644 index 430a6e8..0000000 --- a/src/Conversation/Option.php +++ /dev/null @@ -1,76 +0,0 @@ -setTemperature($temperature); - $this->setMaxTokens($maxTokens); - $this->setStop($stop); - $this->setFunctions($functions); - } - - public function getTemperature(): float - { - return $this->temperature; - } - - public function setTemperature(float $temperature): static - { - $this->temperature = $temperature; - return $this; - } - - public function getMaxTokens(): int - { - return $this->maxTokens; - } - - public function setMaxTokens(int $maxTokens): static - { - $this->maxTokens = $maxTokens; - return $this; - } - - public function getStop(): array - { - return $this->stop; - } - - public function setStop(array $stop): static - { - $this->stop = $stop; - return $this; - } - - public function getFunctions(): array - { - return $this->functions; - } - - public function setFunctions(array $functions): static - { - $this->functions = $functions; - return $this; - } -} diff --git a/src/Interpreter/CodeRunner.php b/src/Interpreter/CodeRunner.php index 7ad49b6..03c040a 100644 --- a/src/Interpreter/CodeRunner.php +++ b/src/Interpreter/CodeRunner.php @@ -14,10 +14,10 @@ use Closure; use Exception; -use Hyperf\Odin\Apis\OpenAI\Request\FunctionCallDefinition; -use Hyperf\Odin\Apis\OpenAI\Request\FunctionCallParameter; -use Hyperf\Odin\Apis\OpenAI\Request\FunctionCallParameters; -use Hyperf\Odin\Apis\OpenAI\Response\FunctionCall; +use Hyperf\Odin\Apis\OpenAI\Request\ToolDefinition; +use Hyperf\Odin\Apis\OpenAI\Request\ToolParameter; +use Hyperf\Odin\Apis\OpenAI\Request\ToolParameters; +use Hyperf\Odin\Apis\OpenAI\Response\ToolCall; class CodeRunner { @@ -42,7 +42,7 @@ class CodeRunner */ public static function handlers(): array { - $handler = function (FunctionCall $functionCall) { + $handler = function (ToolCall $functionCall) { $arguments = $functionCall->getArguments(); if (! isset($arguments['language']) || ! isset($arguments['code'])) { echo '[DEBUG] Invalid function arguments' . PHP_EOL; @@ -50,7 +50,7 @@ public static function handlers(): array } return (new CodeRunner())->runCode($arguments['language'], $arguments['code']); }; - $fixer = function (FunctionCall $functionCall) { + $fixer = function (ToolCall $functionCall) { if ($functionCall->getName() !== 'run_code' && $functionCall->getOriginalArguments() && ! $functionCall->getArguments()) { if (in_array($functionCall->getName(), ['python', 'shell', 'php'])) { $functionCall->setArguments([ @@ -65,16 +65,16 @@ public static function handlers(): array return [$handler, $fixer]; } - public static function toFunctionCallDefinition(): FunctionCallDefinition + public static function toFunctionCallDefinition(): ToolDefinition { - return new FunctionCallDefinition(name: 'run_code', description: 'Executes code and returns the value printed on STDOUT.', parameters: new FunctionCallParameters([ - new FunctionCallParameter(name: 'language', description: 'The programming language, PHP version is 8.2, Python version is 3.11', enum: [ + return new ToolDefinition(name: 'run_code', description: 'Executes code and returns the value printed on STDOUT.', parameters: new ToolParameters([ + new ToolParameter(name: 'language', description: 'The programming language, PHP version is 8.2, Python version is 3.11', enum: [ 'php', 'python', 'shell' ]), - new FunctionCallParameter(name: 'code', description: 'The code which needs to be executed.',), - ]), functionHandlers: CodeRunner::handlers(),); + new ToolParameter(name: 'code', description: 'The code which needs to be executed.',), + ]), toolHandler: CodeRunner::handlers(),); } public function runCode(string $language, string $code) diff --git a/src/Logger.php b/src/Logger.php index 976c1d9..3ac8238 100644 --- a/src/Logger.php +++ b/src/Logger.php @@ -1,61 +1,70 @@ log('EMERGENCY', $message, $context); } - public function alert(Stringable|string $message, array $context = []): void + public function alert(string|Stringable $message, array $context = []): void { $this->log('ALERT', $message, $context); } - public function critical(Stringable|string $message, array $context = []): void + public function critical(string|Stringable $message, array $context = []): void { $this->log('CRITICAL', $message, $context); } - public function error(Stringable|string $message, array $context = []): void + public function error(string|Stringable $message, array $context = []): void { $this->log('ERROR', $message, $context); } - public function warning(Stringable|string $message, array $context = []): void + public function warning(string|Stringable $message, array $context = []): void { $this->log('WARNING', $message, $context); } - public function notice(Stringable|string $message, array $context = []): void + public function notice(string|Stringable $message, array $context = []): void { $this->log('NOTICE', $message, $context); } - public function info(Stringable|string $message, array $context = []): void + public function info(string|Stringable $message, array $context = []): void { $this->log('INFO', $message, $context); } - public function debug(Stringable|string $message, array $context = []): void + public function debug(string|Stringable $message, array $context = []): void { $this->log('DEBUG', $message, $context); } - public function log($level, Stringable|string $message, array $context = []): void + public function log($level, string|Stringable $message, array $context = []): void { $message = (string)$message; - $message = sprintf('[%s] %s', $level, $message); + $datetime = date('Y-m-d H:i:s'); + $message = sprintf('[%s] %s %s', $level, $datetime, $message); if ($context) { $message .= sprintf(' %s', json_encode($context, JSON_UNESCAPED_UNICODE)); } echo $message . PHP_EOL; } -} \ No newline at end of file +} diff --git a/src/Memory/AbstractMemory.php b/src/Memory/AbstractMemory.php index 0e2f6e9..a9c68bd 100644 --- a/src/Memory/AbstractMemory.php +++ b/src/Memory/AbstractMemory.php @@ -12,31 +12,10 @@ namespace Hyperf\Odin\Memory; -use Stringable; - -abstract class AbstractMemory +abstract class AbstractMemory implements MemoryInterface { protected array $conversations = []; - abstract public function addHumanMessage( - string|Stringable $input, - string|Stringable|null $conversationId, - string $prefix = 'User: ' - ): static; - - abstract public function addAIMessage( - string|Stringable $output, - string|Stringable|null $conversationId, - string $prefix = 'AI: ' - ): static; - - abstract public function addMessage(string|Stringable $message, string|Stringable|null $conversationId): static; - - abstract public function buildPrompt( - string|Stringable $input, - string|Stringable|null $conversationId - ): string|Stringable; - public function count(): int { return count($this->conversations); diff --git a/src/Memory/MemoryInterface.php b/src/Memory/MemoryInterface.php new file mode 100644 index 0000000..9acd714 --- /dev/null +++ b/src/Memory/MemoryInterface.php @@ -0,0 +1,23 @@ +conversations[$conversationId] ?? null; - if (! $conversation) { - return $input; - } - $history = implode("\n", $conversation); - return <<systemMessages[$conversationId] = $message; + return $this; } - public function addHumanMessage( - string|Stringable $input, - string|Stringable|null $conversationId, - string $prefix = 'User: ' - ): static + public function addMessages(array|MessageInterface $messages, string|Stringable $conversationId): static { - return $this->addMessage($prefix . $input, $conversationId); - } + if (! is_string($conversationId) && ! ($conversationId instanceof Stringable)) { + throw new InvalidArgumentException('Conversation ID must be a string, an instance of Stringable, or null.'); + } - public function addAIMessage( - string|Stringable $output, - string|Stringable|null $conversationId, - string $prefix = 'AI: ' - ): static - { - return $this->addMessage($prefix . $output, $conversationId); - } + if (! is_array($messages)) { + $messages = [$messages]; + } - public function addMessage(string|Stringable $message, string|Stringable|null $conversationId): static - { - if ($conversationId) { + foreach ($messages as $message) { + if (! $message instanceof MessageInterface) { + throw new InvalidArgumentException('Messages must be an array of MessageInterface instances.'); + } + } + + foreach ($messages as $message) { $this->conversations[$conversationId][] = $message; if (count($this->conversations[$conversationId]) > $this->maxRecord) { array_shift($this->conversations[$conversationId]); } } + return $this; } + + public function getConversations(string $conversationId): array + { + $messages = $this->conversations[$conversationId] ?? []; + $systemMessage = $this->systemMessages[$conversationId] ?? null; + return $systemMessage ? array_merge([$systemMessage], $messages) : $messages; + } } diff --git a/src/Message/AbstractMessage.php b/src/Message/AbstractMessage.php index a957ab8..635a098 100644 --- a/src/Message/AbstractMessage.php +++ b/src/Message/AbstractMessage.php @@ -22,14 +22,30 @@ abstract class AbstractMessage implements MessageInterface, Stringable protected array $context = []; - public function __construct(string $content) + public function __construct(string $content, array $context = []) { $this->content = $content; + $this->context = $context; } public function __toString(): string { - return $this->getContent(); + // Replace the variables in content according to the key in context, for example {name} matches $context['name'] + $content = $this->content; + foreach ($this->context as $key => $value) { + $content = str_replace('{' . $key . '}', $value, $content); + } + return $content; + } + + public function formatContent(array $context): string + { + $context = array_merge($this->context, $context); + $content = $this->content; + foreach ($context as $key => $value) { + $content = str_replace('{' . $key . '}', $value, $content); + } + return $content; } public function toArray(): array @@ -83,4 +99,9 @@ public function setContext(string $key, $value): mixed $this->context[$key] = $value; return $value; } + + public function hasContext(string $key): bool + { + return isset($this->context[$key]); + } } diff --git a/src/Message/AssistantMessage.php b/src/Message/AssistantMessage.php index c699e2c..f8b2900 100644 --- a/src/Message/AssistantMessage.php +++ b/src/Message/AssistantMessage.php @@ -12,41 +12,55 @@ namespace Hyperf\Odin\Message; -use Hyperf\Odin\Apis\OpenAI\Response\FunctionCall; +use Hyperf\Odin\Apis\OpenAI\Response\ToolCall; class AssistantMessage extends AbstractMessage { protected Role $role = Role::Assistant; - protected ?FunctionCall $functionCall; + /** + * @var ToolCall[] + */ + protected array $toolCalls = []; - public function __construct(string $content, ?FunctionCall $functionCall = null) + public function __construct(string $content, array $toolsCall = []) { parent::__construct($content); - $this->functionCall = $functionCall; + $this->toolCalls = $toolsCall; } public static function fromArray(array $message): static { - return new static($message['content'] ?? '', FunctionCall::fromArray($message['function_call'] ?? [])); + return new static($message['content'] ?? '', ToolCall::fromArray($message['tool_calls'] ?? [])); } public function toArray(): array { - return [ + $toolCalls = []; + foreach ($this->toolCalls as $toolCall) { + $toolCalls[] = $toolCall->toArray(); + } + $result = [ 'role' => $this->role->value, - 'function_call' => $this->functionCall->toArray(), + 'content' => $this->content, ]; + $toolCalls && $result['tool_calls'] = $toolCalls; + return $result; } - public function getFunctionCall(): ?FunctionCall + public function hasToolCalls(): bool { - return $this->functionCall; + return ! empty($this->toolCalls); } - public function setFunctionCall(FunctionCall $functionCall): static + public function getToolCalls(): array { - $this->functionCall = $functionCall; + return $this->toolCalls; + } + + public function setToolCalls(array $toolCalls): static + { + $this->toolCalls = $toolCalls; return $this; } } diff --git a/src/Message/FunctionMessage.php b/src/Message/FunctionMessage.php new file mode 100644 index 0000000..b2cef13 --- /dev/null +++ b/src/Message/FunctionMessage.php @@ -0,0 +1,18 @@ +systemMessage], $this->previousMessages, [$this->lastUserMessage]); + } +} diff --git a/src/Message/Role.php b/src/Message/Role.php index b981e33..b6ffcdc 100644 --- a/src/Message/Role.php +++ b/src/Message/Role.php @@ -8,5 +8,6 @@ enum Role: string case System = 'system'; case User = 'user'; case Assistant = 'assistant'; + case Function = 'function'; } \ No newline at end of file diff --git a/src/Model/AzureOpenAIModel.php b/src/Model/AzureOpenAIModel.php new file mode 100644 index 0000000..8e2bfcf --- /dev/null +++ b/src/Model/AzureOpenAIModel.php @@ -0,0 +1,36 @@ +getAzureOpenAIClient(); + return $client->chat($messages, $this->model, $temperature, $maxTokens, $stop, $tools); + } + + public function getAzureOpenAIClient(): AzureOpenAIClient + { + $openAI = new AzureOpenAI(); + $config = new AzureOpenAIConfig($this->config); + return $openAI->getClient($config, $this->model); + } +} \ No newline at end of file diff --git a/src/Model/ModelInterface.php b/src/Model/ModelInterface.php new file mode 100644 index 0000000..d5bafc6 --- /dev/null +++ b/src/Model/ModelInterface.php @@ -0,0 +1,16 @@ +defaultModel = $config->get('odin.llm.default', 'gpt-3.5-turbo'); $models = $config->get('odin.llm.models', []); foreach ($models as $model => $item) { - if (! isset($item['name'], $item['api_type'])) { + if (! $model || ! isset($item['implementation'])) { continue; } - $this->models[$model] = new Model($item['name'], $item['api_type']); + $implementation = $item['implementation']; + $modelObject = new $implementation($model, $item['config'] ?? []); + if (! $modelObject instanceof ModelInterface) { + throw new InvalidArgumentException(sprintf('Model %s must be an instance of %s.', $model, ModelInterface::class)); + } + $this->models[$model] = $modelObject; } } - public function getModel(string $model): Model + public function getDefaultModel(): ModelInterface + { + return $this->getModel($this->defaultModel); + } + + public function getModel(string $model): ModelInterface { if ($model === '') { $model = $this->defaultModel; diff --git a/src/Observer.php b/src/Observer.php new file mode 100644 index 0000000..4ea2078 --- /dev/null +++ b/src/Observer.php @@ -0,0 +1,41 @@ +logger->info($message, $context); + } + + public function debug(Stringable|string $message, array $context = []): void + { + if (! $this->isDebug()) { + return; + } + $this->logger->debug($message, $context); + } + + public function isDebug(): bool + { + return $this->debug; + } + + public function setDebug(bool $debug): static + { + $this->debug = $debug; + return $this; + } + +} \ No newline at end of file diff --git a/src/Prompt/CodeInterpreter.prompt b/src/Prompt/CodeInterpreter.prompt index df98fe6..2ced213 100644 --- a/src/Prompt/CodeInterpreter.prompt +++ b/src/Prompt/CodeInterpreter.prompt @@ -3,7 +3,18 @@ First, write a plan, Always recap the plan between each code block. If missing some package, you can install it by execute shell command or modify the code, use `pip3` in Python, use `composer` in PHP. Never use (!) when running commands. You should fix any error by execute shell command or modify the code. -When using Python, you should always use the 'print' function for the output +When using Python, you should always use the 'print' function for the final result of the code. +Example: +This is not good: +```python +content = 'Hello' +content +``` +This is good: +```python +content = 'Hello' +print(content) +``` If you want to use some API, use Public and Free APIs first, do not ask the user to input api key or secrets. Response should according to the user's language. When you know the final answer of the goal, you should print it in the end with the format `Final Answer: `. diff --git a/src/Prompt/OpenAIToolsAgentPrompt.php b/src/Prompt/OpenAIToolsAgentPrompt.php new file mode 100644 index 0000000..810fb0a --- /dev/null +++ b/src/Prompt/OpenAIToolsAgentPrompt.php @@ -0,0 +1,67 @@ +systemPrompt = $systemPrompt; + } + if (! is_null($userPrompt)) { + $this->userPrompt = $userPrompt; + } + if (! is_null($placeholders)) { + $this->placeholders = $placeholders; + } + $this->systemPrompt .= "\n" . $this->placeholders; + } + + public function toArray(): array + { + return [ + 'system' => new SystemMessage($this->systemPrompt), + 'user' => new UserMessage($this->userPrompt), + ]; + } + + public function getSystemPrompt(string $agentScratchpad = ''): SystemMessage + { + return new SystemMessage(str_replace('{agent_scratchpad}', $agentScratchpad, $this->systemPrompt)); + } + + public function getUserPrompt(string $input): UserMessage + { + return new UserMessage(str_replace('{input}', $input, $this->userPrompt)); + } + +} diff --git a/src/Prompt/PromptInterface.php b/src/Prompt/PromptInterface.php new file mode 100644 index 0000000..1aa67d5 --- /dev/null +++ b/src/Prompt/PromptInterface.php @@ -0,0 +1,17 @@ +name, $this->description, ToolParameters::fromArray($this->parameters) ?? null, [ + $this, + 'invoke' + ]); + } + +} \ No newline at end of file diff --git a/src/Tools/TavilySearchResults.php b/src/Tools/TavilySearchResults.php new file mode 100644 index 0000000..9d3112a --- /dev/null +++ b/src/Tools/TavilySearchResults.php @@ -0,0 +1,124 @@ + [ + 'type' => 'string', + 'description' => 'The search query to use, the min query length is 5 characters, use Simplified Chinese as much as possible.', + 'required' => true, + ], + ]; + + public function __construct( + public TavilySearchApiWrapper $apiWrapper, + public string $searchDepth = 'basic', + public int $maxResults = 5, + public bool $useAnswerDirectly = false, + ) { + } + + public function setMaxResults(int $maxResults): static + { + $this->maxResults = $maxResults; + return $this; + } + + public function getMaxResults(): int + { + return $this->maxResults; + } + + public function getSearchDepth(): string + { + return $this->searchDepth; + } + + public function isUseAnswerDirectly(): bool + { + return $this->useAnswerDirectly; + } + + public function setUseAnswerDirectly(bool $useAnswerDirectly): static + { + $this->useAnswerDirectly = $useAnswerDirectly; + return $this; + } + + public function setSearchDepth(string $searchDepth): static + { + if (! in_array($searchDepth, ['basic', 'advanced'])) { + throw new InvalidArgumentException('Invalid search depth. Must be one of: basic, advanced'); + } + $this->searchDepth = $searchDepth; + return $this; + } + + public function invoke(string $query, bool $throwException = false): array + { + try { + if ($this->isValidQuery($query)) { + $result = $this->apiWrapper->results($query, $this->getMaxResults(), $this->getSearchDepth(), $this->isUseAnswerDirectly()); + if ($this->isUseAnswerDirectly()) { + $answer = $result['answer'] ?? ''; + return ['answer' => $answer]; + } + return $this->cleanResults($result); + } + return []; + } catch (Throwable $e) { + if ($throwException) { + throw $e; + } + return []; + } + } + + protected function isValidQuery(string $query): bool + { + // Query is too short. Min query length is 5 characters. + if (strlen($query) < 5) { + throw new InvalidArgumentException('Query is too short. Min query length is 5 characters.'); + } + return true; + } + + protected function cleanResults(array $result): array + { + $filteredResult = []; + foreach ($result['results'] ?? [] as $item) { + $filteredResult[] = [ + 'title' => $item['title'], + 'content' => $item['content'], + 'url' => $item['url'], + ]; + } + return $filteredResult; + } +} diff --git a/src/Tools/ToolInterface.php b/src/Tools/ToolInterface.php new file mode 100644 index 0000000..da6e259 --- /dev/null +++ b/src/Tools/ToolInterface.php @@ -0,0 +1,12 @@ +client = $client; + $apiKey = $config->get('odin.tavily.api_key'); + $this->apiKeys = explode(',', $apiKey); + } + + public function results( + string $query, + int $maxResults = 5, + string $searchDepth = 'basic', + $includeAnswer = false + ): array { + return $this->rawResults($query, $maxResults, $searchDepth, includeAnswer: $includeAnswer); + } + + protected function rawResults( + string $query, + int $maxResults = 5, + string $searchDepth = 'basic', + array $includeDomains = [], + array $excludeDomains = [], + bool $includeAnswer = false, + bool $includeRawContent = false, + bool $includeImages = false + ): array { + $uri = self::API_URL . '/search'; + $randApiKey = $this->apiKeys[array_rand($this->apiKeys)]; + $response = $this->client->post($uri, [ + 'json' => [ + 'api_key' => $randApiKey, + 'query' => $query, + 'max_results' => $maxResults, + 'search_depth' => $searchDepth, + "include_domains" => $includeDomains, + "exclude_domains" => $excludeDomains, + "include_answer" => $includeAnswer, + "include_raw_content" => $includeRawContent, + "include_images" => $includeImages, + ], + 'verify' => false, + ]); + if ($response->getStatusCode() !== 200) { + throw new \RuntimeException('Failed to fetch results from Tavily Search API with status code ' . $response->getStatusCode()); + } + return json_decode($response->getBody()->getContents(), true); + } + +} \ No newline at end of file