Skip to content

Commit 0cfb7f2

Browse files
valtzutryvin
andcommitted
chore: gemini function calling
Co-authored-by: Vin Souza <[email protected]>
1 parent 5ee30df commit 0cfb7f2

File tree

10 files changed

+455
-8
lines changed

10 files changed

+455
-8
lines changed

examples/google/toolcall.php

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<?php
2+
3+
use PhpLlm\LlmChain\Chain\Chain;
4+
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
5+
use PhpLlm\LlmChain\Chain\Toolbox\Tool\Clock;
6+
use PhpLlm\LlmChain\Chain\Toolbox\Toolbox;
7+
use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
8+
use PhpLlm\LlmChain\Platform\Bridge\Google\PlatformFactory;
9+
use PhpLlm\LlmChain\Platform\Message\Message;
10+
use PhpLlm\LlmChain\Platform\Message\MessageBag;
11+
use Symfony\Component\Dotenv\Dotenv;
12+
13+
require_once dirname(__DIR__, 2).'/vendor/autoload.php';
14+
(new Dotenv())->loadEnv(dirname(__DIR__, 2).'/.env');
15+
16+
if (empty($_ENV['GOOGLE_API_KEY'])) {
17+
echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL;
18+
exit(1);
19+
}
20+
21+
$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']);
22+
$llm = new Gemini(Gemini::GEMINI_2_FLASH);
23+
24+
$toolbox = Toolbox::create(new Clock());
25+
$processor = new ChainProcessor($toolbox);
26+
$chain = new Chain($platform, $llm, [$processor], [$processor]);
27+
28+
$messages = new MessageBag(Message::ofUser('What time is it?'));
29+
$response = $chain->call($messages);
30+
31+
echo $response->getContent().\PHP_EOL;

