Skip to content

Google Gemini tool support #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/google/toolcall.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<?php

use PhpLlm\LlmChain\Chain\Chain;
use PhpLlm\LlmChain\Chain\Toolbox\ChainProcessor;
use PhpLlm\LlmChain\Chain\Toolbox\Tool\Clock;
use PhpLlm\LlmChain\Chain\Toolbox\Toolbox;
use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
use PhpLlm\LlmChain\Platform\Bridge\Google\PlatformFactory;
use PhpLlm\LlmChain\Platform\Message\Message;
use PhpLlm\LlmChain\Platform\Message\MessageBag;
use Symfony\Component\Dotenv\Dotenv;

require_once dirname(__DIR__, 2).'/vendor/autoload.php';
(new Dotenv())->loadEnv(dirname(__DIR__, 2).'/.env');

if (empty($_ENV['GOOGLE_API_KEY'])) {
echo 'Please set the GOOGLE_API_KEY environment variable.'.\PHP_EOL;
exit(1);
}

$platform = PlatformFactory::create($_ENV['GOOGLE_API_KEY']);
$llm = new Gemini(Gemini::GEMINI_2_FLASH);

$toolbox = Toolbox::create(new Clock());
$processor = new ChainProcessor($toolbox);
$chain = new Chain($platform, $llm, [$processor], [$processor]);

$messages = new MessageBag(Message::ofUser('What time is it?'));
$response = $chain->call($messages);

echo $response->getContent().\PHP_EOL;
21 changes: 18 additions & 3 deletions src/Platform/Bridge/Google/Contract/AssistantMessageNormalizer.php
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,23 @@ protected function supportsModel(Model $model): bool
*/
public function normalize(mixed $data, ?string $format = null, array $context = []): array
{
return [
['text' => $data->content],
];
$normalized = [];

if (isset($data->content)) {
$normalized['text'] = $data->content;
}

if (isset($data->toolCalls[0])) {
$normalized['functionCall'] = [
'id' => $data->toolCalls[0]->id,
'name' => $data->toolCalls[0]->name,
];

if ($data->toolCalls[0]->arguments) {
$normalized['functionCall']['args'] = $data->toolCalls[0]->arguments;
}
}

return [$normalized];
}
}
56 changes: 56 additions & 0 deletions src/Platform/Bridge/Google/Contract/ToolCallMessageNormalizer.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
<?php

declare(strict_types=1);

namespace PhpLlm\LlmChain\Platform\Bridge\Google\Contract;

use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
use PhpLlm\LlmChain\Platform\Contract\Normalizer\ModelContractNormalizer;
use PhpLlm\LlmChain\Platform\Message\ToolCallMessage;
use PhpLlm\LlmChain\Platform\Model;
use Symfony\Component\Serializer\Normalizer\NormalizerAwareInterface;
use Symfony\Component\Serializer\Normalizer\NormalizerAwareTrait;

/**
* @author Valtteri R <[email protected]>
*/
final class ToolCallMessageNormalizer extends ModelContractNormalizer implements NormalizerAwareInterface
{
use NormalizerAwareTrait;

protected function supportedDataClass(): string
{
return ToolCallMessage::class;
}

protected function supportsModel(Model $model): bool
{
return $model instanceof Gemini;
}

/**
* @param ToolCallMessage $data
*
* @return array{
* functionResponse: array{
* id: string,
* name: string,
* response: array<int|string, mixed>
* }
* }[]
*/
public function normalize(mixed $data, ?string $format = null, array $context = []): array
{
$responseContent = json_validate($data->content) ? json_decode($data->content, true) : $data->content;

return [[
'functionResponse' => array_filter([
'id' => $data->toolCall->id,
'name' => $data->toolCall->name,
'response' => \is_array($responseContent) ? $responseContent : [
'rawResponse' => $responseContent, // Gemini expects the response to be an object, but not everyone uses objects as their responses.
],
]),
]];
}
}
54 changes: 54 additions & 0 deletions src/Platform/Bridge/Google/Contract/ToolNormalizer.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<?php

namespace PhpLlm\LlmChain\Platform\Bridge\Google\Contract;

