Skip to content

Commit

Permalink
Update test cases and _make_search_grounding
Browse files Browse the repository at this point in the history
  • Loading branch information
shilpakancharla committed Sep 19, 2024
1 parent cc45552 commit c5cebf2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 46 deletions.
41 changes: 12 additions & 29 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,39 +677,22 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
if isinstance(gsr, protos.GoogleSearchRetrieval):
return gsr
elif isinstance(gsr, Iterable) and not isinstance(gsr, Mapping):
# Handle list of protos.Tool(...) and list of protos.GoogleSearchRetrieval
return gsr
elif isinstance(gsr, Mapping):
if "mode" in gsr["dynamic_retrieval_config"]:
print(to_mode(gsr["dynamic_retrieval_config"]["mode"]))
# Create proto object from dictionary
gsr = {
"google_search_retrieval": {
"dynamic_retrieval_config": {
"mode": to_mode(gsr["dynamic_retrieval_config"]["mode"]),
"dynamic_threshold": gsr["dynamic_retrieval_config"]["dynamic_threshold"],
}
}
}
print(gsr)
elif "mode" in gsr.keys():
# Create proto object from dictionary
gsr = {
"google_search_retrieval": {
"dynamic_retrieval_config": {
"mode": to_mode(gsr["mode"]),
"dynamic_threshold": gsr["dynamic_threshold"],
}
}
}
return gsr
drc = gsr.get("dynamic_retrieval_config", None)
if drc is not None:
mode = drc.get("mode", None)
if mode is not None:
mode = to_mode(mode)
gsr = gsr.copy()
gsr["dynamic_retrieval_config"]["mode"] = mode
return protos.GoogleSearchRetrieval(gsr)
else:
raise TypeError(
"Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n"
f"However, received an object of type: {type(gsr)}.\n"
f"Object Value: {gsr}"
)



class Tool:
Expand Down Expand Up @@ -741,13 +724,13 @@ def __init__(

if google_search_retrieval:
if isinstance(google_search_retrieval, str):
google_search_retrieval = {
self._google_search_retrieval = {
"google_search_retrieval": {
"dynamic_retrieval_config": {"mode": to_mode(google_search_retrieval)}
}
}
else:
_make_google_search_retrieval(google_search_retrieval)
self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval)

self._proto = protos.Tool(
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
Expand All @@ -763,7 +746,7 @@ def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDec

@property
def google_search_retrieval(self) -> protos.GoogleSearchRetrieval:
return self._proto.google_search_retrieval
return self._google_search_retrieval

@property
def code_execution(self) -> protos.CodeExecution:
Expand Down
27 changes: 10 additions & 17 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,17 +426,16 @@ def no_args():
["empty_dictionary_list", [{"code_execution": {}}]],
)
def test_code_execution(self, tools):
if isinstance(tools, Iterable):
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)
else:
t = content_types._make_tool(tools) # Pass code execution into tools
self.assertIsInstance(t.code_execution, protos.CodeExecution)
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)

@parameterized.named_parameters(
["string", "google_search_retrieval"],
["empty_dictionary", {"google_search_retrieval": {}}],
["empty_dictionary_with_dynamic_retrieval_config", {"google_search_retrieval": {"dynamic_retrieval_config": {}}}],
[
"empty_dictionary_with_dynamic_retrieval_config",
{"google_search_retrieval": {"dynamic_retrieval_config": {}}},
],
[
"dictionary_with_mode_integer",
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}},
Expand All @@ -453,10 +452,6 @@ def test_code_execution(self, tools):
}
},
],
[
"dictionary_without_dynamic_retrieval_config",
{"google_search_retrieval": {"mode": "unspecified", "dynamic_threshold": 0.5}},
],
[
"proto_object",
protos.GoogleSearchRetrieval(
Expand Down Expand Up @@ -499,12 +494,10 @@ def test_code_execution(self, tools):
],
)
def test_search_grounding(self, tools):
if isinstance(tools, Iterable):
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)
else:
t = content_types._make_tool(tools) # Pass google_search_retrieval into tools
self.assertIsInstance(t.google_search_retrieval, protos.GoogleSearchRetrieval)
if self._testMethodName == "test_search_grounding_empty_dictionary":
pass
t = content_types._make_tools(tools)
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)

def test_two_fun_is_one_tool(self):
def a():
Expand Down

0 comments on commit c5cebf2

Please sign in to comment.