Skip to content

Commit

Permalink
[card-components] address comments on Image class
Browse files Browse the repository at this point in the history
- Image.update supports multi-type object
- remove `disable_updates` from public api
  • Loading branch information
valayDave committed Dec 8, 2023
1 parent cd067ec commit 31f714b
Showing 1 changed file with 120 additions and 51 deletions.
171 changes: 120 additions & 51 deletions metaflow/plugins/cards/card_modules/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,6 @@ class Table(UserComponent):
List (rows) of lists (columns). Each item can be a string or a `MetaflowCardComponent`.
headers : List[str], optional
Optional header row for the table.
disable_updates: bool, optional
A boolean value to disable realtime updates for all components within the table. Default: False
"""

REALTIME_UPDATABLE = True
Expand Down Expand Up @@ -290,12 +288,44 @@ class Image(UserComponent):
The image data in `bytes`.
label : str
Optional label for the image.
disable_updates: bool
Disable realtime updates for the image. Default: True
"""

REALTIME_UPDATABLE = True

_PIL_IMAGE_MODULE_PATH = "PIL.Image.Image"

_MATPLOTLIB_FIGURE_MODULE_PATH = "matplotlib.figure.Figure"

_PLT_MODULE = None

_PIL_MODULE = None

@classmethod
def _get_pil_module(cls):
if cls._PIL_MODULE == "NOT_PRESENT":
return None
if cls._PIL_MODULE is None:
try:
import PIL
except ImportError:
cls._PIL_MODULE = "NOT_PRESENT"
return None
cls._PIL_MODULE = PIL
return cls._PIL_MODULE

@classmethod
def _get_plt_module(cls):
if cls._PLT_MODULE == "NOT_PRESENT":
return None
if cls._PLT_MODULE is None:
try:
import matplotlib.pyplot as pyplt
except ImportError:
cls._PLT_MODULE = "NOT_PRESENT"
return None
cls._PLT_MODULE = pyplt
return cls._PLT_MODULE

@staticmethod
def render_fail_headline(msg):
return "[IMAGE_RENDER FAIL]: %s" % msg
Expand Down Expand Up @@ -346,22 +376,78 @@ def __init__(self, src=None, label=None, disable_updates: bool = True):
self.REALTIME_UPDATABLE = False
self._set_image_src(src, label=label)

def _update_image(self, img_obj, label=None):
task_to_dict = TaskToDict()
parsed_image, err_comp = None, None

# First set image for bytes/string type
if task_to_dict.object_type(img_obj) in ["bytes", "str"]:
self._set_image_src(img_obj, label=label)
return

if task_to_dict.object_type(img_obj).startswith("PIL"):
parsed_image, err_comp = self._parse_pil_image(img_obj)
elif _full_classname(img_obj) == self._MATPLOTLIB_FIGURE_MODULE_PATH:
parsed_image, err_comp = self._parse_matplotlib(img_obj)
else:
parsed_image, err_comp = None, ErrorComponent(
self.render_fail_headline(
"Invalid Type. Object %s is not supported. Supported types: %s"
% (
type(img_obj),
", ".join(
[
"str",
"bytes",
self._PIL_IMAGE_MODULE_PATH,
self._MATPLOTLIB_FIGURE_MODULE_PATH,
]
),
)
),
"",
)

if parsed_image is not None:
self._set_image_src(parsed_image, label=label)
else:
self._set_image_src(None, label=label)
self._error_comp = err_comp

@classmethod
def _pil_parsing_error(cls, error_type):
return None, ErrorComponent(
cls.render_fail_headline(
"first argument for `Image` should be of type %s"
% cls._PIL_IMAGE_MODULE_PATH
),
"Type of %s is invalid. Type of %s required"
% (error_type, cls._PIL_IMAGE_MODULE_PATH),
)

@classmethod
def _parse_pil_image(cls, pilimage):
parsed_value = None
error_component = None
import io

PIL_IMAGE_PATH = "PIL.Image.Image"
task_to_dict = TaskToDict()
if task_to_dict.object_type(pilimage) != PIL_IMAGE_PATH:
_img_type = task_to_dict.object_type(pilimage)

if not _img_type.startswith("PIL"):
return cls._pil_parsing_error(_img_type)

# Set the module as a part of the class so that
# we don't keep reloading the module everytime
pil_module = cls._get_pil_module()

if pil_module is None:
return parsed_value, ErrorComponent(
cls.render_fail_headline(
"first argument for `Image` should be of type %s" % PIL_IMAGE_PATH
),
"Type of %s is invalid. Type of %s required"
% (task_to_dict.object_type(pilimage), PIL_IMAGE_PATH),
cls.render_fail_headline("PIL cannot be imported"), ""
)
if not isinstance(pilimage, pil_module.Image.Image):
return cls._pil_parsing_error(_img_type)