use PhpLlm\LlmChain\Platform\Bridge\Google\Gemini;
use PhpLlm\LlmChain\Platform\Contract\JsonSchema\Factory;
use PhpLlm\LlmChain\Platform\Contract\Normalizer\ModelContractNormalizer;
use PhpLlm\LlmChain\Platform\Model;
use PhpLlm\LlmChain\Platform\Tool\Tool;

/**
* @author Valtteri R <[email protected]>
*
* @phpstan-import-type JsonSchema from Factory
*/
final class ToolNormalizer extends ModelContractNormalizer
{
protected function supportedDataClass(): string
{
return Tool::class;
}

protected function supportsModel(Model $model): bool
{
return $model instanceof Gemini;
}

/**
* @param Tool $data
*
* @return array{
* functionDeclarations: array{
* name: string,
* description: string,
* parameters: JsonSchema|array{type: 'object'}
* }[]
* }
*/
public function normalize(mixed $data, ?string $format = null, array $context = []): array
{
$parameters = $data->parameters;
unset($parameters['additionalProperties']);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why/if this is needed, but it was included in the original implementation by @tryvin


return [
'functionDeclarations' => [
[
'description' => $data->description,
'name' => $data->name,
'parameters' => $parameters,
],
],
];
}
}
1 change: 1 addition & 0 deletions src/Platform/Bridge/Google/Gemini.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public function __construct(string $name = self::GEMINI_2_PRO, array $options =
Capability::INPUT_MESSAGES,
Capability::INPUT_IMAGE,
Capability::OUTPUT_STREAMING,
Capability::TOOL_CALLING,
];

parent::__construct($name, $capabilities, $options);
Expand Down
85 changes: 81 additions & 4 deletions src/Platform/Bridge/Google/ModelHandler.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
use PhpLlm\LlmChain\Platform\Exception\RuntimeException;
use PhpLlm\LlmChain\Platform\Model;
use PhpLlm\LlmChain\Platform\ModelClientInterface;
use PhpLlm\LlmChain\Platform\Response\Choice;
use PhpLlm\LlmChain\Platform\Response\ChoiceResponse;
use PhpLlm\LlmChain\Platform\Response\ResponseInterface as LlmResponse;
use PhpLlm\LlmChain\Platform\Response\StreamResponse;
use PhpLlm\LlmChain\Platform\Response\TextResponse;
use PhpLlm\LlmChain\Platform\Response\ToolCall;
use PhpLlm\LlmChain\Platform\Response\ToolCallResponse;
use PhpLlm\LlmChain\Platform\ResponseConverterInterface;
use Symfony\Component\HttpClient\EventSourceHttpClient;
use Symfony\Contracts\HttpClient\Exception\ClientExceptionInterface;
Expand Down Expand Up @@ -52,6 +56,12 @@ public function request(Model $model, array|string $payload, array $options = []

$generationConfig = ['generationConfig' => $options];
unset($generationConfig['generationConfig']['stream']);
unset($generationConfig['generationConfig']['tools']);

if (isset($options['tools'])) {
$generationConfig['tools'] = $options['tools'];
unset($options['tools']);
}

return $this->httpClient->request('POST', $url, [
'headers' => [
Expand All @@ -76,11 +86,22 @@ public function convert(ResponseInterface $response, array $options = []): LlmRe

$data = $response->toArray();

if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
if (!isset($data['candidates'][0]['content']['parts'][0])) {
throw new RuntimeException('Response does not contain any content');
}

return new TextResponse($data['candidates'][0]['content']['parts'][0]['text']);
/** @var Choice[] $choices */
$choices = array_map($this->convertChoice(...), $data['candidates']);

if (1 !== \count($choices)) {
return new ChoiceResponse(...$choices);
}

if ($choices[0]->hasToolCall()) {
return new ToolCallResponse(...$choices[0]->getToolCalls());
}

return new TextResponse($choices[0]->getContent());
}

private function convertStream(ResponseInterface $response): \Generator
Expand Down Expand Up @@ -114,12 +135,68 @@ private function convertStream(ResponseInterface $response): \Generator
throw new RuntimeException('Failed to decode JSON response', 0, $e);
}

if (!isset($data['candidates'][0]['content']['parts'][0]['text'])) {
/** @var Choice[] $choices */
$choices = array_map($this->convertChoice(...), $data['candidates'] ?? []);

if (!$choices) {
continue;
}

yield $data['candidates'][0]['content']['parts'][0]['text'];
if (1 !== \count($choices)) {
yield new ChoiceResponse(...$choices);
continue;
}

if ($choices[0]->hasToolCall()) {
yield new ToolCallResponse(...$choices[0]->getToolCalls());
}

if ($choices[0]->hasContent()) {
yield $choices[0]->getContent();
}
}
}
}

