From 6c8dad12643dc19e80f7d1943125984ae845c6d8 Mon Sep 17 00:00:00 2001 From: Laurent Picard Date: Tue, 24 Sep 2024 20:29:53 +0200 Subject: [PATCH] fix: preserve quality and optimize transfer of prompt images (#570) * fix: preserve quality and optimize transfer of prompt images * Move numpy-images to their own test case. Change-Id: Ie6b02c7647487c1df9d4e70e9b8eed70dc8b8fe3 * Format with black Change-Id: I04550a89eed9bb21c0a8f6f9b6ab76b8b0f41270 --------- Co-authored-by: Mark Daoust --- google/generativeai/types/content_types.py | 68 ++++++++++------------ tests/test_content.py | 15 ++++- 2 files changed, 43 insertions(+), 40 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index b925967c8..531999f55 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -19,6 +19,7 @@ import io import inspect import mimetypes +import pathlib import typing from typing import Any, Callable, Union from typing_extensions import TypedDict @@ -30,7 +31,7 @@ if typing.TYPE_CHECKING: import PIL.Image - import PIL.PngImagePlugin + import PIL.ImageFile import IPython.display IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image) @@ -38,7 +39,7 @@ IMAGE_TYPES = () try: import PIL.Image - import PIL.PngImagePlugin + import PIL.ImageFile IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,) except ImportError: @@ -72,46 +73,39 @@ ] -def pil_to_blob(img): - # When you load an image with PIL you get a subclass of PIL.Image - # The subclass knows what file type it was loaded from it has a `.format` class attribute - # and the `get_format_mimetype` method. Convert these back to the same file type. - # - # The base image class doesn't know its file type, it just knows its mode. - # RGBA converts to PNG easily, P[allet] converts to GIF, RGB to GIF. - # But for anything else I'm not going to bother mapping it out (for now) let's just convert to RGB and send it. - # - # References: - # - file formats: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html - # - image modes: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes - - bytesio = io.BytesIO() - - get_mime = getattr(img, "get_format_mimetype", None) - if get_mime is not None: - # If the image is created from a file, convert back to the same file type. - img.save(bytesio, format=img.format) - mime_type = img.get_format_mimetype() - elif img.mode == "RGBA": - img.save(bytesio, format="PNG") - mime_type = "image/png" - elif img.mode == "P": - img.save(bytesio, format="GIF") - mime_type = "image/gif" - else: - if img.mode != "RGB": - img = img.convert("RGB") - img.save(bytesio, format="JPEG") - mime_type = "image/jpeg" - bytesio.seek(0) - data = bytesio.read() - return protos.Blob(mime_type=mime_type, data=data) +def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob: + # If the image is a local file, return a file-based blob without any modification. + # Otherwise, return a lossless WebP blob (same quality with optimized size). + def file_blob(image: PIL.Image.Image) -> protos.Blob | None: + if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None: + return None + filename = str(image.filename) + if not pathlib.Path(filename).is_file(): + return None + + mime_type = image.get_format_mimetype() + image_bytes = pathlib.Path(filename).read_bytes() + + return protos.Blob(mime_type=mime_type, data=image_bytes) + + def webp_blob(image: PIL.Image.Image) -> protos.Blob: + # Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp + image_io = io.BytesIO() + image.save(image_io, format="webp", lossless=True) + image_io.seek(0) + + mime_type = "image/webp" + image_bytes = image_io.read() + + return protos.Blob(mime_type=mime_type, data=image_bytes) + + return file_blob(image) or webp_blob(image) def image_to_blob(image) -> protos.Blob: if PIL is not None: if isinstance(image, PIL.Image.Image): - return pil_to_blob(image) + return _pil_to_blob(image) if IPython is not None: if isinstance(image, IPython.display.Image): diff --git a/tests/test_content.py b/tests/test_content.py index dc62e997b..52e78f349 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -83,9 +83,20 @@ class HasEnum: class UnitTests(parameterized.TestCase): + @parameterized.named_parameters( - ["PIL", PIL.Image.open(TEST_PNG_PATH)], ["RGBA", PIL.Image.fromarray(np.zeros([6, 6, 4], dtype=np.uint8))], + ["RGB", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8))], + ["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")], + ) + def test_numpy_to_blob(self, image): + blob = content_types.image_to_blob(image) + self.assertIsInstance(blob, protos.Blob) + self.assertEqual(blob.mime_type, "image/webp") + self.assertStartsWith(blob.data, b"RIFF \x00\x00\x00WEBPVP8L") + + @parameterized.named_parameters( + ["PIL", PIL.Image.open(TEST_PNG_PATH)], ["IPython", IPython.display.Image(filename=TEST_PNG_PATH)], ) def test_png_to_blob(self, image): @@ -96,7 +107,6 @@ def test_png_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_JPG_PATH)], - ["RGB", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8))], ["IPython", IPython.display.Image(filename=TEST_JPG_PATH)], ) def test_jpg_to_blob(self, image): @@ -107,7 +117,6 @@ def test_jpg_to_blob(self, image): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_GIF_PATH)], - ["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")], ["IPython", IPython.display.Image(filename=TEST_GIF_PATH)], ) def test_gif_to_blob(self, image):