img_byte_arr = io.BytesIO()
try:
pilimage.save(img_byte_arr, format="PNG")
Expand All @@ -380,16 +466,15 @@ def _parse_matplotlib(cls, plot):

parsed_value = None
error_component = None
try:
import matplotlib.pyplot as pyplt
except ImportError:
pyplt = cls._get_plt_module()
if pyplt is None:
return parsed_value, ErrorComponent(
cls.render_fail_headline("Matplotlib cannot be imported"),
"%s" % traceback.format_exc(),
)
# First check if it is a valid Matplotlib figure.
figure = None
if _full_classname(plot) == "matplotlib.figure.Figure":
if _full_classname(plot) == cls._MATPLOTLIB_FIGURE_MODULE_PATH:
figure = plot

# If it is not valid figure then check if it is matplotlib.axes.Axes or a matplotlib.axes._subplots.AxesSubplot
Expand Down Expand Up @@ -427,7 +512,9 @@ def _bytes_to_base64(bytes_arr):
return parsed_image

@classmethod
def from_pil_image(cls, pilimage, label: Optional[str] = None):
def from_pil_image(
cls, pilimage, label: Optional[str] = None, disable_updates: bool = False
):
"""
Create an `Image` from a PIL image.
Expand All @@ -441,23 +528,27 @@ def from_pil_image(cls, pilimage, label: Optional[str] = None):
try:
parsed_image, error_comp = cls._parse_pil_image(pilimage)
if parsed_image is not None:
img = cls(src=parsed_image, label=label)
img = cls(
src=parsed_image, label=label, disable_updates=disable_updates
)
else:
img = cls(src=None, label=label)
img = cls(src=None, label=label, disable_updates=disable_updates)
img._error_comp = error_comp
return img
except:
import traceback

img = cls(src=None, label=label)
img = cls(src=None, label=label, disable_updates=disable_updates)
img._error_comp = ErrorComponent(
cls.render_fail_headline("PIL Image Not Parsable"),
"%s" % traceback.format_exc(),
)
return img

@classmethod
def from_matplotlib(cls, plot, label: Optional[str] = None):
def from_matplotlib(
cls, plot, label: Optional[str] = None, disable_updates: bool = False
):
"""
Create an `Image` from a Matplotlib plot.
Expand All @@ -471,15 +562,17 @@ def from_matplotlib(cls, plot, label: Optional[str] = None):
try:
parsed_image, error_comp = cls._parse_matplotlib(plot)
if parsed_image is not None:
img = cls(src=parsed_image, label=label)
img = cls(
src=parsed_image, label=label, disable_updates=disable_updates
)
else:
img = cls(src=None, label=label)
img = cls(src=None, label=label, disable_updates=disable_updates)
img._error_comp = error_comp
return img
except:
import traceback

img = cls(src=None, label=label)
img = cls(src=None, label=label, disable_updates=disable_updates)
img._error_comp = ErrorComponent(
cls.render_fail_headline("Matplotlib plot's image is not parsable"),
"%s" % traceback.format_exc(),
Expand All @@ -500,20 +593,14 @@ def render(self):
self.render_fail_headline("`Image` Component `src` argument is `None`"), ""
).render()

def update(self, pilimage=None, plot=None, bytes=None, string=None, label=None):
def update(self, image, label=None):
"""
Update the image.
Parameters
----------
pilimage : PIL.Image, optional
a PIL image object.
plot : matplotlib.figure.Figure or matplotlib.axes.Axes or matplotlib.axes._subplots.AxesSubplot, optional
a PIL axes (plot) object.
bytes : bytes, optional
The image data in `bytes`.
string : str, optional
The image data in base64 string.
image : PIL.Image or matplotlib.figure.Figure or matplotlib.axes.Axes or matplotlib.axes._subplots.AxesSubplot or bytes or str
The updated image object
label : str, optional
Optional label for the image.
"""
Expand All @@ -526,25 +613,7 @@ def update(self, pilimage=None, plot=None, bytes=None, string=None, label=None):
return

_label = label if label is not None else self._label
if bytes is not None:
self._set_image_src(bytes, label=_label)
return
elif string is not None:
self._set_image_src(string, label=_label)
return

parsed_image = None
err_comp = None
if pilimage is not None:
parsed_image, err_comp = self._parse_pil_image(pilimage)
elif plot is not None:
parsed_image, err_comp = self._parse_matplotlib(plot)

if parsed_image is not None:
self._set_image_src(parsed_image, label=_label)
else:
self._set_image_src(None, label=_label)
self._error_comp = err_comp
self._update_image(image, label=_label)


class Error(UserComponent):
Expand Down

0 comments on commit 31f714b

Please sign in to comment.