From 12c25ff9fd4397658e632b32b8f1f4b1cb42bf9f Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Mon, 16 Sep 2024 12:17:35 -0700 Subject: [PATCH 01/15] Updated tests and current progress on adding search grounding. --- google/generativeai/types/content_types.py | 45 +++++++++++++++++-- tests/test_content.py | 50 ++++++++++++++++++++++ 2 files changed, 92 insertions(+), 3 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index b925967c8..7dd91e22b 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -71,6 +71,27 @@ "FunctionLibraryType", ] +Mode = protos.DynamicRetrievalConfig.Mode + +ModeOptions = Union[str, str, Mode] + +_MODE: dict[ModeOptions, Mode] = { + Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED, + 0: Mode.MODE_UNSPECIFIED, + "mode_unspecified": Mode.MODE_UNSPECIFIED, + "unspecified": Mode.MODE_UNSPECIFIED, + Mode.DYNAMIC: Mode.DYNAMIC, + 1: Mode.DYNAMIC, + "mode_dynamic": Mode.DYNAMIC, + "dynamic": Mode.DYNAMIC, +} + + +def to_mode(x: ModeOptions) -> Mode: + if isinstance(x, str): + x = x.lower() + return _MODE[x] + def pil_to_blob(img): # When you load an image with PIL you get a subclass of PIL.Image @@ -656,6 +677,7 @@ class Tool: def __init__( self, function_declarations: Iterable[FunctionDeclarationType] | None = None, + google_search_retrieval: protos.GoogleSearchRetrieval | None = None, code_execution: protos.CodeExecution | None = None, ): # The main path doesn't use this but is seems useful. @@ -676,6 +698,7 @@ def __init__( self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations], + google_search_retrieval=google_search_retrieval, code_execution=code_execution, ) @@ -723,9 +746,17 @@ def _make_tool(tool: ToolType) -> Tool: code_execution = tool.code_execution else: code_execution = None - return Tool(function_declarations=tool.function_declarations, code_execution=code_execution) + return Tool( + function_declarations=tool.function_declarations, + google_search_retrieval=tool.google_search_retrieval, + code_execution=code_execution, + ) elif isinstance(tool, dict): - if "function_declarations" in tool or "code_execution" in tool: + if ( + "function_declarations" in tool + or "google_search_retrieval" in tool + or "code_execution" in tool + ): return Tool(**tool) else: fd = tool @@ -733,10 +764,18 @@ def _make_tool(tool: ToolType) -> Tool: elif isinstance(tool, str): if tool.lower() == "code_execution": return Tool(code_execution=protos.CodeExecution()) + # Check to see if one of the mode enums matches + elif to_mode(tool) == Mode.MODE_UNSPECIFIED or to_mode(tool) == Mode.DYNAMIC: + mode = to_mode(tool) + return Tool(google_search_retrieval=protos.GoogleSearchRetrieval(mode=mode)) else: - raise ValueError("The only string that can be passed as a tool is 'code_execution'.") + raise ValueError( + "The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval." + ) elif isinstance(tool, protos.CodeExecution): return Tool(code_execution=tool) + elif isinstance(tool, protos.GoogleSearchRetrieval): + return Tool(google_search_retrieval=tool) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: diff --git a/tests/test_content.py b/tests/test_content.py index dc62e997b..4cbc12866 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -433,6 +433,56 @@ def test_code_execution(self, tools): t = content_types._make_tool(tools) # Pass code execution into tools self.assertIsInstance(t.code_execution, protos.CodeExecution) + @parameterized.named_parameters( + ["string", "unspecified"], + [ + "dictionary", + {"google_search_retrieval": {"mode": "unspecified", "dynamic_threshold": 0.5}}, + ], + ["tuple", ("unspecified", 0.5)], + [ + "proto_object", + protos.GoogleSearchRetrieval( + protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5) + ), + ], + [ + "proto_passed_in", + protos.Tool( + google_search_retrieval=protos.GoogleSearchRetrieval( + protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5) + ) + ), + ], + [ + "proto_object_list", + [ + protos.GoogleSearchRetrieval( + protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5) + ) + ], + ], + [ + "proto_passed_in_list", + [ + protos.Tool( + google_search_retrieval=protos.GoogleSearchRetrieval( + protos.DynamicRetrievalConfig( + mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 + ) + ) + ) + ], + ], + ) + def test_search_grounding(self, tools): + if isinstance(tools, Iterable): + t = content_types._make_tools(tools) + self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval) + else: + t = content_types._make_tool(tools) # Pass code execution into tools + self.assertIsInstance(t.google_search_retrieval, protos.GoogleSearchRetrieval) + def test_two_fun_is_one_tool(self): def a(): pass From 250fb41f75dc91505a4d5d7947dab9d89180bc4a Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Mon, 16 Sep 2024 14:27:05 -0700 Subject: [PATCH 02/15] Update google/generativeai/types/content_types.py Co-authored-by: Mark Daoust --- google/generativeai/types/content_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 7dd91e22b..2aca3a6bf 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -73,7 +73,7 @@ Mode = protos.DynamicRetrievalConfig.Mode -ModeOptions = Union[str, str, Mode] +ModeOptions = Union[int, str, Mode] _MODE: dict[ModeOptions, Mode] = { Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED, From 8ed4a25e45632adcc1882fbb676b637ffa9f6ee3 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Mon, 16 Sep 2024 15:19:20 -0700 Subject: [PATCH 03/15] Update tests/test_content.py Co-authored-by: Mark Daoust --- tests/test_content.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_content.py b/tests/test_content.py index 4cbc12866..db501fd14 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -434,7 +434,7 @@ def test_code_execution(self, tools): self.assertIsInstance(t.code_execution, protos.CodeExecution) @parameterized.named_parameters( - ["string", "unspecified"], + ["string", "google_search_retrieval"], [ "dictionary", {"google_search_retrieval": {"mode": "unspecified", "dynamic_threshold": 0.5}}, From fa1651d03408cf21cb3c9ddae738dab0885e0222 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Thu, 19 Sep 2024 12:26:51 -0700 Subject: [PATCH 04/15] Update search grounding --- tests/test_content.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/test_content.py b/tests/test_content.py index 4cbc12866..611bd14b7 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -434,23 +434,34 @@ def test_code_execution(self, tools): self.assertIsInstance(t.code_execution, protos.CodeExecution) @parameterized.named_parameters( - ["string", "unspecified"], + ["string", "google_search_retrieval"], [ - "dictionary", + "dictionary_with_dynamic_retrieval_config", + { + "google_search_retrieval": { + "dynamic_retrieval_config": {"mode": "unspecified", "dynamic_threshold": 0.5} + } + }, + ], + [ + "dictionary_without_dynamic_retrieval_config", {"google_search_retrieval": {"mode": "unspecified", "dynamic_threshold": 0.5}}, ], - ["tuple", ("unspecified", 0.5)], [ "proto_object", protos.GoogleSearchRetrieval( - protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5) + dynamic_retrieval_config=protos.DynamicRetrievalConfig( + mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 + ) ), ], [ "proto_passed_in", protos.Tool( google_search_retrieval=protos.GoogleSearchRetrieval( - protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5) + dynamic_retrieval_config=protos.DynamicRetrievalConfig( + mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 + ) ) ), ], @@ -458,7 +469,9 @@ def test_code_execution(self, tools): "proto_object_list", [ protos.GoogleSearchRetrieval( - protos.DynamicRetrievalConfig(mode="MODE_UNSPECIFIED", dynamic_threshold=0.5) + dynamic_retrieval_config=protos.DynamicRetrievalConfig( + mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 + ) ) ], ], @@ -467,7 +480,7 @@ def test_code_execution(self, tools): [ protos.Tool( google_search_retrieval=protos.GoogleSearchRetrieval( - protos.DynamicRetrievalConfig( + dynamic_retrieval_config=protos.DynamicRetrievalConfig( mode="MODE_UNSPECIFIED", dynamic_threshold=0.5 ) ) From 84a5f29539517cdb479100c470d3a3dfcef632f2 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Thu, 19 Sep 2024 12:27:42 -0700 Subject: [PATCH 05/15] update content_types --- google/generativeai/types/content_types.py | 72 ++++++++++++++++++---- 1 file changed, 59 insertions(+), 13 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 7dd91e22b..cc8e79066 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -73,17 +73,17 @@ Mode = protos.DynamicRetrievalConfig.Mode -ModeOptions = Union[str, str, Mode] +ModeOptions = Union[int, str, Mode] _MODE: dict[ModeOptions, Mode] = { Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED, 0: Mode.MODE_UNSPECIFIED, "mode_unspecified": Mode.MODE_UNSPECIFIED, "unspecified": Mode.MODE_UNSPECIFIED, - Mode.DYNAMIC: Mode.DYNAMIC, - 1: Mode.DYNAMIC, - "mode_dynamic": Mode.DYNAMIC, - "dynamic": Mode.DYNAMIC, + Mode.MODE_DYNAMIC: Mode.MODE_DYNAMIC, + 1: Mode.MODE_DYNAMIC, + "mode_dynamic": Mode.MODE_DYNAMIC, + "dynamic": Mode.MODE_DYNAMIC, } @@ -670,14 +670,43 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F return fd.to_proto() +GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, dict[str, float]] + +def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType): + if isinstance(gsr, protos.GoogleSearchRetrieval): + return gsr + elif isinstance(gsr, Iterable) and not isinstance(gsr, Mapping): + # Handle list of protos.Tool(...) and list of protos.GoogleSearchRetrieval + return gsr + elif isinstance(gsr, Mapping): + if "mode" in gsr["dynamic_retrieval_config"]: + print(to_mode(gsr["dynamic_retrieval_config"]["mode"])) + # Create proto object from dictionary + gsr = {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]), + "dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"]}}} + print(gsr) + elif "mode" in gsr.keys(): + # Create proto object from dictionary + gsr = {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": to_mode(gsr["mode"]), + "dynamic_threshold": gsr["dynamic_threshold"]}}} + return gsr + else: + raise TypeError( + "Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n" + f"However, received an object of type: {type(gsr)}.\n" + f"Object Value: {gsr}" + ) + class Tool: - """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects, + protos.CodeExecution object, and protos.GoogleSearchRetrieval object.""" def __init__( self, + *, function_declarations: Iterable[FunctionDeclarationType] | None = None, - google_search_retrieval: protos.GoogleSearchRetrieval | None = None, + google_search_retrieval: Union[protos.GoogleSearchRetrieval, str] | None = None, code_execution: protos.CodeExecution | None = None, ): # The main path doesn't use this but is seems useful. @@ -695,6 +724,12 @@ def __init__( # Consistent fields self._function_declarations = [] self._index = {} + + if google_search_retrieval: + if isinstance(google_search_retrieval, str): + google_search_retrieval = {"google_search_retrieval" : {"dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)}}} + else: + _make_google_search_retrieval(google_search_retrieval) self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations], @@ -702,10 +737,16 @@ def __init__( code_execution=code_execution, ) + print(self._proto.google_search_retrieval) + @property def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations + @property + def google_search_retrieval(self) -> protos.GoogleSearchRetrieval: + return self._proto.google_search_retrieval + @property def code_execution(self) -> protos.CodeExecution: return self._proto.code_execution @@ -734,7 +775,7 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + str, Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] @@ -746,9 +787,15 @@ def _make_tool(tool: ToolType) -> Tool: code_execution = tool.code_execution else: code_execution = None + + if "google_search_retrieval" in tool: + google_search_retrieval = tool.google_search_retrieval + else: + google_search_retrieval = None + return Tool( function_declarations=tool.function_declarations, - google_search_retrieval=tool.google_search_retrieval, + google_search_retrieval=google_search_retrieval, code_execution=code_execution, ) elif isinstance(tool, dict): @@ -765,9 +812,8 @@ def _make_tool(tool: ToolType) -> Tool: if tool.lower() == "code_execution": return Tool(code_execution=protos.CodeExecution()) # Check to see if one of the mode enums matches - elif to_mode(tool) == Mode.MODE_UNSPECIFIED or to_mode(tool) == Mode.DYNAMIC: - mode = to_mode(tool) - return Tool(google_search_retrieval=protos.GoogleSearchRetrieval(mode=mode)) + elif tool.lower() == "google_search_retrieval": + return Tool(google_search_retrieval=protos.GoogleSearchRetrieval()) else: raise ValueError( "The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval." @@ -831,7 +877,7 @@ def to_proto(self): def _make_tools(tools: ToolsType) -> list[Tool]: if isinstance(tools, str): - if tools.lower() == "code_execution": + if tools.lower() == "code_execution" or tools.lower() == "google_search_retrieval": return [_make_tool(tools)] else: raise ValueError("The only string that can be passed as a tool is 'code_execution'.") From 91fc30c02a13ac385ab86382f4ea11ba460c684e Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Thu, 19 Sep 2024 12:37:07 -0700 Subject: [PATCH 06/15] Update and add aditional test cases --- google/generativeai/types/content_types.py | 32 +++++++++++++++++----- tests/test_content.py | 12 +++++++- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index cc8e79066..fecd43dfb 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -670,8 +670,10 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F return fd.to_proto() + GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, dict[str, float]] + def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType): if isinstance(gsr, protos.GoogleSearchRetrieval): return gsr @@ -682,13 +684,25 @@ def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType): if "mode" in gsr["dynamic_retrieval_config"]: print(to_mode(gsr["dynamic_retrieval_config"]["mode"])) # Create proto object from dictionary - gsr = {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]), - "dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"]}}} + gsr = { + "google_search_retrieval": { + "dynamic_retrieval_config": { + "mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]), + "dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"], + } + } + } print(gsr) elif "mode" in gsr.keys(): # Create proto object from dictionary - gsr = {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": to_mode(gsr["mode"]), - "dynamic_threshold": gsr["dynamic_threshold"]}}} + gsr = { + "google_search_retrieval": { + "dynamic_retrieval_config": { + "mode": to_mode(gsr["mode"]), + "dynamic_threshold": gsr["dynamic_threshold"], + } + } + } return gsr else: raise TypeError( @@ -724,10 +738,14 @@ def __init__( # Consistent fields self._function_declarations = [] self._index = {} - + if google_search_retrieval: if isinstance(google_search_retrieval, str): - google_search_retrieval = {"google_search_retrieval" : {"dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)}}} + google_search_retrieval = { + "google_search_retrieval": { + "dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)} + } + } else: _make_google_search_retrieval(google_search_retrieval) @@ -792,7 +810,7 @@ def _make_tool(tool: ToolType) -> Tool: google_search_retrieval = tool.google_search_retrieval else: google_search_retrieval = None - + return Tool( function_declarations=tool.function_declarations, google_search_retrieval=google_search_retrieval, diff --git a/tests/test_content.py b/tests/test_content.py index 611bd14b7..86f306196 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -435,6 +435,16 @@ def test_code_execution(self, tools): @parameterized.named_parameters( ["string", "google_search_retrieval"], + ["empty_dictionary", {"google_search_retrieval": {}}], + ["empty_dictionary_with_dynamic_retrieval_config", {"dynamic_retrieval_config": {}}], + [ + "dictionary_with_mode_integer", + {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}}, + ], + [ + "dictionary_with_mode_string", + {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": "DYNAMIC"}}}, + ], [ "dictionary_with_dynamic_retrieval_config", { @@ -493,7 +503,7 @@ def test_search_grounding(self, tools): t = content_types._make_tools(tools) self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval) else: - t = content_types._make_tool(tools) # Pass code execution into tools + t = content_types._make_tool(tools) # Pass google_search_retrieval into tools self.assertIsInstance(t.google_search_retrieval, protos.GoogleSearchRetrieval) def test_two_fun_is_one_tool(self): From cc455523bebb24c145bfe4ddb7686b338c9a59bc Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Thu, 19 Sep 2024 12:39:06 -0700 Subject: [PATCH 07/15] update test case on empty_dictionary_with_dynamic_retrieval_config --- tests/test_content.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_content.py b/tests/test_content.py index 86f306196..3b7962cf2 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -436,7 +436,7 @@ def test_code_execution(self, tools): @parameterized.named_parameters( ["string", "google_search_retrieval"], ["empty_dictionary", {"google_search_retrieval": {}}], - ["empty_dictionary_with_dynamic_retrieval_config", {"dynamic_retrieval_config": {}}], + ["empty_dictionary_with_dynamic_retrieval_config", {"google_search_retrieval": {"dynamic_retrieval_config": {}}}], [ "dictionary_with_mode_integer", {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}}, From c5cebf257f893fbcc308904b2f117abae1648c77 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Thu, 19 Sep 2024 16:48:06 -0700 Subject: [PATCH 08/15] Update test cases and _make_search_grounding --- google/generativeai/types/content_types.py | 41 +++++++--------------- tests/test_content.py | 27 ++++++-------- 2 files changed, 22 insertions(+), 46 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index fecd43dfb..af6ebafad 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -677,39 +677,22 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType): if isinstance(gsr, protos.GoogleSearchRetrieval): return gsr - elif isinstance(gsr, Iterable) and not isinstance(gsr, Mapping): - # Handle list of protos.Tool(...) and list of protos.GoogleSearchRetrieval - return gsr elif isinstance(gsr, Mapping): - if "mode" in gsr["dynamic_retrieval_config"]: - print(to_mode(gsr["dynamic_retrieval_config"]["mode"])) - # Create proto object from dictionary - gsr = { - "google_search_retrieval": { - "dynamic_retrieval_config": { - "mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]), - "dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"], - } - } - } - print(gsr) - elif "mode" in gsr.keys(): - # Create proto object from dictionary - gsr = { - "google_search_retrieval": { - "dynamic_retrieval_config": { - "mode": to_mode(gsr["mode"]), - "dynamic_threshold": gsr["dynamic_threshold"], - } - } - } - return gsr + drc = gsr.get("dynamic_retrieval_config", None) + if drc is not None: + mode = drc.get("mode", None) + if mode is not None: + mode = to_mode(mode) + gsr = gsr.copy() + gsr["dynamic_retrieval_config"]["mode"] = mode + return protos.GoogleSearchRetrieval(gsr) else: raise TypeError( "Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n" f"However, received an object of type: {type(gsr)}.\n" f"Object Value: {gsr}" ) + class Tool: @@ -741,13 +724,13 @@ def __init__( if google_search_retrieval: if isinstance(google_search_retrieval, str): - google_search_retrieval = { + self._google_search_retrieval = { "google_search_retrieval": { "dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)} } } else: - _make_google_search_retrieval(google_search_retrieval) + self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval) self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations], @@ -763,7 +746,7 @@ def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDec @property def google_search_retrieval(self) -> protos.GoogleSearchRetrieval: - return self._proto.google_search_retrieval + return self._google_search_retrieval @property def code_execution(self) -> protos.CodeExecution: diff --git a/tests/test_content.py b/tests/test_content.py index 3b7962cf2..e84db9b39 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -426,17 +426,16 @@ def no_args(): ["empty_dictionary_list", [{"code_execution": {}}]], ) def test_code_execution(self, tools): - if isinstance(tools, Iterable): - t = content_types._make_tools(tools) - self.assertIsInstance(t[0].code_execution, protos.CodeExecution) - else: - t = content_types._make_tool(tools) # Pass code execution into tools - self.assertIsInstance(t.code_execution, protos.CodeExecution) + t = content_types._make_tools(tools) + self.assertIsInstance(t[0].code_execution, protos.CodeExecution) @parameterized.named_parameters( ["string", "google_search_retrieval"], ["empty_dictionary", {"google_search_retrieval": {}}], - ["empty_dictionary_with_dynamic_retrieval_config", {"google_search_retrieval": {"dynamic_retrieval_config": {}}}], + [ + "empty_dictionary_with_dynamic_retrieval_config", + {"google_search_retrieval": {"dynamic_retrieval_config": {}}}, + ], [ "dictionary_with_mode_integer", {"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}}, @@ -453,10 +452,6 @@ def test_code_execution(self, tools): } }, ], - [ - "dictionary_without_dynamic_retrieval_config", - {"google_search_retrieval": {"mode": "unspecified", "dynamic_threshold": 0.5}}, - ], [ "proto_object", protos.GoogleSearchRetrieval( @@ -499,12 +494,10 @@ def test_code_execution(self, tools): ], ) def test_search_grounding(self, tools): - if isinstance(tools, Iterable): - t = content_types._make_tools(tools) - self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval) - else: - t = content_types._make_tool(tools) # Pass google_search_retrieval into tools - self.assertIsInstance(t.google_search_retrieval, protos.GoogleSearchRetrieval) + if self._testMethodName == "test_search_grounding_empty_dictionary": + pass + t = content_types._make_tools(tools) + self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval) def test_two_fun_is_one_tool(self): def a(): From e60841e4db50ce1589f4b3f65608c88a1ba0fa93 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 19 Sep 2024 17:03:32 -0700 Subject: [PATCH 09/15] fix tests Change-Id: Ib9e19d78861da180f713e09ec93d366d5d7b5762 --- google/generativeai/types/content_types.py | 17 ++++++----------- setup.py | 2 +- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index af6ebafad..f82e26839 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -703,11 +703,11 @@ def __init__( self, *, function_declarations: Iterable[FunctionDeclarationType] | None = None, - google_search_retrieval: Union[protos.GoogleSearchRetrieval, str] | None = None, + google_search_retrieval: GoogleSearchRetrievalType | None = None, code_execution: protos.CodeExecution | None = None, ): # The main path doesn't use this but is seems useful. - if function_declarations: + if function_declarations is not None: self._function_declarations = [ _make_function_declaration(f) for f in function_declarations ] @@ -722,15 +722,10 @@ def __init__( self._function_declarations = [] self._index = {} - if google_search_retrieval: - if isinstance(google_search_retrieval, str): - self._google_search_retrieval = { - "google_search_retrieval": { - "dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)} - } - } - else: - self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval) + if google_search_retrieval is not None: + self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval) + else: + self._google_search_retrieval = None self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations], diff --git a/setup.py b/setup.py index 29841ba1d..0575dcd28 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage==0.6.9", + "google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support From 2f6747046c6c0ad8c2f9e6b9e566238029bcc447 Mon Sep 17 00:00:00 2001 From: Shilpa Kancharla Date: Thu, 19 Sep 2024 17:08:11 -0700 Subject: [PATCH 10/15] Remove print statement --- google/generativeai/types/content_types.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index f82e26839..63adca8c6 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -733,8 +733,6 @@ def __init__( code_execution=code_execution, ) - print(self._proto.google_search_retrieval) - @property def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations From d15431b04b0b0e4062dc29cacceb639b76ebd2b7 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 20 Sep 2024 15:36:51 -0700 Subject: [PATCH 11/15] Fix tuned model tests Change-Id: I5ace9222954be7d903ebbdabab9efc663fa79174 --- google/generativeai/types/model_types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 03922a64e..79957c793 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -143,7 +143,7 @@ def idecode_time(parent: dict["str", Any], name: str): def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel: if isinstance(tuned_model, protos.TunedModel): - tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error + tuned_model = type(tuned_model).to_dict(tuned_model, including_default_value_fields=False) # pytype: disable=attribute-error tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None)) base_model = tuned_model.pop("base_model", None) @@ -195,6 +195,7 @@ class TunedModel: create_time: datetime.datetime | None = None update_time: datetime.datetime | None = None tuning_task: TuningTask | None = None + reader_project_numbers: List[int] | None = None @property def permissions(self) -> permission_types.Permissions: From eff6ea5b0bedc8b69bb09ce05feff301038fc27a Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 20 Sep 2024 17:37:24 -0700 Subject: [PATCH 12/15] Fix tests Change-Id: Ifa610965c5d6c38123080a7e16416ac325418285 --- google/generativeai/types/generation_types.py | 10 +++- tests/test_generation.py | 46 +++++++++++------- tests/test_generative_models.py | 47 ++----------------- 3 files changed, 43 insertions(+), 60 deletions(-) diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index 23e7fb1d8..157a2c62a 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -306,6 +306,7 @@ def _join_code_execution_result(result_1, result_2): def _join_candidates(candidates: Iterable[protos.Candidate]): + """Joins stream chunks of a single candidate.""" candidates = tuple(candidates) index = candidates[0].index # These should all be the same. @@ -321,6 +322,7 @@ def _join_candidates(candidates: Iterable[protos.Candidate]): def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]): + """Joins stream chunks where each chunk is a list of candidate chunks.""" # Assuming that is a candidate ends, it is no longer returned in the list of # candidates and that's why candidates have an index candidates = collections.defaultdict(list) @@ -344,10 +346,16 @@ def _join_prompt_feedbacks( def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]): chunks = tuple(chunks) + if 'usage_metadata' in chunks[-1]: + usage_metadata = chunks[-1].usage_metadata + else: + usage_metadata=None + + return protos.GenerateContentResponse( candidates=_join_candidate_lists(c.candidates for c in chunks), prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), - usage_metadata=chunks[-1].usage_metadata, + usage_metadata=usage_metadata, ) diff --git a/tests/test_generation.py b/tests/test_generation.py index 0cc3bfd07..2559cc6f2 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -1,4 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect +import json import string import textwrap from typing_extensions import TypedDict @@ -22,6 +38,7 @@ class Person(TypedDict): class UnitTests(parameterized.TestCase): + maxDiff = None @parameterized.named_parameters( [ "protos.GenerationConfig", @@ -416,12 +433,8 @@ def test_join_prompt_feedbacks(self): ], "role": "assistant", }, - "citation_metadata": {"citation_sources": []}, "index": 0, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [], + "citation_metadata": {}, }, { "content": { @@ -429,11 +442,7 @@ def test_join_prompt_feedbacks(self): "role": "assistant", }, "index": 1, - "citation_metadata": {"citation_sources": []}, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [], + "citation_metadata": {}, }, { "content": { @@ -458,17 +467,13 @@ def test_join_prompt_feedbacks(self): }, ] }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [], }, ] def test_join_candidates(self): candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] result = generation_types._join_candidate_lists(candidate_lists) - self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result]) + self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r, including_default_value_fields=False) for r in result]) def test_join_chunks(self): chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] @@ -480,6 +485,8 @@ def test_join_chunks(self): ], ) + chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(prompt_token_count=5) + result = generation_types._join_chunks(chunks) expected = protos.GenerateContentResponse( @@ -495,10 +502,17 @@ def test_join_chunks(self): } ], }, + "usage_metadata": { + "prompt_token_count": 5 + } + }, ) - self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected)) + expected = json.dumps(type(expected).to_dict(expected, including_default_value_fields=False), indent=4) + result = json.dumps(type(result).to_dict(result, including_default_value_fields=False), indent=4) + + self.assertEqual(expected, result) def test_generate_content_response_iterator_end_to_end(self): chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 79c1ac36f..f296b460d 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -935,8 +935,7 @@ def test_repr_for_streaming_start_to_finish(self): "citation_metadata": {} } ], - "prompt_feedback": {}, - "usage_metadata": {} + "prompt_feedback": {} }), )""" ) @@ -964,8 +963,7 @@ def test_repr_for_streaming_start_to_finish(self): "citation_metadata": {} } ], - "prompt_feedback": {}, - "usage_metadata": {} + "prompt_feedback": {} }), )""" ) @@ -1056,8 +1054,7 @@ def no_throw(): "citation_metadata": {} } ], - "prompt_feedback": {}, - "usage_metadata": {} + "prompt_feedback": {} }), ), error= """ @@ -1095,43 +1092,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): response = chat.send_message("hello2", stream=True) result = repr(response) - expected = textwrap.dedent( - """\ - response: - GenerateContentResponse( - done=True, - iterator=None, - result=protos.GenerateContentResponse({ - "candidates": [ - { - "content": { - "parts": [ - { - "text": "abc" - } - ] - }, - "finish_reason": "SAFETY", - "index": 0, - "citation_metadata": {} - } - ], - "prompt_feedback": {}, - "usage_metadata": {} - }), - ), - error= content { - parts { - text: "abc" - } - } - finish_reason: SAFETY - index: 0 - citation_metadata { - } - """ - ) - self.assertEqual(expected, result) + self.assertIn("StopCandidateException", result) def test_repr_for_multi_turn_chat(self): # Multi turn chat From 4e7677760ccc6bf119c3443a6c00c2c2cc302d6f Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 20 Sep 2024 17:39:14 -0700 Subject: [PATCH 13/15] format Change-Id: Iab48a9400d53f3cbdc5ca49c73df4f6a186a867b --- google/generativeai/types/content_types.py | 1 - google/generativeai/types/generation_types.py | 5 ++-- google/generativeai/types/model_types.py | 4 +++- tests/test_generation.py | 23 ++++++++++++------- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 63adca8c6..37e46fd14 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -692,7 +692,6 @@ def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType): f"However, received an object of type: {type(gsr)}.\n" f"Object Value: {gsr}" ) - class Tool: diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index 157a2c62a..56cdd16b3 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -346,11 +346,10 @@ def _join_prompt_feedbacks( def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]): chunks = tuple(chunks) - if 'usage_metadata' in chunks[-1]: + if "usage_metadata" in chunks[-1]: usage_metadata = chunks[-1].usage_metadata else: - usage_metadata=None - + usage_metadata = None return protos.GenerateContentResponse( candidates=_join_candidate_lists(c.candidates for c in chunks), diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 79957c793..d8fd57223 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -143,7 +143,9 @@ def idecode_time(parent: dict["str", Any], name: str): def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel: if isinstance(tuned_model, protos.TunedModel): - tuned_model = type(tuned_model).to_dict(tuned_model, including_default_value_fields=False) # pytype: disable=attribute-error + tuned_model = type(tuned_model).to_dict( + tuned_model, including_default_value_fields=False + ) # pytype: disable=attribute-error tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None)) base_model = tuned_model.pop("base_model", None) diff --git a/tests/test_generation.py b/tests/test_generation.py index 2559cc6f2..a1461e8b5 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -39,6 +39,7 @@ class Person(TypedDict): class UnitTests(parameterized.TestCase): maxDiff = None + @parameterized.named_parameters( [ "protos.GenerationConfig", @@ -473,7 +474,10 @@ def test_join_prompt_feedbacks(self): def test_join_candidates(self): candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] result = generation_types._join_candidate_lists(candidate_lists) - self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r, including_default_value_fields=False) for r in result]) + self.assertEqual( + self.MERGED_CANDIDATES, + [type(r).to_dict(r, including_default_value_fields=False) for r in result], + ) def test_join_chunks(self): chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] @@ -485,7 +489,9 @@ def test_join_chunks(self): ], ) - chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(prompt_token_count=5) + chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata( + prompt_token_count=5 + ) result = generation_types._join_chunks(chunks) @@ -502,15 +508,16 @@ def test_join_chunks(self): } ], }, - "usage_metadata": { - "prompt_token_count": 5 - } - + "usage_metadata": {"prompt_token_count": 5}, }, ) - expected = json.dumps(type(expected).to_dict(expected, including_default_value_fields=False), indent=4) - result = json.dumps(type(result).to_dict(result, including_default_value_fields=False), indent=4) + expected = json.dumps( + type(expected).to_dict(expected, including_default_value_fields=False), indent=4 + ) + result = json.dumps( + type(result).to_dict(result, including_default_value_fields=False), indent=4 + ) self.assertEqual(expected, result) From 3f348fa38b20213ff7cbe95da11633cb0fa0dc74 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 24 Sep 2024 09:28:14 -0700 Subject: [PATCH 14/15] fix typing Change-Id: If892b20ca29d1afb82c48ae1a49bef58e0421bab --- google/generativeai/types/content_types.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 37e46fd14..39f0e138a 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -671,7 +671,16 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F return fd.to_proto() -GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, dict[str, float]] +class DynamicRetrievalConfigDict(TypedDict): + mode: protos.DynamicRetrievalConfig.mode + dynamic_threshold: float + +DynamicRetrievalConfig = Union[protos.DynamicRetrievalConfig, DynamicRetrievalConfigDict] + +class GoogleSearchRetrievalDict(TypedDict): + dynamic_retrieval_config: DynamicRetrievalConfig + +GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, GoogleSearchRetrievalDict] def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType): @@ -679,7 +688,7 @@ def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType): return gsr elif isinstance(gsr, Mapping): drc = gsr.get("dynamic_retrieval_config", None) - if drc is not None: + if drc is not None and isinstance(drc, Mapping): mode = drc.get("mode", None) if mode is not None: mode = to_mode(mode) From a26120bea66f1529f67a93c43a8c2b4c1e96b54e Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 24 Sep 2024 09:41:55 -0700 Subject: [PATCH 15/15] Format Change-Id: I51a51150879adb3d4b6b00323e0d8eaf4c0b2515 --- google/generativeai/types/content_types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 39f0e138a..8e3ddaa81 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -675,11 +675,14 @@ class DynamicRetrievalConfigDict(TypedDict): mode: protos.DynamicRetrievalConfig.mode dynamic_threshold: float + DynamicRetrievalConfig = Union[protos.DynamicRetrievalConfig, DynamicRetrievalConfigDict] + class GoogleSearchRetrievalDict(TypedDict): dynamic_retrieval_config: DynamicRetrievalConfig + GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, GoogleSearchRetrievalDict]