Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AzureOpenAI and Skylark support stream ChatCompletion #2

Merged
merged 9 commits into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions config/autoload/odin.php
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
use Hyperf\Odin\Model\ChatglmModel;
use Hyperf\Odin\Model\OllamaModel;
use Hyperf\Odin\Model\OpenAIModel;
use Hyperf\Odin\Model\SkylarkModel;
use Hyperf\Odin\Model\DoubaoModel;
use function Hyperf\Support\env;
use function Hyperf\Support\value;

@@ -137,7 +137,7 @@
],
],
'skylark:character-4k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_CHARACTER_4K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_CHARACTER_4K_AK', env('SKYLARK_PRO_AK')),
@@ -148,7 +148,7 @@
],
],
'skylark:turbo-8k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_TURBO_8K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_TURBO_8K_AK', env('SKYLARK_PRO_AK')),
@@ -159,7 +159,7 @@
],
],
'skylark:32k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_32K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_32K_AK', env('SKYLARK_PRO_AK')),
@@ -170,7 +170,7 @@
],
],
'skylark:4k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_4K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_4K_AK', env('SKYLARK_PRO_AK')),
@@ -181,7 +181,7 @@
],
],
'skylark:lite-8k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_LITE_8K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_LITE_8K_AK', env('SKYLARK_PRO_AK')),
12 changes: 6 additions & 6 deletions publish/odin.php
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
use Hyperf\Odin\Model\ChatglmModel;
use Hyperf\Odin\Model\OllamaModel;
use Hyperf\Odin\Model\OpenAIModel;
use Hyperf\Odin\Model\SkylarkModel;
use Hyperf\Odin\Model\DoubaoModel;
use function Hyperf\Support\env;
use function Hyperf\Support\value;

