diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 19341b625..5b143d768 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -48,10 +48,6 @@ from google.generativeai.client import configure -from google.generativeai.discuss import chat -from google.generativeai.discuss import chat_async -from google.generativeai.discuss import count_message_tokens - from google.generativeai.embedding import embed_content from google.generativeai.embedding import embed_content_async @@ -77,19 +73,13 @@ from google.generativeai.operations import list_operations from google.generativeai.operations import get_operation -from google.generativeai.text import generate_text -from google.generativeai.text import generate_embeddings -from google.generativeai.text import count_text_tokens - from google.generativeai.types import GenerationConfig __version__ = version.__version__ -del discuss del embedding del files del generative_models -del text del models del client del operations diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 4dd93feaf..83bf5f679 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -283,7 +283,7 @@ def generate_answer( answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. temperature: Controls the randomness of the output. - client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead. request_options: Options for the request. Returns: @@ -337,7 +337,7 @@ async def generate_answer_async( answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. temperature: Controls the randomness of the output. - client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `glm.GenerativeServiceClient` instead. Returns: A `types.Answer` containing the model's text answer response. diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 7e2193890..01d0a003b 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -108,9 +108,6 @@ async def create_file(self, *args, **kwargs): class _ClientManager: client_config: dict[str, Any] = dataclasses.field(default_factory=dict) default_metadata: Sequence[tuple[str, str]] = () - - discuss_client: glm.DiscussServiceClient | None = None - discuss_async_client: glm.DiscussServiceAsyncClient | None = None clients: dict[str, Any] = dataclasses.field(default_factory=dict) def configure( @@ -119,7 +116,7 @@ def configure( api_key: str | None = None, credentials: ga_credentials.Credentials | dict | None = None, # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'. - # See `_transport_registry` in `DiscussServiceClientMeta`. + # See _transport_registry in the google.ai.generativelanguage package. # Since the transport classes align with the client classes it wouldn't make # sense to accept a `Transport` object here even though the client classes can. # We could accept a dict since all the `Transport` classes take the same args, @@ -281,7 +278,6 @@ def configure( api_key: str | None = None, credentials: ga_credentials.Credentials | dict | None = None, # The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'. - # See `_transport_registry` in `DiscussServiceClientMeta`. # Since the transport classes align with the client classes it wouldn't make # sense to accept a `Transport` object here even though the client classes can. # We could accept a dict since all the `Transport` classes take the same args, @@ -326,14 +322,6 @@ def get_default_cache_client() -> glm.CacheServiceClient: return _client_manager.get_default_client("cache") -def get_default_discuss_client() -> glm.DiscussServiceClient: - return _client_manager.get_default_client("discuss") - - -def get_default_discuss_async_client() -> glm.DiscussServiceAsyncClient: - return _client_manager.get_default_client("discuss_async") - - def get_default_file_client() -> glm.FilesServiceClient: return _client_manager.get_default_client("file") @@ -350,10 +338,6 @@ def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient: return _client_manager.get_default_client("generative_async") -def get_default_text_client() -> glm.TextServiceClient: - return _client_manager.get_default_client("text") - - def get_default_operations_client() -> operations_v1.OperationsClient: return _client_manager.get_default_client("operations") diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py deleted file mode 100644 index 448347b41..000000000 --- a/google/generativeai/discuss.py +++ /dev/null @@ -1,599 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import dataclasses -import sys -import textwrap - -from typing import Iterable, List - -import google.ai.generativelanguage as glm - -from google.generativeai.client import get_default_discuss_client -from google.generativeai.client import get_default_discuss_async_client -from google.generativeai import string_utils -from google.generativeai import protos -from google.generativeai.types import discuss_types -from google.generativeai.types import helper_types -from google.generativeai.types import model_types -from google.generativeai.types import palm_safety_types - - -def _make_message(content: discuss_types.MessageOptions) -> protos.Message: - """Creates a `protos.Message` object from the provided content.""" - if isinstance(content, protos.Message): - return content - if isinstance(content, str): - return protos.Message(content=content) - else: - return protos.Message(content) - - -def _make_messages( - messages: discuss_types.MessagesOptions, -) -> List[protos.Message]: - """ - Creates a list of `protos.Message` objects from the provided messages. - - This function takes a variety of message content inputs, such as strings, dictionaries, - or `protos.Message` objects, and creates a list of `protos.Message` objects. It ensures that - the authors of the messages alternate appropriately. If authors are not provided, - default authors are assigned based on their position in the list. - - Args: - messages: The messages to convert. - - Returns: - A list of `protos.Message` objects with alternating authors. - """ - if isinstance(messages, (str, dict, protos.Message)): - messages = [_make_message(messages)] - else: - messages = [_make_message(message) for message in messages] - - even_authors = set(msg.author for msg in messages[::2] if msg.author) - if not even_authors: - even_author = "0" - elif len(even_authors) == 1: - even_author = even_authors.pop() - else: - raise discuss_types.AuthorError( - "Invalid sequence: Authors in the discussion must alternate strictly." - ) - - odd_authors = set(msg.author for msg in messages[1::2] if msg.author) - if not odd_authors: - odd_author = "1" - elif len(odd_authors) == 1: - odd_author = odd_authors.pop() - else: - raise discuss_types.AuthorError( - "Invalid sequence: Authors in the discussion must alternate strictly." - ) - - if all(msg.author for msg in messages): - return messages - - authors = [even_author, odd_author] - for i, msg in enumerate(messages): - msg.author = authors[i % 2] - - return messages - - -def _make_example(item: discuss_types.ExampleOptions) -> protos.Example: - """Creates a `protos.Example` object from the provided item.""" - if isinstance(item, protos.Example): - return item - - if isinstance(item, dict): - item = item.copy() - item["input"] = _make_message(item["input"]) - item["output"] = _make_message(item["output"]) - return protos.Example(item) - - if isinstance(item, Iterable): - input, output = list(item) - return protos.Example(input=_make_message(input), output=_make_message(output)) - - # try anyway - return protos.Example(item) - - -def _make_examples_from_flat( - examples: List[discuss_types.MessageOptions], -) -> List[protos.Example]: - """ - Creates a list of `protos.Example` objects from a list of message options. - - This function takes a list of `discuss_types.MessageOptions` and pairs them into - `protos.Example` objects. The input examples must be in pairs to create valid examples. - - Args: - examples: The list of `discuss_types.MessageOptions`. - - Returns: - A list of `protos.Example objects` created by pairing up the provided messages. - - Raises: - ValueError: If the provided list of examples is not of even length. - """ - if len(examples) % 2 != 0: - raise ValueError( - textwrap.dedent( - f"""\ - Invalid input: You must pass either `Primer` objects, pairs of messages, or an even number of messages. - Currently, {len(examples)} messages were provided, which is an odd number.""" - ) - ) - result = [] - pair = [] - for n, item in enumerate(examples): - msg = _make_message(item) - pair.append(msg) - if n % 2 == 0: - continue - primer = protos.Example( - input=pair[0], - output=pair[1], - ) - result.append(primer) - pair = [] - return result - - -def _make_examples( - examples: discuss_types.ExamplesOptions, -) -> List[protos.Example]: - """ - Creates a list of `protos.Example` objects from the provided examples. - - This function takes various types of example content inputs and creates a list - of `protos.Example` objects. It handles the conversion of different input types and ensures - the appropriate structure for creating valid examples. - - Args: - examples: The examples to convert. - - Returns: - A list of `protos.Example` objects created from the provided examples. - """ - if isinstance(examples, protos.Example): - return [examples] - - if isinstance(examples, dict): - return [_make_example(examples)] - - examples = list(examples) - - if not examples: - return examples - - first = examples[0] - - if isinstance(first, dict): - if "content" in first: - # These are `Messages` - return _make_examples_from_flat(examples) - else: - if not ("input" in first and "output" in first): - raise TypeError( - "Invalid dictionary format: To create an `Example` instance, the dictionary must contain both `input` and `output` keys." - ) - else: - if isinstance(first, discuss_types.MESSAGE_OPTIONS): - return _make_examples_from_flat(examples) - - result = [] - for item in examples: - result.append(_make_example(item)) - return result - - -def _make_message_prompt_dict( - prompt: discuss_types.MessagePromptOptions = None, - *, - context: str | None = None, - examples: discuss_types.ExamplesOptions | None = None, - messages: discuss_types.MessagesOptions | None = None, -) -> protos.MessagePrompt: - """ - Creates a `protos.MessagePrompt` object from the provided prompt components. - - This function constructs a `protos.MessagePrompt` object using the provided `context`, `examples`, - or `messages`. It ensures the proper structure and handling of the input components. - - Either pass a `prompt` or it's component `context`, `examples`, `messages`. - - Args: - prompt: The complete prompt components. - context: The context for the prompt. - examples: The examples for the prompt. - messages: The messages for the prompt. - - Returns: - A `protos.MessagePrompt` object created from the provided prompt components. - """ - if prompt is None: - prompt = dict( - context=context, - examples=examples, - messages=messages, - ) - else: - flat_prompt = (context is not None) or (examples is not None) or (messages is not None) - if flat_prompt: - raise ValueError( - "Invalid configuration: Either `prompt` or its fields `(context, examples, messages)` should be set, but not both simultaneously." - ) - if isinstance(prompt, protos.MessagePrompt): - return prompt - elif isinstance(prompt, dict): # Always check dict before Iterable. - pass - else: - prompt = {"messages": prompt} - - keys = set(prompt.keys()) - if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS): - raise KeyError( - f"Invalid prompt dictionary: Extra entries found that are not recognized: {keys - discuss_types.MESSAGE_PROMPT_KEYS}. Please check the keys." - ) - - examples = prompt.get("examples", None) - if examples is not None: - prompt["examples"] = _make_examples(examples) - messages = prompt.get("messages", None) - if messages is not None: - prompt["messages"] = _make_messages(messages) - - prompt = {k: v for k, v in prompt.items() if v is not None} - return prompt - - -def _make_message_prompt( - prompt: discuss_types.MessagePromptOptions = None, - *, - context: str | None = None, - examples: discuss_types.ExamplesOptions | None = None, - messages: discuss_types.MessagesOptions | None = None, -) -> protos.MessagePrompt: - """Creates a `protos.MessagePrompt` object from the provided prompt components.""" - prompt = _make_message_prompt_dict( - prompt=prompt, context=context, examples=examples, messages=messages - ) - return protos.MessagePrompt(prompt) - - -def _make_generate_message_request( - *, - model: model_types.AnyModelNameOptions | None, - context: str | None = None, - examples: discuss_types.ExamplesOptions | None = None, - messages: discuss_types.MessagesOptions | None = None, - temperature: float | None = None, - candidate_count: int | None = None, - top_p: float | None = None, - top_k: float | None = None, - prompt: discuss_types.MessagePromptOptions | None = None, -) -> protos.GenerateMessageRequest: - """Creates a `protos.GenerateMessageRequest` object for generating messages.""" - model = model_types.make_model_name(model) - - prompt = _make_message_prompt( - prompt=prompt, context=context, examples=examples, messages=messages - ) - - return protos.GenerateMessageRequest( - model=model, - prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, - candidate_count=candidate_count, - ) - - -DEFAULT_DISCUSS_MODEL = "models/chat-bison-001" - - -def chat( - *, - model: model_types.AnyModelNameOptions | None = "models/chat-bison-001", - context: str | None = None, - examples: discuss_types.ExamplesOptions | None = None, - messages: discuss_types.MessagesOptions | None = None, - temperature: float | None = None, - candidate_count: int | None = None, - top_p: float | None = None, - top_k: float | None = None, - prompt: discuss_types.MessagePromptOptions | None = None, - client: glm.DiscussServiceClient | None = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> discuss_types.ChatResponse: - """Calls the API to initiate a chat with a model using provided parameters - - Args: - model: Which model to call, as a string or a `types.Model`. - context: Text that should be provided to the model first, to ground the response. - - If not empty, this `context` will be given to the model first before the - `examples` and `messages`. - - This field can be a description of your prompt to the model to help provide - context and guide the responses. - - Examples: - - * "Translate the phrase from English to French." - * "Given a statement, classify the sentiment as happy, sad or neutral." - - Anything included in this field will take precedence over history in `messages` - if the total input size exceeds the model's `Model.input_token_limit`. - examples: Examples of what the model should generate. - - This includes both the user input and the response that the model should - emulate. - - These `examples` are treated identically to conversation messages except - that they take precedence over the history in `messages`: - If the total input size exceeds the model's `input_token_limit` the input - will be truncated. Items will be dropped from `messages` before `examples` - messages: A snapshot of the conversation history sorted chronologically. - - Turns alternate between two authors. - - If the total input size exceeds the model's `input_token_limit` the input - will be truncated: The oldest items will be dropped from `messages`. - temperature: Controls the randomness of the output. Must be positive. - - Typical values are in the range: `[0.0,1.0]`. Higher values produce a - more random and varied response. A temperature of zero will be deterministic. - candidate_count: The **maximum** number of generated response messages to return. - - This value must be between `[1, 8]`, inclusive. If unset, this - will default to `1`. - - Note: Only unique candidates are returned. Higher temperatures are more - likely to produce unique candidates. Setting `temperature=0.0` will always - return 1 candidate regardless of the `candidate_count`. - top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and - top-k sampling. - - `top_k` sets the maximum number of tokens to sample from on each step. - top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and - top-k sampling. - - `top_p` configures the nucleus sampling. It sets the maximum cumulative - probability of tokens to sample from. - - For example, if the sorted probabilities are - `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample - as `[0.625, 0.25, 0.125, 0, 0, 0]`. - - Typical values are in the `[0.9, 1.0]` range. - prompt: You may pass a `types.MessagePromptOptions` **instead** of a - setting `context`/`examples`/`messages`, but not both. - client: If you're not relying on the default client, you pass a - `glm.DiscussServiceClient` instead. - request_options: Options for the request. - - Returns: - A `types.ChatResponse` containing the model's reply. - """ - request = _make_generate_message_request( - model=model, - context=context, - examples=examples, - messages=messages, - temperature=temperature, - candidate_count=candidate_count, - top_p=top_p, - top_k=top_k, - prompt=prompt, - ) - - return _generate_response(client=client, request=request, request_options=request_options) - - -@string_utils.set_doc(chat.__doc__) -async def chat_async( - *, - model: model_types.AnyModelNameOptions | None = "models/chat-bison-001", - context: str | None = None, - examples: discuss_types.ExamplesOptions | None = None, - messages: discuss_types.MessagesOptions | None = None, - temperature: float | None = None, - candidate_count: int | None = None, - top_p: float | None = None, - top_k: float | None = None, - prompt: discuss_types.MessagePromptOptions | None = None, - client: glm.DiscussServiceAsyncClient | None = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> discuss_types.ChatResponse: - """Calls the API asynchronously to initiate a chat with a model using provided parameters""" - request = _make_generate_message_request( - model=model, - context=context, - examples=examples, - messages=messages, - temperature=temperature, - candidate_count=candidate_count, - top_p=top_p, - top_k=top_k, - prompt=prompt, - ) - - return await _generate_response_async( - client=client, request=request, request_options=request_options - ) - - -if (sys.version_info.major, sys.version_info.minor) >= (3, 10): - DATACLASS_KWARGS = {"kw_only": True} -else: - DATACLASS_KWARGS = {} - - -@string_utils.prettyprint -@string_utils.set_doc(discuss_types.ChatResponse.__doc__) -@dataclasses.dataclass(**DATACLASS_KWARGS, init=False) -class ChatResponse(discuss_types.ChatResponse): - _client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False) - - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - @property - @string_utils.set_doc(discuss_types.ChatResponse.last.__doc__) - def last(self) -> str | None: - if self.messages[-1]: - return self.messages[-1]["content"] - else: - return None - - @last.setter - def last(self, message: discuss_types.MessageOptions): - message = _make_message(message) - message = type(message).to_dict(message) - self.messages[-1] = message - - @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__) - def reply( - self, - message: discuss_types.MessageOptions, - request_options: helper_types.RequestOptionsType | None = None, - ) -> discuss_types.ChatResponse: - if isinstance(self._client, glm.DiscussServiceAsyncClient): - raise TypeError( - "Invalid operation: The 'reply' method cannot be called on an asynchronous client. Please use the 'reply_async' method instead." - ) - if self.last is None: - raise ValueError( - f"Invalid operation: No candidates returned from the model's last response. " - f"Please inspect the '.filters' attribute to understand why responses were filtered out. Current filters: {self.filters}" - ) - - request = self.to_dict() - request.pop("candidates") - request.pop("filters", None) - request["messages"] = list(request["messages"]) - request["messages"].append(_make_message(message)) - request = _make_generate_message_request(**request) - return _generate_response( - request=request, client=self._client, request_options=request_options - ) - - @string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__) - async def reply_async( - self, message: discuss_types.MessageOptions - ) -> discuss_types.ChatResponse: - if isinstance(self._client, glm.DiscussServiceClient): - raise TypeError( - "Invalid method call: `reply_async` is not supported on a non-async client. Please use the `reply` method instead." - ) - request = self.to_dict() - request.pop("candidates") - request.pop("filters", None) - request["messages"] = list(request["messages"]) - request["messages"].append(_make_message(message)) - request = _make_generate_message_request(**request) - return await _generate_response_async(request=request, client=self._client) - - -def _build_chat_response( - request: protos.GenerateMessageRequest, - response: protos.GenerateMessageResponse, - client: glm.DiscussServiceClient | protos.DiscussServiceAsyncClient, -) -> ChatResponse: - request = type(request).to_dict(request) - prompt = request.pop("prompt") - request["examples"] = prompt["examples"] - request["context"] = prompt["context"] - request["messages"] = prompt["messages"] - - response = type(response).to_dict(response) - response.pop("messages") - - response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"]) - - if response["candidates"]: - last = response["candidates"][0] - else: - last = None - request["messages"].append(last) - request.setdefault("temperature", None) - request.setdefault("candidate_count", None) - - return ChatResponse(_client=client, **response, **request) # pytype: disable=missing-parameter - - -def _generate_response( - request: protos.GenerateMessageRequest, - client: glm.DiscussServiceClient | None = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> ChatResponse: - if request_options is None: - request_options = {} - - if client is None: - client = get_default_discuss_client() - - response = client.generate_message(request, **request_options) - - return _build_chat_response(request, response, client) - - -async def _generate_response_async( - request: protos.GenerateMessageRequest, - client: glm.DiscussServiceAsyncClient | None = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> ChatResponse: - if request_options is None: - request_options = {} - - if client is None: - client = get_default_discuss_async_client() - - response = await client.generate_message(request, **request_options) - - return _build_chat_response(request, response, client) - - -def count_message_tokens( - *, - prompt: discuss_types.MessagePromptOptions = None, - context: str | None = None, - examples: discuss_types.ExamplesOptions | None = None, - messages: discuss_types.MessagesOptions | None = None, - model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, - client: glm.DiscussServiceAsyncClient | None = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> discuss_types.TokenCount: - """Calls the API to calculate the number of tokens used in the prompt.""" - - model = model_types.make_model_name(model) - prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages) - - if request_options is None: - request_options = {} - - if client is None: - client = get_default_discuss_client() - - result = client.count_message_tokens(model=model, prompt=prompt, **request_options) - - return type(result).to_dict(result) diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 616fa07bf..15645c792 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -24,8 +24,8 @@ from google.generativeai.client import get_default_generative_async_client from google.generativeai.types import helper_types -from google.generativeai.types import text_types from google.generativeai.types import model_types +from google.generativeai.types import text_types from google.generativeai.types import content_types DEFAULT_EMB_MODEL = "models/embedding-001" diff --git a/google/generativeai/text.py b/google/generativeai/text.py deleted file mode 100644 index 2a6267661..000000000 --- a/google/generativeai/text.py +++ /dev/null @@ -1,347 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import dataclasses -from collections.abc import Iterable, Sequence -import itertools -from typing import Any, Iterable, overload, TypeVar - -import google.ai.generativelanguage as glm - -from google.generativeai import protos - -from google.generativeai.client import get_default_text_client -from google.generativeai import string_utils -from google.generativeai.types import helper_types -from google.generativeai.types import text_types -from google.generativeai.types import model_types -from google.generativeai import models -from google.generativeai.types import palm_safety_types - -DEFAULT_TEXT_MODEL = "models/text-bison-001" -EMBEDDING_MAX_BATCH_SIZE = 100 - -try: - # python 3.12+ - _batched = itertools.batched # type: ignore -except AttributeError: - T = TypeVar("T") - - def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: - if n < 1: - raise ValueError(f"Batch size `n` must be >1, got: {n}") - batch = [] - for item in iterable: - batch.append(item) - if len(batch) == n: - yield batch - batch = [] - - if batch: - yield batch - - -def _make_text_prompt(prompt: str | dict[str, str]) -> protos.TextPrompt: - """ - Creates a `protos.TextPrompt` object based on the provided prompt input. - - Args: - prompt: The prompt input, either a string or a dictionary. - - Returns: - protos.TextPrompt: A TextPrompt object containing the prompt text. - - Raises: - TypeError: If the provided prompt is neither a string nor a dictionary. - """ - if isinstance(prompt, str): - return protos.TextPrompt(text=prompt) - elif isinstance(prompt, dict): - return protos.TextPrompt(prompt) - else: - raise TypeError( - "Invalid argument type: Expected a string or dictionary for the text prompt." - ) - - -def _make_generate_text_request( - *, - model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL, - prompt: str | None = None, - temperature: float | None = None, - candidate_count: int | None = None, - max_output_tokens: int | None = None, - top_p: int | None = None, - top_k: int | None = None, - safety_settings: palm_safety_types.SafetySettingOptions | None = None, - stop_sequences: str | Iterable[str] | None = None, -) -> protos.GenerateTextRequest: - """ - Creates a `protos.GenerateTextRequest` object based on the provided parameters. - - This function generates a `protos.GenerateTextRequest` object with the specified - parameters. It prepares the input parameters and creates a request that can be - used for generating text using the chosen model. - - Args: - model: The model to use for text generation. - prompt: The prompt for text generation. Defaults to None. - temperature: The temperature for randomness in generation. Defaults to None. - candidate_count: The number of candidates to consider. Defaults to None. - max_output_tokens: The maximum number of output tokens. Defaults to None. - top_p: The nucleus sampling probability threshold. Defaults to None. - top_k: The top-k sampling parameter. Defaults to None. - safety_settings: Safety settings for generated text. Defaults to None. - stop_sequences: Stop sequences to halt text generation. Can be a string - or iterable of strings. Defaults to None. - - Returns: - `protos.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. - """ - model = model_types.make_model_name(model) - prompt = _make_text_prompt(prompt=prompt) - safety_settings = palm_safety_types.normalize_safety_settings(safety_settings) - if isinstance(stop_sequences, str): - stop_sequences = [stop_sequences] - if stop_sequences: - stop_sequences = list(stop_sequences) - - return protos.GenerateTextRequest( - model=model, - prompt=prompt, - temperature=temperature, - candidate_count=candidate_count, - max_output_tokens=max_output_tokens, - top_p=top_p, - top_k=top_k, - safety_settings=safety_settings, - stop_sequences=stop_sequences, - ) - - -def generate_text( - *, - model: model_types.AnyModelNameOptions = DEFAULT_TEXT_MODEL, - prompt: str, - temperature: float | None = None, - candidate_count: int | None = None, - max_output_tokens: int | None = None, - top_p: float | None = None, - top_k: float | None = None, - safety_settings: palm_safety_types.SafetySettingOptions | None = None, - stop_sequences: str | Iterable[str] | None = None, - client: glm.TextServiceClient | None = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> text_types.Completion: - """Calls the API to generate text based on the provided prompt. - - Args: - model: Which model to call, as a string or a `types.Model`. - prompt: Free-form input text given to the model. Given a prompt, the model will - generate text that completes the input text. - temperature: Controls the randomness of the output. Must be positive. - Typical values are in the range: `[0.0,1.0]`. Higher values produce a - more random and varied response. A temperature of zero will be deterministic. - candidate_count: The **maximum** number of generated response messages to return. - This value must be between `[1, 8]`, inclusive. If unset, this - will default to `1`. - - Note: Only unique candidates are returned. Higher temperatures are more - likely to produce unique candidates. Setting `temperature=0.0` will always - return 1 candidate regardless of the `candidate_count`. - max_output_tokens: Maximum number of tokens to include in a candidate. Must be greater - than zero. If unset, will default to 64. - top_k: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. - `top_k` sets the maximum number of tokens to sample from on each step. - top_p: The API uses combined [nucleus](https://arxiv.org/abs/1904.09751) and top-k sampling. - `top_p` configures the nucleus sampling. It sets the maximum cumulative - probability of tokens to sample from. - For example, if the sorted probabilities are - `[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample - as `[0.625, 0.25, 0.125, 0, 0, 0]`. - safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content. - These will be enforced on the `prompt` and - `candidates`. There should not be more than one - setting for each `types.SafetyCategory` type. The API will block any prompts and - responses that fail to meet the thresholds set by these settings. This list - overrides the default settings for each `SafetyCategory` specified in the - safety_settings. If there is no `types.SafetySetting` for a given - `SafetyCategory` provided in the list, the API will use the default safety - setting for that category. - stop_sequences: A set of up to 5 character sequences that will stop output generation. - If specified, the API will stop at the first appearance of a stop - sequence. The stop sequence will not be included as part of the response. - client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. - request_options: Options for the request. - - Returns: - A `types.Completion` containing the model's text completion response. - """ - request = _make_generate_text_request( - model=model, - prompt=prompt, - temperature=temperature, - candidate_count=candidate_count, - max_output_tokens=max_output_tokens, - top_p=top_p, - top_k=top_k, - safety_settings=safety_settings, - stop_sequences=stop_sequences, - ) - - return _generate_response(client=client, request=request, request_options=request_options) - - -@string_utils.prettyprint -@dataclasses.dataclass(init=False) -class Completion(text_types.Completion): - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - self.result = None - if self.candidates: - self.result = self.candidates[0]["output"] - - -def _generate_response( - request: protos.GenerateTextRequest, - client: glm.TextServiceClient = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> Completion: - """ - Generates a response using the provided `protos.GenerateTextRequest` and client. - - Args: - request: The text generation request. - client: The client to use for text generation. Defaults to None, in which - case the default text client is used. - request_options: Options for the request. - - Returns: - `Completion`: A `Completion` object with the generated text and response information. - """ - if request_options is None: - request_options = {} - - if client is None: - client = get_default_text_client() - - response = client.generate_text(request, **request_options) - response = type(response).to_dict(response) - - response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"]) - response["safety_feedback"] = palm_safety_types.convert_safety_feedback_to_enums( - response["safety_feedback"] - ) - response["candidates"] = palm_safety_types.convert_candidate_enums(response["candidates"]) - - return Completion(_client=client, **response) - - -def count_text_tokens( - model: model_types.AnyModelNameOptions, - prompt: str, - client: glm.TextServiceClient | None = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> text_types.TokenCount: - """Calls the API to count the number of tokens in the text prompt.""" - - base_model = models.get_base_model_name(model) - - if request_options is None: - request_options = {} - - if client is None: - client = get_default_text_client() - - result = client.count_text_tokens( - protos.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), - **request_options, - ) - - return type(result).to_dict(result) - - -@overload -def generate_embeddings( - model: model_types.BaseModelNameOptions, - text: str, - client: glm.TextServiceClient = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> text_types.EmbeddingDict: ... - - -@overload -def generate_embeddings( - model: model_types.BaseModelNameOptions, - text: Sequence[str], - client: glm.TextServiceClient = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> text_types.BatchEmbeddingDict: ... - - -def generate_embeddings( - model: model_types.BaseModelNameOptions, - text: str | Sequence[str], - client: glm.TextServiceClient = None, - request_options: helper_types.RequestOptionsType | None = None, -) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: - """Calls the API to create an embedding for the text passed in. - - Args: - model: Which model to call, as a string or a `types.Model`. - - text: Free-form input text given to the model. Given a string, the model will - generate an embedding based on the input text. - - client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. - - request_options: Options for the request. - - Returns: - Dictionary containing the embedding (list of float values) for the input text. - """ - model = model_types.make_model_name(model) - - if request_options is None: - request_options = {} - - if client is None: - client = get_default_text_client() - - if isinstance(text, str): - embedding_request = protos.EmbedTextRequest(model=model, text=text) - embedding_response = client.embed_text( - embedding_request, - **request_options, - ) - embedding_dict = type(embedding_response).to_dict(embedding_response) - embedding_dict["embedding"] = embedding_dict["embedding"]["value"] - else: - result = {"embedding": []} - for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): - # TODO(markdaoust): This could use an option for returning an iterator or wait-bar. - embedding_request = protos.BatchEmbedTextRequest(model=model, texts=batch) - embedding_response = client.batch_embed_text( - embedding_request, - **request_options, - ) - embedding_dict = type(embedding_response).to_dict(embedding_response) - result["embedding"].extend(e["value"] for e in embedding_dict["embeddings"]) - return result - - return embedding_dict diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index 0acfb1397..1e7853746 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -16,18 +16,14 @@ from google.generativeai.types.citation_types import * from google.generativeai.types.content_types import * -from google.generativeai.types.discuss_types import * from google.generativeai.types.file_types import * from google.generativeai.types.generation_types import * from google.generativeai.types.helper_types import * from google.generativeai.types.model_types import * from google.generativeai.types.permission_types import * from google.generativeai.types.safety_types import * -from google.generativeai.types.text_types import * -del discuss_types del model_types -del text_types del citation_types del safety_types diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py deleted file mode 100644 index 05ad262f3..000000000 --- a/google/generativeai/types/discuss_types.py +++ /dev/null @@ -1,208 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Type definitions for the discuss service.""" - -import abc -import dataclasses -from typing import Any, Dict, Union, Iterable, Optional, Tuple, List -from typing_extensions import TypedDict - -from google.generativeai import protos -from google.generativeai import string_utils - -from google.generativeai.types import palm_safety_types -from google.generativeai.types import citation_types - - -__all__ = [ - "MessageDict", - "MessageOptions", - "MessagesOptions", - "ExampleDict", - "ExampleOptions", - "ExamplesOptions", - "MessagePromptDict", - "MessagePromptOptions", - "ResponseDict", - "ChatResponse", - "AuthorError", -] - - -class TokenCount(TypedDict): - token_count: int - - -class MessageDict(TypedDict): - """A dict representation of a `protos.Message`.""" - - author: str - content: str - citation_metadata: Optional[citation_types.CitationMetadataDict] - - -MessageOptions = Union[str, MessageDict, protos.Message] -MESSAGE_OPTIONS = (str, dict, protos.Message) - -MessagesOptions = Union[ - MessageOptions, - Iterable[MessageOptions], -] -MESSAGES_OPTIONS = (MESSAGE_OPTIONS, Iterable) - - -class ExampleDict(TypedDict): - """A dict representation of a `protos.Example`.""" - - input: MessageOptions - output: MessageOptions - - -ExampleOptions = Union[ - Tuple[MessageOptions, MessageOptions], - Iterable[MessageOptions], - ExampleDict, - protos.Example, -] -EXAMPLE_OPTIONS = (protos.Example, dict, Iterable) -ExamplesOptions = Union[ExampleOptions, Iterable[ExampleOptions]] - - -class MessagePromptDict(TypedDict, total=False): - """A dict representation of a `protos.MessagePrompt`.""" - - context: str - examples: ExamplesOptions - messages: MessagesOptions - - -MessagePromptOptions = Union[ - str, - protos.Message, - Iterable[Union[str, protos.Message]], - MessagePromptDict, - protos.MessagePrompt, -] -MESSAGE_PROMPT_KEYS = {"context", "examples", "messages"} - - -class ResponseDict(TypedDict): - """A dict representation of a `protos.GenerateMessageResponse`.""" - - messages: List[MessageDict] - candidates: List[MessageDict] - - -@string_utils.prettyprint -@dataclasses.dataclass(init=False) -class ChatResponse(abc.ABC): - """A chat response from the model. - - * Use `response.last` (settable) for easy access to the text of the last response. - (`messages[-1]['content']`) - * Use `response.messages` to access the message history (including `.last`). - * Use `response.candidates` to access all the responses generated by the model. - - Other attributes are just saved from the arguments to `genai.chat`, so you - can easily continue a conversation: - - ``` - import google.generativeai as genai - - genai.configure(api_key=os.environ['GEMINI_API_KEY']) - - response = genai.chat(messages=["Hello."]) - print(response.last) # 'Hello! What can I help you with?' - response.reply("Can you tell me a joke?") - ``` - - See `genai.chat` for more details. - - Attributes: - candidates: A list of candidate responses from the model. - - The top candidate is appended to the `messages` field. - - This list will contain a *maximum* of `candidate_count` candidates. - It may contain fewer (duplicates are dropped), it will contain at least one. - - Note: The `temperature` field affects the variability of the responses. Low - temperatures will return few candidates. Setting `temperature=0` is deterministic, - so it will only ever return one candidate. - filters: This indicates which `types.SafetyCategory`(s) blocked a - candidate from this response, the lowest `types.HarmProbability` - that triggered a block, and the `types.HarmThreshold` setting for that category. - This indicates the smallest change to the `types.SafetySettings` that would be - necessary to unblock at least 1 response. - - The blocking is configured by the `types.SafetySettings` in the request (or the - default `types.SafetySettings` of the API). - messages: Contains all the `messages` that were passed when the model was called, - plus the top `candidate` message. - model: The model name. - context: Text that should be provided to the model first, to ground the response. - examples: Examples of what the model should generate. - messages: A snapshot of the conversation history sorted chronologically. - temperature: Controls the randomness of the output. Must be positive. - candidate_count: The **maximum** number of generated response messages to return. - top_k: The maximum number of tokens to consider when sampling. - top_p: The maximum cumulative probability of tokens to consider when sampling. - - """ - - model: str - context: str - examples: List[ExampleDict] - messages: List[Optional[MessageDict]] - temperature: Optional[float] - candidate_count: Optional[int] - candidates: List[MessageDict] - filters: List[palm_safety_types.ContentFilterDict] - top_p: Optional[float] = None - top_k: Optional[float] = None - - @property - @abc.abstractmethod - def last(self) -> Optional[str]: - """A settable property that provides simple access to the last response string - - A shortcut for `response.messages[0]['content']`. - """ - pass - - def to_dict(self) -> Dict[str, Any]: - result = { - "model": self.model, - "context": self.context, - "examples": self.examples, - "messages": self.messages, - "temperature": self.temperature, - "candidate_count": self.candidate_count, - "top_p": self.top_p, - "top_k": self.top_k, - "candidates": self.candidates, - } - return result - - @abc.abstractmethod - def reply(self, message: MessageOptions) -> "ChatResponse": - "Add a message to the conversation, and get the model's response." - pass - - -class AuthorError(Exception): - """Raised by the `chat` (or `reply`) functions when the author list can't be normalized.""" - - pass diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index 61804fcaa..e84a7e715 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -21,55 +21,12 @@ from typing_extensions import TypedDict from google.generativeai import string_utils -from google.generativeai.types import palm_safety_types from google.generativeai.types import citation_types -__all__ = ["Completion"] - - -class TokenCount(TypedDict): - token_count: int - - class EmbeddingDict(TypedDict): embedding: list[float] class BatchEmbeddingDict(TypedDict): embedding: list[list[float]] - - -class TextCompletion(TypedDict, total=False): - output: str - safety_ratings: List[palm_safety_types.SafetyRatingDict | None] - citation_metadata: citation_types.CitationMetadataDict | None - - -@string_utils.prettyprint -@dataclasses.dataclass(init=False) -class Completion(abc.ABC): - """The result returned by `generativeai.generate_text`. - - Use `GenerateTextResponse.candidates` to access all the completions generated by the model. - - Attributes: - candidates: A list of candidate text completions generated by the model. - result: The output of the first candidate, - filters: Indicates the reasons why content may have been blocked. - See `types.BlockedReason`. - safety_feedback: Indicates which safety settings blocked content in this result. - """ - - candidates: List[TextCompletion] - result: str | None - filters: List[palm_safety_types.ContentFilterDict | None] - safety_feedback: List[palm_safety_types.SafetyFeedbackDict | None] - - def to_dict(self) -> Dict[str, Any]: - result = { - "candidates": self.candidates, - "filters": self.filters, - "safety_feedback": self.safety_feedback, - } - return result diff --git a/tests/test_client.py b/tests/test_client.py index 0cc3e05eb..9162c3d75 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -58,11 +58,17 @@ def test_api_key_and_client_options(self): self.assertEqual(actual_client_opts.api_endpoint, "web.site") @parameterized.parameters( - client.get_default_discuss_client, - client.get_default_text_client, - client.get_default_discuss_async_client, + client.get_default_cache_client, + client.get_default_file_client, + client.get_default_file_async_client, + client.get_default_generative_client, + client.get_default_generative_async_client, client.get_default_model_client, client.get_default_operations_client, + client.get_default_retriever_client, + client.get_default_retriever_async_client, + client.get_default_permission_client, + client.get_default_permission_async_client, ) @mock.patch.dict(os.environ, {"GOOGLE_API_KEY": "AIzA_env"}) def test_configureless_client_with_key(self, factory_fn): @@ -76,7 +82,7 @@ class DummyClient: def __init__(self, *args, **kwargs): pass - def generate_text(self, metadata=None): + def generate_content(self, metadata=None): self.metadata = metadata not_a_function = 7 @@ -92,26 +98,26 @@ def static(): def classm(cls): cls.called_classm = True - @mock.patch.object(glm, "TextServiceClient", DummyClient) + @mock.patch.object(glm, "GenerativeServiceClient", DummyClient) def test_default_metadata(self): # The metadata wrapper injects this argument. metadata = [("hello", "world")] client.configure(default_metadata=metadata) - text_client = client.get_default_text_client() - text_client.generate_text() + generative_client = client.get_default_generative_client() + generative_client.generate_content() - self.assertEqual(metadata, text_client.metadata) + self.assertEqual(metadata, generative_client.metadata) - self.assertEqual(text_client.not_a_function, ClientTests.DummyClient.not_a_function) + self.assertEqual(generative_client.not_a_function, ClientTests.DummyClient.not_a_function) # Since these don't have a metadata arg, they'll fail if the wrapper is applied. - text_client._hidden() - self.assertTrue(text_client.called_hidden) + generative_client._hidden() + self.assertTrue(generative_client.called_hidden) - text_client.static() + generative_client.static() - text_client.classm() + generative_client.classm() self.assertTrue(ClientTests.DummyClient.called_classm) def test_same_config(self): diff --git a/tests/test_discuss.py b/tests/test_discuss.py deleted file mode 100644 index 4e54cf754..000000000 --- a/tests/test_discuss.py +++ /dev/null @@ -1,386 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy - -import unittest.mock - -from google.generativeai import protos - -from google.generativeai import discuss -from google.generativeai import client -import google.generativeai as genai -from google.generativeai.types import palm_safety_types - -from absl.testing import absltest -from absl.testing import parameterized - -# TODO: replace returns with 'assert' statements - - -class UnitTests(parameterized.TestCase): - def setUp(self): - self.client = unittest.mock.MagicMock() - - client._client_manager.clients["discuss"] = self.client - - self.observed_request = None - - self.mock_response = protos.GenerateMessageResponse( - candidates=[ - protos.Message(content="a", author="1"), - protos.Message(content="b", author="1"), - protos.Message(content="c", author="1"), - ], - ) - - def fake_generate_message( - request: protos.GenerateMessageRequest, - **kwargs, - ) -> protos.GenerateMessageResponse: - self.observed_request = request - response = copy.copy(self.mock_response) - response.messages = request.prompt.messages - return response - - self.client.generate_message = fake_generate_message - - @parameterized.named_parameters( - ["string", "Hello", ""], - ["dict", {"content": "Hello"}, ""], - ["dict_author", {"content": "Hello", "author": "me"}, "me"], - ["proto", protos.Message(content="Hello"), ""], - ["proto_author", protos.Message(content="Hello", author="me"), "me"], - ) - def test_make_message(self, message, author): - x = discuss._make_message(message) - self.assertIsInstance(x, protos.Message) - self.assertEqual("Hello", x.content) - self.assertEqual(author, x.author) - - @parameterized.named_parameters( - ["string", "Hello", ["Hello"]], - ["dict", {"content": "Hello"}, ["Hello"]], - ["proto", protos.Message(content="Hello"), ["Hello"]], - [ - "list", - ["hello0", {"content": "hello1"}, protos.Message(content="hello2")], - ["hello0", "hello1", "hello2"], - ], - ) - def test_make_messages(self, messages, expected_contents): - messages = discuss._make_messages(messages) - for expected, message in zip(expected_contents, messages): - self.assertEqual(expected, message.content) - - @parameterized.named_parameters( - ["tuple", ("hello", {"content": "goodbye"})], - ["iterable", iter(["hello", "goodbye"])], - ["dict", {"input": "hello", "output": "goodbye"}], - [ - "proto", - protos.Example( - input=protos.Message(content="hello"), - output=protos.Message(content="goodbye"), - ), - ], - ) - def test_make_example(self, example): - x = discuss._make_example(example) - self.assertIsInstance(x, protos.Example) - self.assertEqual("hello", x.input.content) - self.assertEqual("goodbye", x.output.content) - return - - @parameterized.named_parameters( - [ - "messages", - [ - "Hi", - {"content": "Hello!"}, - "what's your name?", - protos.Message(content="Dave, what's yours"), - ], - ], - [ - "examples", - [ - ("Hi", "Hello!"), - { - "input": "what's your name?", - "output": {"content": "Dave, what's yours"}, - }, - ], - ], - ) - def test_make_examples(self, examples): - examples = discuss._make_examples(examples) - self.assertLen(examples, 2) - self.assertEqual(examples[0].input.content, "Hi") - self.assertEqual(examples[0].output.content, "Hello!") - self.assertEqual(examples[1].input.content, "what's your name?") - self.assertEqual(examples[1].output.content, "Dave, what's yours") - - return - - def test_make_examples_from_example(self): - ex_dict = {"input": "hello", "output": "meow!"} - example = discuss._make_example(ex_dict) - examples1 = discuss._make_examples(ex_dict) - examples2 = discuss._make_examples(discuss._make_example(ex_dict)) - - self.assertEqual(example, examples1[0]) - self.assertEqual(example, examples2[0]) - - @parameterized.named_parameters( - ["str", "hello"], - ["message", protos.Message(content="hello")], - ["messages", ["hello"]], - ["dict", {"messages": "hello"}], - ["dict2", {"messages": ["hello"]}], - ["proto", protos.MessagePrompt(messages=[protos.Message(content="hello")])], - ) - def test_make_message_prompt_from_messages(self, prompt): - x = discuss._make_message_prompt(prompt) - self.assertIsInstance(x, protos.MessagePrompt) - self.assertEqual(x.messages[0].content, "hello") - return - - @parameterized.named_parameters( - [ - "dict", - [ - { - "context": "you are a cat", - "examples": ["are you hungry?", "meow!"], - "messages": "hello", - } - ], - {}, - ], - [ - "kwargs", - [], - { - "context": "you are a cat", - "examples": ["are you hungry?", "meow!"], - "messages": "hello", - }, - ], - [ - "proto", - [ - protos.MessagePrompt( - context="you are a cat", - examples=[ - protos.Example( - input=protos.Message(content="are you hungry?"), - output=protos.Message(content="meow!"), - ) - ], - messages=[protos.Message(content="hello")], - ) - ], - {}, - ], - ) - def test_make_message_prompt_from_prompt(self, args, kwargs): - x = discuss._make_message_prompt(*args, **kwargs) - self.assertIsInstance(x, protos.MessagePrompt) - self.assertEqual(x.context, "you are a cat") - self.assertEqual(x.examples[0].input.content, "are you hungry?") - self.assertEqual(x.examples[0].output.content, "meow!") - self.assertEqual(x.messages[0].content, "hello") - - def test_make_generate_message_request_nested( - self, - ): - request0 = discuss._make_generate_message_request( - **{ - "model": "models/Dave", - "context": "you are a cat", - "examples": ["hello", "meow", "are you hungry?", "meow!"], - "messages": "Please catch that mouse.", - "temperature": 0.2, - "candidate_count": 7, - } - ) - request1 = discuss._make_generate_message_request( - **{ - "model": "models/Dave", - "prompt": { - "context": "you are a cat", - "examples": ["hello", "meow", "are you hungry?", "meow!"], - "messages": "Please catch that mouse.", - }, - "temperature": 0.2, - "candidate_count": 7, - } - ) - - self.assertIsInstance(request0, protos.GenerateMessageRequest) - self.assertIsInstance(request1, protos.GenerateMessageRequest) - self.assertEqual(request0, request1) - - @parameterized.parameters( - {"prompt": {}, "context": "You are a cat."}, - { - "prompt": {"context": "You are a cat."}, - "examples": ["hello", "meow"], - }, - {"prompt": {"examples": ["hello", "meow"]}, "messages": "hello"}, - ) - def test_make_generate_message_request_flat_prompt_conflict( - self, - context=None, - examples=None, - messages=None, - prompt=None, - ): - with self.assertRaises(ValueError): - x = discuss._make_generate_message_request( - model="test", - context=context, - examples=examples, - messages=messages, - prompt=prompt, - ) - - @parameterized.parameters( - {"kwargs": {"context": "You are a cat."}}, - {"kwargs": {"messages": "hello"}}, - {"kwargs": {"examples": [["a", "b"], ["c", "d"]]}}, - { - "kwargs": { - "messages": ["hello"], - "examples": [["a", "b"], ["c", "d"]], - } - }, - ) - def test_reply(self, kwargs): - response = genai.chat(**kwargs) - first_messages = response.messages - - self.assertEqual("a", response.last) - self.assertEqual( - [ - {"author": "1", "content": "a"}, - {"author": "1", "content": "b"}, - {"author": "1", "content": "c"}, - ], - response.candidates, - ) - - response = response.reply("again") - - def test_receive_and_reply_with_filters(self): - self.mock_response = mock_response = protos.GenerateMessageResponse( - candidates=[protos.Message(content="a", author="1")], - filters=[ - protos.ContentFilter( - reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe" - ), - protos.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), - ], - ) - response = discuss.chat(messages="do filters work?") - - filters = response.filters - self.assertLen(filters, 2) - self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason) - self.assertEqual(filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) - self.assertEqual(filters[0]["message"], "unsafe") - - self.mock_response = protos.GenerateMessageResponse( - candidates=[protos.Message(content="a", author="1")], - filters=[ - protos.ContentFilter( - reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED - ) - ], - ) - - response = response.reply("Does reply work?") - filters = response.filters - self.assertLen(filters, 1) - self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason) - self.assertEqual( - filters[0]["reason"], - palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED, - ) - - def test_chat_citations(self): - self.mock_response = mock_response = protos.GenerateMessageResponse( - candidates=[ - { - "content": "Hello google!", - "author": "1", - "citation_metadata": { - "citation_sources": [ - { - "start_index": 6, - "end_index": 12, - "uri": "https://google.com", - } - ] - }, - } - ], - ) - - response = discuss.chat(messages="Do citations work?") - - self.assertEqual( - response.candidates[0]["citation_metadata"]["citation_sources"][0]["start_index"], - 6, - ) - - response = response.reply("What about a second time?") - - self.assertEqual( - response.candidates[0]["citation_metadata"]["citation_sources"][0]["start_index"], - 6, - ) - self.assertLen(response.messages, 4) - - def test_set_last(self): - response = discuss.chat(messages="Can you overwrite `.last`?") - response.last = "yes" - response = response.reply("glad to hear it!") - response.last = "Me too!" - self.assertEqual( - [msg["content"] for msg in response.messages], - [ - "Can you overwrite `.last`?", - "yes", - "glad to hear it!", - "Me too!", - ], - ) - - def test_generate_message_called_with_request_options(self): - self.client.generate_message = unittest.mock.MagicMock() - request = unittest.mock.ANY - request_options = {"timeout": 120} - - try: - genai.chat(**{"context": "You are a cat."}, request_options=request_options) - except AttributeError: - pass - - self.client.generate_message.assert_called_once_with(request, **request_options) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py deleted file mode 100644 index d35d03525..000000000 --- a/tests/test_discuss_async.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -from typing import Any -import unittest - -from google.generativeai import protos - -from google.generativeai import discuss -from absl.testing import absltest -from absl.testing import parameterized - - -class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): - async def test_chat_async(self): - client = unittest.mock.AsyncMock() - - observed_request = None - - async def fake_generate_message( - request: protos.GenerateMessageRequest, - **kwargs, - ) -> protos.GenerateMessageResponse: - nonlocal observed_request - observed_request = request - return protos.GenerateMessageResponse( - candidates=[ - protos.Message( - author="1", - content="Why did the chicken cross the road?", - ) - ] - ) - - client.generate_message = fake_generate_message - - observed_response = await discuss.chat_async( - model="models/bard", - context="Example Prompt", - examples=[["Example from human", "Example response from AI"]], - messages=["Tell me a joke"], - temperature=0.75, - candidate_count=1, - client=client, - ) - - self.assertEqual( - observed_request, - protos.GenerateMessageRequest( - model="models/bard", - prompt=protos.MessagePrompt( - context="Example Prompt", - examples=[ - protos.Example( - input=protos.Message(content="Example from human"), - output=protos.Message(content="Example response from AI"), - ) - ], - messages=[protos.Message(author="0", content="Tell me a joke")], - ), - temperature=0.75, - candidate_count=1, - ), - ) - self.assertEqual( - observed_response.candidates, - [{"author": "1", "content": "Why did the chicken cross the road?"}], - ) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/test_text.py b/tests/test_text.py deleted file mode 100644 index 795c3dfcd..000000000 --- a/tests/test_text.py +++ /dev/null @@ -1,542 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import copy -import math -from typing import Any -import unittest -import unittest.mock as mock - -from google.generativeai import protos - -from google.generativeai import text as text_service -from google.generativeai import client -from google.generativeai.types import palm_safety_types -from google.generativeai.types import model_types -from absl.testing import absltest -from absl.testing import parameterized - - -class UnitTests(parameterized.TestCase): - def setUp(self): - self.client = unittest.mock.MagicMock() - - client._client_manager.clients["text"] = self.client - client._client_manager.clients["model"] = self.client - - self.observed_requests = [] - - self.responses = {} - - def add_client_method(f): - name = f.__name__ - setattr(self.client, name, f) - return f - - @add_client_method - def generate_text( - request: protos.GenerateTextRequest, - **kwargs, - ) -> protos.GenerateTextResponse: - self.observed_requests.append(request) - return self.responses["generate_text"] - - @add_client_method - def embed_text( - request: protos.EmbedTextRequest, - **kwargs, - ) -> protos.EmbedTextResponse: - self.observed_requests.append(request) - return self.responses["embed_text"] - - @add_client_method - def batch_embed_text( - request: protos.EmbedTextRequest, - **kwargs, - ) -> protos.EmbedTextResponse: - self.observed_requests.append(request) - - return protos.BatchEmbedTextResponse( - embeddings=[protos.Embedding(value=[1, 2, 3])] * len(request.texts) - ) - - @add_client_method - def count_text_tokens( - request: protos.CountTextTokensRequest, - **kwargs, - ) -> protos.CountTextTokensResponse: - self.observed_requests.append(request) - return self.responses["count_text_tokens"] - - @add_client_method - def get_tuned_model(name) -> protos.TunedModel: - request = protos.GetTunedModelRequest(name=name) - self.observed_requests.append(request) - response = copy.copy(self.responses["get_tuned_model"]) - return response - - @parameterized.named_parameters( - [ - dict(testcase_name="string", prompt="Hello how are"), - ] - ) - def test_make_prompt(self, prompt): - x = text_service._make_text_prompt(prompt) - self.assertIsInstance(x, protos.TextPrompt) - self.assertEqual("Hello how are", x.text) - - @parameterized.named_parameters( - [ - dict(testcase_name="string", prompt="What are you"), - ] - ) - def test_make_generate_text_request(self, prompt): - x = text_service._make_generate_text_request(model="models/chat-bison-001", prompt=prompt) - self.assertEqual("models/chat-bison-001", x.model) - self.assertIsInstance(x, protos.GenerateTextRequest) - - @parameterized.named_parameters( - [ - dict( - testcase_name="basic_model", - model="models/chat-lamda-001", - text="What are you?", - ) - ] - ) - def test_generate_embeddings(self, model, text): - self.responses["embed_text"] = protos.EmbedTextResponse( - embedding=protos.Embedding(value=[1, 2, 3]) - ) - - emb = text_service.generate_embeddings(model=model, text=text) - - self.assertIsInstance(emb, dict) - self.assertEqual( - self.observed_requests[-1], protos.EmbedTextRequest(model=model, text=text) - ) - self.assertIsInstance(emb["embedding"][0], float) - - @parameterized.named_parameters( - [ - dict( - testcase_name="small-2", - model="models/chat-lamda-001", - text=["Who are you?", "Who am I?"], - ), - dict( - testcase_name="even-batch", - model="models/chat-lamda-001", - text=["Who are you?"] * 100, - ), - dict( - testcase_name="even-batch-plus-one", - model="models/chat-lamda-001", - text=["Who are you?"] * 101, - ), - dict( - testcase_name="odd-batch", - model="models/chat-lamda-001", - text=["Who are you?"] * 237, - ), - ] - ) - def test_generate_embeddings_batch(self, model, text): - emb = text_service.generate_embeddings(model=model, text=text) - - self.assertIsInstance(emb, dict) - - # Check first and last requests. - self.assertEqual(self.observed_requests[-1].model, model) - self.assertEqual(self.observed_requests[-1].texts[-1], text[-1]) - self.assertEqual(self.observed_requests[0].texts[0], text[0]) - - # Check that the list has the right length. - self.assertIsInstance(emb["embedding"][0], list) - self.assertLen(emb["embedding"], len(text)) - - # Check that the right number of requests were sent. - self.assertLen( - self.observed_requests, - math.ceil(len(text) / text_service.EMBEDDING_MAX_BATCH_SIZE), - ) - - @parameterized.named_parameters( - [ - dict(testcase_name="basic", prompt="Why did the chicken cross the"), - dict( - testcase_name="temperature", - prompt="Why did the chicken cross the", - temperature=0.75, - ), - dict( - testcase_name="stop_list", - prompt="Why did the chicken cross the", - stop_sequences=["a", "b", "c"], - ), - dict( - testcase_name="count", - prompt="Why did the chicken cross the", - candidate_count=2, - ), - ] - ) - def test_generate_response(self, *, prompt, **kwargs): - self.responses["generate_text"] = protos.GenerateTextResponse( - candidates=[ - protos.TextCompletion(output=" road?"), - protos.TextCompletion(output=" bridge?"), - protos.TextCompletion(output=" river?"), - ] - ) - - complete = text_service.generate_text(prompt=prompt, **kwargs) - - self.assertEqual( - self.observed_requests[-1], - protos.GenerateTextRequest( - model="models/text-bison-001", prompt=protos.TextPrompt(text=prompt), **kwargs - ), - ) - - self.assertIsInstance(complete.result, str) - - self.assertEqual( - complete.candidates, - [ - {"output": " road?", "safety_ratings": []}, - {"output": " bridge?", "safety_ratings": []}, - {"output": " river?", "safety_ratings": []}, - ], - ) - - def test_stop_string(self): - self.responses["generate_text"] = protos.GenerateTextResponse( - candidates=[ - protos.TextCompletion(output="Hello world?"), - protos.TextCompletion(output="Hell!"), - protos.TextCompletion(output="I'm going to stop"), - ] - ) - complete = text_service.generate_text(prompt="Hello", stop_sequences="stop") - - self.assertEqual( - self.observed_requests[-1], - protos.GenerateTextRequest( - model="models/text-bison-001", - prompt=protos.TextPrompt(text="Hello"), - stop_sequences=["stop"], - ), - ) - # Just make sure it made it into the request object. - self.assertEqual(self.observed_requests[-1].stop_sequences, ["stop"]) - - @parameterized.named_parameters( - [ - dict( - testcase_name="basic", - safety_settings=[ - { - "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_NONE, - }, - { - "category": palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, - "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - }, - ], - ), - dict( - testcase_name="strings", - safety_settings=[ - { - "category": "medical", - "threshold": "block_none", - }, - { - "category": "violent", - "threshold": "low", - }, - ], - ), - dict( - testcase_name="flat", - safety_settings={"medical": "block_none", "sex": "low"}, - ), - dict( - testcase_name="mixed", - safety_settings={ - "medical": palm_safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE: 1, - }, - ), - ] - ) - def test_safety_settings(self, safety_settings): - self.responses["generate_text"] = protos.GenerateTextResponse( - candidates=[ - protos.TextCompletion(output="No"), - ] - ) - # This test really just checks that the safety_settings get converted to a proto. - result = text_service.generate_text( - prompt="Say something wicked.", safety_settings=safety_settings - ) - - self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - ) - - def test_filters(self): - self.responses["generate_text"] = protos.GenerateTextResponse( - candidates=[{"output": "hello"}], - filters=[ - { - "reason": palm_safety_types.BlockedReason.SAFETY, - "message": "not safe", - } - ], - ) - - response = text_service.generate_text(prompt="do filters work?") - self.assertIsInstance(response.filters[0]["reason"], palm_safety_types.BlockedReason) - self.assertEqual(response.filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) - - def test_safety_feedback(self): - self.responses["generate_text"] = protos.GenerateTextResponse( - candidates=[{"output": "hello"}], - safety_feedback=[ - { - "rating": { - "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "probability": palm_safety_types.HarmProbability.HIGH, - }, - "setting": { - "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_NONE, - }, - } - ], - ) - - response = text_service.generate_text(prompt="does safety feedback work?") - self.assertIsInstance( - response.safety_feedback[0]["rating"]["probability"], - palm_safety_types.HarmProbability, - ) - self.assertEqual( - response.safety_feedback[0]["rating"]["probability"], - palm_safety_types.HarmProbability.HIGH, - ) - - self.assertIsInstance( - response.safety_feedback[0]["setting"]["category"], - protos.HarmCategory, - ) - self.assertEqual( - response.safety_feedback[0]["setting"]["category"], - palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - ) - - def test_candidate_safety_feedback(self): - self.responses["generate_text"] = protos.GenerateTextResponse( - candidates=[ - { - "output": "hello", - "safety_ratings": [ - { - "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "probability": palm_safety_types.HarmProbability.HIGH, - }, - { - "category": palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, - "probability": palm_safety_types.HarmProbability.LOW, - }, - ], - } - ] - ) - - result = text_service.generate_text(prompt="Write a story from the ER.") - self.assertIsInstance( - result.candidates[0]["safety_ratings"][0]["category"], - protos.HarmCategory, - ) - self.assertEqual( - result.candidates[0]["safety_ratings"][0]["category"], - palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - ) - - self.assertIsInstance( - result.candidates[0]["safety_ratings"][0]["probability"], - palm_safety_types.HarmProbability, - ) - self.assertEqual( - result.candidates[0]["safety_ratings"][0]["probability"], - palm_safety_types.HarmProbability.HIGH, - ) - - def test_candidate_citations(self): - self.responses["generate_text"] = protos.GenerateTextResponse( - candidates=[ - { - "output": "Hello Google!", - "citation_metadata": { - "citation_sources": [ - { - "start_index": 6, - "end_index": 12, - "uri": "https://google.com", - } - ] - }, - } - ] - ) - result = text_service.generate_text(prompt="Hi my name is Google") - self.assertEqual( - result.candidates[0]["citation_metadata"]["citation_sources"][0]["start_index"], - 6, - ) - - @parameterized.named_parameters( - [ - dict(testcase_name="base-name", model="models/text-bison-001"), - dict(testcase_name="tuned-name", model="tunedModels/bipedal-pangolin-001"), - dict( - testcase_name="model", - model=model_types.Model( - name="models/text-bison-001", - base_model_id="text-bison-001", - version="001", - display_name="🦬", - description="🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬🦬", - input_token_limit=8000, - output_token_limit=4000, - supported_generation_methods=["GenerateText"], - ), - ), - dict( - testcase_name="tuned_model", - model=model_types.TunedModel( - name="tunedModels/bipedal-pangolin-001", - base_model="models/text-bison-001", - ), - ), - dict( - testcase_name="protos.model", - model=protos.Model( - name="models/text-bison-001", - ), - ), - dict( - testcase_name="protos.tuned_model", - model=protos.TunedModel( - name="tunedModels/bipedal-pangolin-001", - base_model="models/text-bison-001", - ), - ), - dict( - testcase_name="protos.tuned_model_nested", - model=protos.TunedModel( - name="tunedModels/bipedal-pangolin-002", - tuned_model_source={ - "tuned_model": "tunedModels/bipedal-pangolin-002", - "base_model": "models/text-bison-001", - }, - ), - ), - ] - ) - def test_count_message_tokens(self, model): - self.responses["get_tuned_model"] = protos.TunedModel( - name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001" - ) - self.responses["count_text_tokens"] = protos.CountTextTokensResponse(token_count=7) - - response = text_service.count_text_tokens(model, "Tell me a story about a magic backpack.") - self.assertEqual({"token_count": 7}, response) - - should_look_up_model = isinstance(model, str) and model.startswith("tunedModels/") - if should_look_up_model: - self.assertLen(self.observed_requests, 2) - self.assertEqual( - self.observed_requests[0], - protos.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), - ) - - def test_count_text_tokens_called_with_request_options(self): - self.client.count_text_tokens = unittest.mock.MagicMock() - request = unittest.mock.ANY - request_options = {"timeout": 120} - - try: - result = text_service.count_text_tokens( - model="models/", - prompt="", - request_options=request_options, - ) - except AttributeError: - pass - - self.client.count_text_tokens.assert_called_once_with(request, **request_options) - - def test_batch_embed_text_called_with_request_options(self): - self.client.batch_embed_text = unittest.mock.MagicMock() - request = unittest.mock.ANY - request_options = {"timeout": 120} - - try: - result = text_service.generate_embeddings( - model="models/", - text=["first", "second"], - request_options=request_options, - ) - except AttributeError: - pass - - self.client.batch_embed_text.assert_called_once_with(request, **request_options) - - def test_embed_text_called_with_request_options(self): - self.client.embed_text = unittest.mock.MagicMock() - request = unittest.mock.ANY - request_options = {"timeout": 120} - - try: - result = text_service.generate_embeddings( - model="models/", - text="", - request_options=request_options, - ) - except AttributeError: - pass - - self.client.embed_text.assert_called_once_with(request, **request_options) - - def test_generate_text_called_with_request_options(self): - self.client.generate_text = unittest.mock.MagicMock() - request = unittest.mock.ANY - request_options = {"timeout": 120} - - try: - result = text_service.generate_text(prompt="", request_options=request_options) - except AttributeError: - pass - - self.client.generate_text.assert_called_once_with(request, **request_options) - - -if __name__ == "__main__": - absltest.main()