src/Platform/Bridge/Google/Contract/AssistantMessageNormalizer.php

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,17 @@ protected function supportsModel(Model $model): bool
3636
public function normalize(mixed $data, ?string $format = null, array $context = []): array
3737
{
3838
return [
39-
['text' => $data->content],
39+
array_filter(
40+
[
41+
'text' => $data->content,
42+
'functionCall' => ($data->toolCalls[0] ?? null) ? [
43+
'id' => $data->toolCalls[0]->id,
44+
'name' => $data->toolCalls[0]->name,
45+
'args' => $data->toolCalls[0]->arguments ?: new \ArrayObject(),
46+
] : null,
47+
],
48+
static fn ($content) => null !== $content,
49+
),
4050
];
4151
}
4252
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace PhpLlm\LlmChain\Platform\Bridge\Google\Contract;
6+
7+
use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
8+
use PhpLlm\LlmChain\Platform\Contract\Normalizer\ModelContractNormalizer;
9+
use PhpLlm\LlmChain\Platform\Message\ToolCallMessage;
10+
use PhpLlm\LlmChain\Platform\Model;
11+
use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface;
12+
use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait;
13+
14+
final class ToolCallMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface
15+
{
16+
use NormalizerAwareTrait;
17+
18+
protected function supportedDataClass(): string
19+
{
20+
return ToolCallMessage::class;
21+
}
22+
23+
protected function supportsModel(Model $model): bool
24+
{
25+
return $model instanceof Gemini;
26+
}
27+
28+
/**
29+
* @param ToolCallMessage $data
30+
*
31+
* @return array{
32+
* functionResponse: array{
33+
* id: string,
34+
* name: string,
35+
* response: array
36+
* }
37+
* }[]
38+
*/
39+
public function normalize(mixed $data, ?string $format = null, array $context = []): array
40+
{
41+
$responseContent = json_validate($data->content) ? json_decode($data->content, true) : $data->content;
42+
43+
return [[
44+
'functionResponse' => array_filter([
45+
'id' => $data->toolCall->id,
46+
'name' => $data->toolCall->name,
47+
'response' => \is_array($responseContent) ? $responseContent : [
48+
'rawResponse' => $responseContent, // Gemini expects the response to be an object, but not everyone uses objects as their responses.
49+
],
50+
]),
51+
]];
52+
}
53+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
<?php
2+
3+
namespace PhpLlm\LlmChain\Platform\Bridge\Google\Contract;
4+
5+
use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
6+
use PhpLlm\LlmChain\Platform\Contract\JsonSchema\Factory;
7+
use PhpLlm\LlmChain\Platform\Contract\Normalizer\ModelContractNormalizer;
8+
use PhpLlm\LlmChain\Platform\Model;
9+
use PhpLlm\LlmChain\Platform\Tool\Tool;
10+
11+
/**
12+
* @phpstan-import-type JsonSchema from Factory
13+
*/
14+
final class ToolNormalizer extends ModelContractNormalizer
15+
{
16+
protected function supportedDataClass(): string
17+
{
18+
return Tool::class;
19+
}
20+
21+
protected function supportsModel(Model $model): bool
22+
{
23+
return $model instanceof Gemini;
24+
}
25+
26+
/**
27+
* @param Tool $data
28+
*
29+
* @return array{
30+
* functionDeclarations: array{
31+
* name: string,
32+
* description: string,
33+
* parameters: array
34+
* }[]
35+
* }
36+
*/
37+
public function normalize(mixed $data, ?string $format = null, array $context = []): array
38+
{
39+
$parameters = $data->parameters;
40+
unset($parameters['additionalProperties']);
41+
42+
return [
43+
'functionDeclarations' => [
44+
[
45+
'description' => $data->description,
46+
'name' => $data->name,
47+
'parameters' => $parameters,
48+
],
49+
],
50+
];
51+
}
52+
}

src/Platform/Bridge/Google/Gemini.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ public function __construct(string $name = self::GEMINI_2_PRO, array $options =
2727
Capability::INPUT_MESSAGES,
2828
Capability::INPUT_IMAGE,
2929
Capability::OUTPUT_STREAMING,
30+
Capability::TOOL_CALLING,
3031
];
3132

3233
parent::__construct($name, $capabilities, $options);

src/Platform/Bridge/Google/ModelHandler.php

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
use PhpLlm\LlmChain\Platform\Exception\RuntimeException;
88
use PhpLlm\LlmChain\Platform\Model;
99
use PhpLlm\LlmChain\Platform\ModelClientInterface;
10+
use PhpLlm\LlmChain\Platform\Response\Choice;
11+
use PhpLlm\LlmChain\Platform\Response\ChoiceResponse;
1012
use PhpLlm\LlmChain\Platform\Response\ResponseInterface as LlmResponse;
1113
use PhpLlm\LlmChain\Platform\Response\StreamResponse;
1214
use PhpLlm\LlmChain\Platform\Response\TextResponse;
15+
use PhpLlm\LlmChain\Platform\Response\ToolCall;
16+
use PhpLlm\LlmChain\Platform\Response\ToolCallResponse;
1317
use PhpLlm\LlmChain\Platform\ResponseConverterInterface;
1418
use Symfony\Component\HttpClient\EventSourceHttpClient;
1519
use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface;
@@ -52,6 +56,12 @@ public function request(Model $model, array|string $payload, array $options = []
5256

5357
$generationConfig = ['generationConfig' => $options];
5458
unset($generationConfig['generationConfig']['stream']);
59+
unset($generationConfig['generationConfig']['tools']);
60+
61+
if (isset($options['tools'])) {
62+
$generationConfig['tools'] = $options['tools'];
63+
unset($options['tools']);
64+
}
5565

5666
return $this->httpClient->request('POST', $url, [
5767
'headers' => [
@@ -76,11 +86,22 @@ public function convert(ResponseInterface $response, array $options = []): LlmRe
7686

7787
$data = $response->toArray();
7888

79-
if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
89+
if (!isset($data['candidates'][0]['content']['parts'][0])) {
8090
throw new RuntimeException('Response does not contain any content');
8191
}
8292

83-
return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']);
93+
/** @var Choice[] $choices */
94+
$choices = array_map($this->convertChoice(...), $data['candidates']);
95+
96+
if (1 !== \count($choices)) {
97+
return new ChoiceResponse(...$choices);
98+
}
99+
100+
if ($choices[0]->hasToolCall()) {
101+
return new ToolCallResponse(...$choices[0]->getToolCalls());
102+
}
103+
104+
return new TextResponse($choices[0]->getContent());
84105
}
85106

86107
private function convertStream(ResponseInterface $response): \Generator
@@ -114,12 +135,70 @@ private function convertStream(ResponseInterface $response): \Generator
114135
throw new RuntimeException('Failed to decode JSON response', 0, $e);
115136
}
116137

117-
if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
138+
/** @var Choice[] $choices */
139+
$choices = array_map($this->convertChoice(...), $data['candidates'] ?? []);
140+
141+
if (!$choices) {
118142
continue;
119143
}
120144

121-
yield $data['candidates'][0]['content']['parts'][0]['text'];
145+
if (1 !== \count($choices)) {
146+
yield new ChoiceResponse(...$choices);
147+
continue;
148+
}
149+
150+
if ($choices[0]->hasToolCall()) {
151+
yield new ToolCallResponse(...$choices[0]->getToolCalls());
152+
}
153+
154+
if ($choices[0]->hasContent()) {
155+
yield $choices[0]->getContent();
156+
}
122157
}
123158
}
124159
}
160+
161+
/**
162+
* @param array{
163+
* finishReason?: string,
164+
* content: array{
165+
* parts: array{
166+
* functionCall?: array{
167+
* id: string,
168+
* name: string,
169+
* args: mixed[]
170+
* },
171+
* text?: string
172+
* }[]
173+
* }
174+
* } $choice
175+
*/
176+
private function convertChoice(array $choice): Choice
177+
{
178+
$stopReason = $choice['finishReason'] ?? null;
179+
180+
$contentPart = $choice['content']['parts'][0] ?? [];
181+
182+
if (isset($contentPart['functionCall'])) {
183+
return new Choice(toolCalls: [$this->convertToolCall($contentPart['functionCall'])]);
184+
}
185+
186+
if (isset($contentPart['text'])) {
187+
return new Choice($contentPart['text']);
188+
}
189+
190+
throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $stopReason));
191+
}
192+
193+
/**
194+
* @param array{
195+
* id: string,
196+
* name: string,
197+
* args: mixed[]
198+
* } $toolCall
199+
*/
200+
private function convertToolCall(array $toolCall): ToolCall
201+
{
202+
return new ToolCall($toolCall['id'] ?? '', $toolCall['name'], $toolCall['args']);
203+
}
125204
}

