Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image feedback #608

Merged
merged 12 commits into from
Nov 1, 2024
Merged
95 changes: 85 additions & 10 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
import contextlib
import inspect
import collections
import dataclasses
import pathlib
from typing import Any, cast
from collections.abc import Sequence
from collections.abc import Sequence, Mapping
import httplib2
from io import IOBase

Expand All @@ -23,6 +24,11 @@
import googleapiclient.http
import googleapiclient.discovery

from google.protobuf import struct_pb2

from proto.marshal.collections import maps
from proto.marshal.collections import repeated

try:
from google.generativeai import version

Expand Down Expand Up @@ -130,6 +136,70 @@ async def create_file(self, *args, **kwargs):
)


# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
def to_value(value) -> struct_pb2.Value:
"""Return a protobuf Value object representing this value."""
if isinstance(value, struct_pb2.Value):
return value
if value is None:
return struct_pb2.Value(null_value=0)
if isinstance(value, bool):
return struct_pb2.Value(bool_value=value)
if isinstance(value, (int, float)):
return struct_pb2.Value(number_value=float(value))
if isinstance(value, str):
return struct_pb2.Value(string_value=value)
if isinstance(value, collections.abc.Sequence):
return struct_pb2.Value(list_value=to_list_value(value))
if isinstance(value, collections.abc.Mapping):
return struct_pb2.Value(struct_value=to_mapping_value(value))
raise ValueError("Unable to coerce value: %r" % value)


def to_list_value(value) -> struct_pb2.ListValue:
# We got a proto, or else something we sent originally.
# Preserve the instance we have.
if isinstance(value, struct_pb2.ListValue):
return value
if isinstance(value, repeated.RepeatedComposite):
return struct_pb2.ListValue(values=[v for v in value.pb])

# We got a list (or something list-like); convert it.
return struct_pb2.ListValue(values=[to_value(v) for v in value])


def to_mapping_value(value) -> struct_pb2.Struct:
# We got a proto, or else something we sent originally.
# Preserve the instance we have.
if isinstance(value, struct_pb2.Struct):
return value
if isinstance(value, maps.MapComposite):
return struct_pb2.Struct(
fields={k: v for k, v in value.pb.items()},
)

# We got a dict (or something dict-like); convert it.
return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()})


class PredictionServiceClient(glm.PredictionServiceClient):
def predict(self, model=None, instances=None, parameters=None):
pr = protos.PredictRequest.pb()
request = pr(
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
)
return super().predict(request)


class PredictionServiceAsyncClient(glm.PredictionServiceAsyncClient):
async def predict(self, model=None, instances=None, parameters=None):
pr = protos.PredictRequest.pb()
request = pr(
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
)
return await super().predict(request)


@dataclasses.dataclass
class _ClientManager:
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
Expand Down Expand Up @@ -220,15 +290,20 @@ def configure(
self.clients = {}

def make_client(self, name):
if name == "file":
cls = FileServiceClient
elif name == "file_async":
cls = FileServiceAsyncClient
elif name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")
local_clients = {
"file": FileServiceClient,
"file_async": FileServiceAsyncClient,
"prediction": PredictionServiceClient,
"prediction_async": PredictionServiceAsyncClient,
}
cls = local_clients.get(name, None)

if cls is None:
if name.endswith("_async"):
name = name.split("_")[0]
cls = getattr(glm, name.title() + "ServiceAsyncClient")
else:
cls = getattr(glm, name.title() + "ServiceClient")

# Attempt to configure using defaults.
if not self.client_config:
Expand Down
98 changes: 4 additions & 94 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,16 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
import io
import inspect
import mimetypes
import pathlib
import typing
from typing import Any, Callable, Union
from typing_extensions import TypedDict

import pydantic

from google.generativeai.types import file_types
from google.generativeai.types.image_types import _image_types
from google.generativeai import protos

if typing.TYPE_CHECKING:
import PIL.Image
import PIL.ImageFile
import IPython.display

IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
ImageType = PIL.Image.Image | IPython.display.Image
else:
IMAGE_TYPES = ()
try:
import PIL.Image
import PIL.ImageFile

IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
except ImportError:
PIL = None

try:
import IPython.display

IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
except ImportError:
IPython = None

ImageType = Union["PIL.Image.Image", "IPython.display.Image"]


__all__ = [
"BlobDict",
Expand Down Expand Up @@ -97,62 +68,6 @@ def to_mode(x: ModeOptions) -> Mode:
return _MODE[x]


def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
# If the image is a local file, return a file-based blob without any modification.
# Otherwise, return a lossless WebP blob (same quality with optimized size).
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
return None
filename = str(image.filename)
if not pathlib.Path(filename).is_file():
return None

mime_type = image.get_format_mimetype()
image_bytes = pathlib.Path(filename).read_bytes()

return protos.Blob(mime_type=mime_type, data=image_bytes)

def webp_blob(image: PIL.Image.Image) -> protos.Blob:
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
image_io = io.BytesIO()
image.save(image_io, format="webp", lossless=True)
image_io.seek(0)

mime_type = "image/webp"
image_bytes = image_io.read()

return protos.Blob(mime_type=mime_type, data=image_bytes)

return file_blob(image) or webp_blob(image)


def image_to_blob(image: ImageType) -> protos.Blob:
if PIL is not None:
if isinstance(image, PIL.Image.Image):
return _pil_to_blob(image)

if IPython is not None:
if isinstance(image, IPython.display.Image):
name = image.filename
if name is None:
raise ValueError(
"Conversion failed. The `IPython.display.Image` can only be converted if "
"it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
)
mime_type, _ = mimetypes.guess_type(name)
if mime_type is None:
mime_type = "image/unknown"

return protos.Blob(mime_type=mime_type, data=image.data)

raise TypeError(
"Image conversion failed. The input was expected to be of type `Image` "
"(either `PIL.Image.Image` or `IPython.display.Image`).\n"
f"However, received an object of type: {type(image)}.\n"
f"Object Value: {image}"
)


class BlobDict(TypedDict):
mime_type: str
data: bytes
Expand Down Expand Up @@ -189,12 +104,7 @@ def is_blob_dict(d):
return "mime_type" in d and "data" in d


if typing.TYPE_CHECKING:
BlobType = Union[
protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image
] # Any for the images
else:
BlobType = Union[protos.Blob, BlobDict, Any]
BlobType = Union[protos.Blob, BlobDict, _image_types.ImageType] # Any for the images


def to_blob(blob: BlobType) -> protos.Blob:
Expand All @@ -203,8 +113,8 @@ def to_blob(blob: BlobType) -> protos.Blob:

if isinstance(blob, protos.Blob):
return blob
elif isinstance(blob, IMAGE_TYPES):
return image_to_blob(blob)
elif isinstance(blob, _image_types.IMAGE_TYPES):
return _image_types.image_to_blob(blob)
else:
if isinstance(blob, Mapping):
raise KeyError(
Expand Down
1 change: 1 addition & 0 deletions google/generativeai/types/image_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from google.generativeai.types.image_types._image_types import *
Loading
Loading