/**
* @param array{
* finishReason?: string,
* content: array{
* parts: array{
* functionCall?: array{
* id: string,
* name: string,
* args: mixed[]
* },
* text?: string
* }[]
* }
* } $choice
*/
private function convertChoice(array $choice): Choice
{
$contentPart = $choice['content']['parts'][0] ?? [];

if (isset($contentPart['functionCall'])) {
return new Choice(toolCalls: [$this->convertToolCall($contentPart['functionCall'])]);
}

if (isset($contentPart['text'])) {
return new Choice($contentPart['text']);
}

throw new RuntimeException(\sprintf('Unsupported finish reason "%s".', $choice['finishReason']));
}

/**
* @param array{
* id: string,
* name: string,
* args: mixed[]
* } $toolCall
*/
private function convertToolCall(array $toolCall): ToolCall
{
return new ToolCall($toolCall['id'] ?? '', $toolCall['name'], $toolCall['args']);
}
}
4 changes: 4 additions & 0 deletions src/Platform/Bridge/Google/PlatformFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\AssistantMessageNormalizer;
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\MessageBagNormalizer;
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\ToolCallMessageNormalizer;
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\ToolNormalizer;
use PhpLlm\LlmChain\Platform\Bridge\Google\Contract\UserMessageNormalizer;
use PhpLlm\LlmChain\Platform\Contract;
use PhpLlm\LlmChain\Platform\Platform;
Expand All @@ -28,6 +30,8 @@ public static function create(
return new Platform([$responseHandler], [$responseHandler], Contract::create(
new AssistantMessageNormalizer(),
new MessageBagNormalizer(),
new ToolNormalizer(),
new ToolCallMessageNormalizer(),
new UserMessageNormalizer(),
));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
use PhpLlm\LlmChain\Platform\Contract;
use PhpLlm\LlmChain\Platform\Message\AssistantMessage;
use PhpLlm\LlmChain\Platform\Model;
use PhpLlm\LlmChain\Platform\Response\ToolCall;
use PHPUnit\Framework\Attributes\CoversClass;
use PHPUnit\Framework\Attributes\DataProvider;
use PHPUnit\Framework\Attributes\Small;
use PHPUnit\Framework\Attributes\Test;
use PHPUnit\Framework\Attributes\UsesClass;
Expand All @@ -20,6 +22,7 @@
#[UsesClass(Gemini::class)]
#[UsesClass(AssistantMessage::class)]
#[UsesClass(Model::class)]
#[UsesClass(ToolCall::class)]
final class AssistantMessageNormalizerTest extends TestCase
{
#[Test]
Expand All @@ -41,14 +44,33 @@ public function getSupportedTypes(): void
self::assertSame([AssistantMessage::class => true], $normalizer->getSupportedTypes(null));
}

#[DataProvider('normalizeDataProvider')]
#[Test]
public function normalize(): void
public function normalize(AssistantMessage $message, array $expectedOutput): void
{
$normalizer = new AssistantMessageNormalizer();
$message = new AssistantMessage('Great to meet you. What would you like to know?');

$normalized = $normalizer->normalize($message);

self::assertSame([['text' => 'Great to meet you. What would you like to know?']], $normalized);
self::assertSame($expectedOutput, $normalized);
}

/**
* @return iterable<string, array{AssistantMessage, array{text?: string, functionCall?: array{id: string, name: string, args?: mixed}}[]}>
*/
public static function normalizeDataProvider(): iterable
{
yield 'assistant message' => [
new AssistantMessage('Great to meet you. What would you like to know?'),
[['text' => 'Great to meet you. What would you like to know?']],
];
yield 'function call' => [
new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1', ['arg1' => '123'])]),
[['functionCall' => ['id' => 'id1', 'name' => 'name1', 'args' => ['arg1' => '123']]]],
];
yield 'function call without parameters' => [
new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1')]),
[['functionCall' => ['id' => 'id1', 'name' => 'name1']]],
];
}
}
Loading
Loading