@@ -138,7 +138,7 @@
],
],
'skylark:character-4k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_CHARACTER_4K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_CHARACTER_4K_AK', env('SKYLARK_PRO_AK')),
@@ -149,7 +149,7 @@
],
],
'skylark:turbo-8k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_TURBO_8K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_TURBO_8K_AK', env('SKYLARK_PRO_AK')),
@@ -160,7 +160,7 @@
],
],
'skylark:32k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_32K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_32K_AK', env('SKYLARK_PRO_AK')),
@@ -171,7 +171,7 @@
],
],
'skylark:4k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_4K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_4K_AK', env('SKYLARK_PRO_AK')),
@@ -182,7 +182,7 @@
],
],
'skylark:lite-8k' => [
'implementation' => SkylarkModel::class,
'implementation' => DoubaoModel::class,
'config' => [
'host' => env('SKYLARK_PRO_LITE_8K_HOST', env('SKYLARK_PRO_HOST')),
'ak' => env('SKYLARK_PRO_LITE_8K_AK', env('SKYLARK_PRO_AK')),
6 changes: 3 additions & 3 deletions src/Agent/ToolsAgent.php
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
use Hyperf\Odin\Message\FunctionMessage;
use Hyperf\Odin\Message\ToolMessage;
use Hyperf\Odin\Model\ModelInterface;
use Hyperf\Odin\Model\SkylarkModel;
use Hyperf\Odin\Model\DoubaoModel;
use Hyperf\Odin\Observer;
use Hyperf\Odin\Prompt\PromptInterface;
use Hyperf\Odin\Tool\ToolInterface;
@@ -254,7 +254,7 @@ protected function chat(

protected function response(ChatCompletionResponse $response): ChatCompletionResponse
{
if ($this->model instanceof SkylarkModel) {
if ($this->model instanceof DoubaoModel) {
$choices = $response->getChoices();
// 取 <|Answer|>: 后面的内容作为回答
foreach ($choices as $key => $choice) {
@@ -279,7 +279,7 @@ protected function response(ChatCompletionResponse $response): ChatCompletionRes

protected function transferMessages(array $messages): array
{
if ($this->model instanceof SkylarkModel) {
if ($this->model instanceof DoubaoModel) {
// 把里面的 ToolMessage 转为 FunctionMessage
foreach ($messages as $key => $message) {
if ($message instanceof ToolMessage) {
4 changes: 1 addition & 3 deletions src/Api/AzureOpenAI/AzureOpenAIConfig.php
Original file line number Diff line number Diff line change
@@ -16,9 +16,7 @@ class AzureOpenAIConfig
{
public function __construct(
protected array $config = [],
) {

}
) {}

public function getApiKey(): ?string
{
27 changes: 12 additions & 15 deletions src/Api/AzureOpenAI/Client.php
Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@
use Hyperf\Odin\Exception\NotImplementedException;
use Hyperf\Odin\Message\MessageInterface;
use Hyperf\Odin\Tool\ToolInterface;
use InvalidArgumentException;
use Psr\Log\LoggerInterface;

class Client implements ClientInterface
@@ -33,12 +32,13 @@ class Client implements ClientInterface
*/
protected array $clients = [];

protected ?LoggerInterface $logger;
protected ?LoggerInterface $logger = null;

protected bool $debug = false;

protected string $model;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

理论上应该想办法去掉这个 $this->model 才对

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对,其实这个 model 没啥用,里面的 $clients 其实也只会有一个值


public function __construct(AzureOpenAIConfig $config, LoggerInterface $logger, string $model)
public function __construct(AzureOpenAIConfig $config, ?LoggerInterface $logger, string $model)
{
$this->logger = $logger;
$this->model = $model;
@@ -54,7 +54,7 @@ public function chat(
array $tools = [],
bool $stream = false,
): ChatCompletionResponse {
$deploymentPath = $this->buildDeploymentPath($model);
$deploymentPath = $this->buildDeploymentPath();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥所有 $model 传参都去掉了?这样会让这个类在 DI 单例的情况下只能支持一种 Model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为下面的那个

protected function buildDeploymentPath(): string
{
    return 'openai/deployments/' . $this->config->getDeploymentName();
}

中的 $this->config->getDeploymentName() 已经不提供入参了,语法检测会提示让我去掉

对于 AzureOpenAI 来说,传入一种配置,其实也就只能访问一个模型,因为 url 里面会有模型名,这里传入不同的 model 其实是无影响的,AzureOpenAI 的接口没有使用该参数

$messagesArr = [];
foreach ($messages as $message) {
if ($message instanceof MessageInterface) {
@@ -65,6 +65,7 @@ public function chat(
'messages' => $messagesArr,
'model' => $model,
'temperature' => $temperature,
'stream' => $stream,
];
if ($maxTokens) {
$json['max_tokens'] = $maxTokens;
@@ -91,10 +92,9 @@ public function chat(
$this->debug && $this->logger?->debug(sprintf("Send Messages: %s\nTools: %s", json_encode($messagesArr, JSON_UNESCAPED_UNICODE), json_encode($tools, JSON_UNESCAPED_UNICODE)));
$response = $this->getClient($model)->post($deploymentPath . '/chat/completions', [
'query' => [
'api-version' => $this->config->getApiVersion($model),
'api-version' => $this->config->getApiVersion(),
],
'json' => $json,
'verify' => false,
]);
$chatCompletionResponse = new ChatCompletionResponse($response);
$this->debug && $this->logger?->debug('Receive: ' . $chatCompletionResponse);
@@ -107,10 +107,10 @@ public function completions(
float $temperature = 0.9,
int $maxTokens = 200
): TextCompletionResponse {
$deploymentPath = $this->buildDeploymentPath($model);
$deploymentPath = $this->buildDeploymentPath();
$response = $this->getClient($model)->post($deploymentPath . '/completions', [
'query' => [
'api-version' => $this->config->getApiVersion($model),
'api-version' => $this->config->getApiVersion(),
],
'json' => [
'prompt' => $prompt,
@@ -133,14 +133,14 @@ public function embedding(
string $model = 'text-embedding-ada-002',
?string $user = null
): ListResponse {
$deploymentPath = $this->buildDeploymentPath($model);
$deploymentPath = $this->buildDeploymentPath();
$json = [
'input' => $input,
];
$user && $json['user'] = $user;
$response = $this->getClient($model)->post($deploymentPath . '/embeddings', [
'query' => [
'api-version' => $this->config->getApiVersion($model),
'api-version' => $this->config->getApiVersion(),
],
'json' => $json,
'verify' => false,
@@ -161,9 +161,6 @@ public function setDebug(bool $debug): static

protected function initConfig(AzureOpenAIConfig $config): static
{
if (! $config instanceof AzureOpenAIConfig) {
throw new InvalidArgumentException('AzureOpenAIConfig is required');
}
$this->config = $config;
$headers = [
'api-key' => $config->getApiKey(),
@@ -182,8 +179,8 @@ protected function getClient(string $model): ?GuzzleClient
return $this->clients[$model];
}

protected function buildDeploymentPath(string $model = 'gpt-3.5-turbo'): string
protected function buildDeploymentPath(): string
{
return 'openai/deployments/' . $this->config->getDeploymentName($model);
return 'openai/deployments/' . $this->config->getDeploymentName();
}
}
117 changes: 117 additions & 0 deletions src/Api/Doubao/Client.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
<?php

declare(strict_types=1);
/**
* This file is part of Hyperf.
*
* @link https://www.hyperf.io
* @document https://hyperf.wiki
* @contact [email protected]
* @license https://github.com/hyperf/hyperf/blob/master/LICENSE
*/

namespace Hyperf\Odin\Api\Doubao;

use GuzzleHttp\Client as GuzzleClient;
use Hyperf\Odin\Api\ClientInterface;
use Hyperf\Odin\Api\OpenAI\Request\ToolDefinition;
use Hyperf\Odin\Api\OpenAI\Response\ChatCompletionResponse;
use Hyperf\Odin\Message\MessageInterface;
use Hyperf\Odin\Tool\ToolInterface;
use Psr\Log\LoggerInterface;

class Client implements ClientInterface
{
protected GuzzleClient $client;

protected DoubaoConfig $config;

protected ?LoggerInterface $logger;

protected bool $debug = false;

public function __construct(DoubaoConfig $config, ?LoggerInterface $logger = null)
{
$this->logger = $logger;
$this->initConfig($config);
}

public function chat(
array $messages,
string $model,
float $temperature = 0.9,
int $maxTokens = 4096,
array $stop = [],
array $tools = [],
bool $stream = false,
): ChatCompletionResponse {
$messagesArr = [];
foreach ($messages as $message) {
if ($message instanceof MessageInterface) {
$messagesArr[] = $message->toArray();
}
}
$json = [
'stream' => $stream,
'model' => $model,
'messages' => $messagesArr,
'temperature' => $temperature,
];
if ($maxTokens) {
$json['max_tokens'] = $maxTokens;
}
if (! empty($tools)) {
$toolsArray = [];
foreach ($tools as $tool) {
if ($tool instanceof ToolInterface) {
$toolsArray[] = $tool->toToolDefinition()->toArray();
} elseif ($tool instanceof ToolDefinition) {
$toolsArray[] = $tool->toArray();
} else {
$toolsArray[] = $tool;
}
}
if (! empty($toolsArray)) {
$json['tools'] = $toolsArray;
}
}
if ($stop) {
$json['stop'] = $stop;
}
$this->debug && $this->logger?->debug(sprintf("Send Messages: %s\nTools: %s", json_encode($messagesArr, JSON_UNESCAPED_UNICODE), json_encode($tools, JSON_UNESCAPED_UNICODE)));
$response = $this->client->post('/api/v3/chat/completions', [
'json' => $json,
]);
$chatCompletionResponse = new ChatCompletionResponse($response);
$this->debug && $this->logger?->debug('Receive: ' . $chatCompletionResponse);
return $chatCompletionResponse;
}

public function isDebug(): bool
{
return $this->debug;
}

public function setDebug(bool $debug): static
{
$this->debug = $debug;
return $this;
}

protected function initConfig(DoubaoConfig $config): static
{
$headers = [
'Content-Type' => 'application/json',
'User-Agent' => 'Hyperf-Odin/1.0',
];
if ($config->getApiKey()) {
$headers['Authorization'] = 'Bearer ' . $config->getApiKey();
}
$this->client = new GuzzleClient([
'base_uri' => $config->getBaseUrl(),
'headers' => $headers,
]);
$this->config = $config;
return $this;
}
}
Loading