src/Platform/Bridge/Google/PlatformFactory.php

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\AssistantMessageNormalizer;
88
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\MessageBagNormalizer;
9+
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\ToolCallMessageNormalizer;
10+
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\ToolNormalizer;
911
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\UserMessageNormalizer;
1012
use PhpLlm\LlmChain\Platform\Contract;
1113
use PhpLlm\LlmChain\Platform\Platform;
@@ -28,6 +30,8 @@ public static function create(
2830
return new Platform([$responseHandler], [$responseHandler], Contract::create(
2931
new AssistantMessageNormalizer(),
3032
new MessageBagNormalizer(),
33+
new ToolNormalizer(),
34+
new ToolCallMessageNormalizer(),
3135
new UserMessageNormalizer(),
3236
));
3337
}

tests/Platform/Bridge/Google/Contract/AssistantMessageNormalizerTest.php

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
use PhpLlm\LlmChain\Platform\Contract;
1010
use PhpLlm\LlmChain\Platform\Message\AssistantMessage;
1111
use PhpLlm\LlmChain\Platform\Model;
12+
use PhpLlm\LlmChain\Platform\Response\ToolCall;
1213
use PHPUnit\Framework\Attributes\CoversClass;
14+
use PHPUnit\Framework\Attributes\DataProvider;
1315
use PHPUnit\Framework\Attributes\Small;
1416
use PHPUnit\Framework\Attributes\Test;
1517
use PHPUnit\Framework\Attributes\UsesClass;
@@ -20,6 +22,7 @@
2022
#[UsesClass(Gemini::class)]
2123
#[UsesClass(AssistantMessage::class)]
2224
#[UsesClass(Model::class)]
25+
#[UsesClass(ToolCall::class)]
2326
final class AssistantMessageNormalizerTest extends TestCase
2427
{
2528
#[Test]
@@ -41,14 +44,29 @@ public function getSupportedTypes(): void
4144
self::assertSame([AssistantMessage::class => true], $normalizer->getSupportedTypes(null));
4245
}
4346

47+
#[DataProvider('normalizeDataProvider')]
4448
#[Test]
45-
public function normalize(): void
49+
public function normalize(AssistantMessage $message, array $expectedOutput): void
4650
{
4751
$normalizer = new AssistantMessageNormalizer();
48-
$message = new AssistantMessage('Great to meet you. What would you like to know?');
4952

5053
$normalized = $normalizer->normalize($message);
5154

52-
self::assertSame([['text' => 'Great to meet you. What would you like to know?']], $normalized);
55+
self::assertSame($expectedOutput, $normalized);
56+
}
57+
58+
/**
59+
* @return iterable<AssistantMessage, array{text?: string, functionCall?: array{id: string, name: string, args: mixed}}[]>
60+
*/
61+
public static function normalizeDataProvider(): iterable
62+
{
63+
yield 'assistant message' => [
64+
new AssistantMessage('Great to meet you. What would you like to know?'),
65+
[['text' => 'Great to meet you. What would you like to know?']],
66+
];
67+
yield 'function call' => [
68+
new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1', ['arg1' => '123'])]),
69+
[['functionCall' => ['id' => 'id1', 'name' => 'name1', 'args' => ['arg1' => '123']]]],
70+
];
5371
}
5472
}

0 commit comments

Comments
 (0)