From 4cddb515844c10a62db41a682d623cf21c49cd36 Mon Sep 17 00:00:00 2001 From: Konstantin Krestnikov Date: Mon, 11 Nov 2024 20:33:51 +0300 Subject: [PATCH] feat: add pydantic schema support for few_shot_examples --- .../utils/function_calling.py | 4 +- .../tests/unit_tests/test_gigachat.py | 43 ++++++++++++++++++- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/libs/gigachat/langchain_gigachat/utils/function_calling.py b/libs/gigachat/langchain_gigachat/utils/function_calling.py index d8d4842..6915454 100644 --- a/libs/gigachat/langchain_gigachat/utils/function_calling.py +++ b/libs/gigachat/langchain_gigachat/utils/function_calling.py @@ -365,8 +365,8 @@ def convert_pydantic_to_gigachat_function( "Incorrect function or tool description. Description is required." ) - if few_shot_examples is None and hasattr(model, 'few_shot_examples'): - few_shot_examples_attr = getattr(model, 'few_shot_examples') + if few_shot_examples is None and hasattr(model, "few_shot_examples"): + few_shot_examples_attr = getattr(model, "few_shot_examples") if inspect.isfunction(few_shot_examples_attr): few_shot_examples = few_shot_examples_attr() diff --git a/libs/gigachat/tests/unit_tests/test_gigachat.py b/libs/gigachat/tests/unit_tests/test_gigachat.py index b56c178..204eab4 100644 --- a/libs/gigachat/tests/unit_tests/test_gigachat.py +++ b/libs/gigachat/tests/unit_tests/test_gigachat.py @@ -31,7 +31,7 @@ _convert_dict_to_message, _convert_message_to_dict, ) -from langchain_gigachat.tools.giga_tool import giga_tool +from langchain_gigachat.tools.giga_tool import FewShotExamples, giga_tool from tests.unit_tests.stubs import FakeAsyncCallbackHandler, FakeCallbackHandler @@ -334,3 +334,44 @@ def test_gigachat_bind_gigatool() -> None: "required": ["status", "message"], "type": "object", } + + +class SomeResult(BaseModel): + """My desc""" + + @staticmethod + def few_shot_examples() -> FewShotExamples: + return [ + { + "request": "request example", + "params": {"is_valid": 1, "description": "correct message"}, + } + ] + + value: int = Field(description="some value") + description: str = Field(description="some descriptin") + + +def test_structured_output() -> None: + llm = GigaChat().with_structured_output(SomeResult) + assert llm.steps[0].kwargs["function_call"] == {"name": "SomeResult"} # type: ignore[attr-defined] + assert llm.steps[0].kwargs["tools"][0]["function"] == { # type: ignore[attr-defined] + "name": "SomeResult", + "description": "My desc", + "parameters": { + "description": "My desc", + "properties": { + "value": {"description": "some value", "type": "integer"}, + "description": {"description": "some descriptin", "type": "string"}, + }, + "required": ["value", "description"], + "type": "object", + }, + "return_parameters": None, + "few_shot_examples": [ + { + "request": "request example", + "params": {"is_valid": 1, "description": "correct message"}, + } + ], + }