diff --git a/metaflow/cards.py b/metaflow/cards.py index edebaf6896..8a8556fb83 100644 --- a/metaflow/cards.py +++ b/metaflow/cards.py @@ -6,6 +6,8 @@ Image, Error, Markdown, + VegaChart, + ProgressBar, ) from metaflow.plugins.cards.card_modules.basic import ( DefaultCard, diff --git a/metaflow/plugins/cards/card_cli.py b/metaflow/plugins/cards/card_cli.py index de52546a29..2084a36b0e 100644 --- a/metaflow/plugins/cards/card_cli.py +++ b/metaflow/plugins/cards/card_cli.py @@ -1,6 +1,11 @@ from metaflow.client import Task from metaflow import JSONType, namespace -from metaflow.exception import CommandException +from metaflow.util import resolve_identity +from metaflow.exception import ( + CommandException, + MetaflowNotFound, + MetaflowNamespaceMismatch, +) import webbrowser import re from metaflow._vendor import click @@ -945,3 +950,125 @@ def list( show_list_as_json=as_json, file=file, ) + + +@card.command(help="Run local card viewer server") +@click.option( + "--run-id", + default=None, + show_default=True, + type=str, + help="Run ID of the flow", +) +@click.option( + "--port", + default=8324, + show_default=True, + type=int, + help="Port on which Metaflow card viewer server will run", +) +@click.option( + "--namespace", + "user_namespace", + default=None, + show_default=True, + type=str, + help="Namespace of the flow", +) +@click.option( + "--poll-interval", + default=5, + show_default=True, + type=int, + help="Polling interval of the card viewer server.", +) +@click.option( + "--max-cards", + default=30, + show_default=True, + type=int, + help="Maximum number of cards to be shown at any time by the card viewer server", +) +@click.pass_context +def server(ctx, run_id, port, user_namespace, poll_interval, max_cards): + from .card_server import create_card_server, CardServerOptions + + user_namespace = resolve_identity() if user_namespace is None else user_namespace + run, follow_new_runs, _status_message = _get_run_object( + ctx.obj, run_id, user_namespace + ) + if _status_message is not None: + ctx.obj.echo(_status_message, fg="red") + options = CardServerOptions( + flow_name=ctx.obj.flow.name, + run_object=run, + only_running=False, + follow_resumed=False, + flow_datastore=ctx.obj.flow_datastore, + max_cards=max_cards, + follow_new_runs=follow_new_runs, + poll_interval=poll_interval, + ) + create_card_server(options, port, ctx.obj) + + +def _get_run_from_cli_set_runid(obj, run_id): + # This run-id will be set from the command line args. + # So if we hit a MetaflowNotFound exception / Namespace mismatch then + # we should raise an exception + from metaflow import Run + + flow_name = obj.flow.name + if len(run_id.split("/")) > 1: + raise CommandException( + "run_id should NOT be of the form: `/`. Please provide only run-id" + ) + try: + pathspec = "%s/%s" % (flow_name, run_id) + # Since we are looking at all namespaces, + # we will not + namespace(None) + return Run(pathspec) + except MetaflowNotFound: + raise CommandException("No run (%s) found for *%s*." % (run_id, flow_name)) + + +def _get_run_object(obj, run_id, user_namespace): + from metaflow import Flow + + follow_new_runs = True + flow_name = obj.flow.name + + if run_id is not None: + follow_new_runs = False + run = _get_run_from_cli_set_runid(obj, run_id) + obj.echo("Using run-id %s" % run.pathspec, fg="blue", bold=False) + return run, follow_new_runs, None + + _msg = "Searching for runs in namespace: %s" % user_namespace + obj.echo(_msg, fg="blue", bold=False) + + try: + namespace(user_namespace) + flow = Flow(pathspec=flow_name) + run = flow.latest_run + except MetaflowNotFound: + # When we have no runs found for the Flow, we need to ensure that + # if the `follow_new_runs` is set to True; If `follow_new_runs` is set to True then + # we don't raise the Exception and instead we return None and let the + # background Thread wait on the Retrieving the run object. + _status_msg = "No run found for *%s*." % flow_name + return None, follow_new_runs, _status_msg + + except MetaflowNamespaceMismatch: + _status_msg = ( + "No run found for *%s* in namespace *%s*. You can switch the namespace using --namespace" + % ( + flow_name, + user_namespace, + ) + ) + return None, follow_new_runs, _status_msg + + obj.echo("Using run-id %s" % run.pathspec, fg="blue", bold=False) + return run, follow_new_runs, None diff --git a/metaflow/plugins/cards/card_creator.py b/metaflow/plugins/cards/card_creator.py index 04262dd97f..3da252af0c 100644 --- a/metaflow/plugins/cards/card_creator.py +++ b/metaflow/plugins/cards/card_creator.py @@ -57,8 +57,11 @@ def create( logger=None, mode="render", final=False, + sync=False, ): - # warning_message("calling proc for uuid %s" % self._card_uuid, self._logger) + # Setting `final` will affect the Reload token set during the card refresh + # data creation along with synchronous execution of subprocess. + # Setting `sync` will only cause synchronous execution of subprocess. if mode != "render" and not runtime_card: # silently ignore runtime updates for cards that don't support them return @@ -68,7 +71,6 @@ def create( component_strings = [] else: component_strings = current.card._serialize_components(card_uuid) - data = current.card._get_latest_data(card_uuid, final=final, mode=mode) runspec = "/".join([current.run_id, current.step_name, current.task_id]) self._run_cards_subprocess( @@ -82,6 +84,7 @@ def create( logger, data, final=final, + sync=sync, ) def _run_cards_subprocess( @@ -96,9 +99,10 @@ def _run_cards_subprocess( logger, data=None, final=False, + sync=False, ): components_file = data_file = None - wait = final + wait = final or sync if len(component_strings) > 0: # note that we can't delete temporary files here when calling the subprocess diff --git a/metaflow/plugins/cards/card_modules/base.html b/metaflow/plugins/cards/card_modules/base.html index 91c05ca1f4..76ff0b87d2 100644 --- a/metaflow/plugins/cards/card_modules/base.html +++ b/metaflow/plugins/cards/card_modules/base.html @@ -25,10 +25,21 @@ window.__MF_DATA__ = {}; } window.__MF_DATA__["{{{card_data_id}}}"] = "{{{task_data}}}" + + {{#RENDER_COMPLETE}} + + {{/RENDER_COMPLETE}} + {{^RENDER_COMPLETE}} + + {{/RENDER_COMPLETE}} diff --git a/metaflow/plugins/cards/card_modules/basic.py b/metaflow/plugins/cards/card_modules/basic.py index b078a8f82d..a1d992145d 100644 --- a/metaflow/plugins/cards/card_modules/basic.py +++ b/metaflow/plugins/cards/card_modules/basic.py @@ -146,6 +146,8 @@ def render(self): label=self._label, ) datadict.update(img_dict) + if self.component_id is not None: + datadict["id"] = self.component_id return datadict @@ -194,6 +196,8 @@ def render(self): datadict["columns"] = self._headers datadict["data"] = self._data datadict["vertical"] = self._vertical + if self.component_id is not None: + datadict["id"] = self.component_id return datadict @@ -295,6 +299,8 @@ def __init__(self, title=None, subtitle=None, data={}): def render(self): datadict = super().render() datadict["data"] = self._data + if self.component_id is not None: + datadict["id"] = self.component_id return datadict @@ -308,6 +314,8 @@ def __init__(self, text=None): def render(self): datadict = super().render() datadict["source"] = self._text + if self.component_id is not None: + datadict["id"] = self.component_id return datadict @@ -319,7 +327,13 @@ class TaskInfoComponent(MetaflowCardComponent): """ def __init__( - self, task, page_title="Task Info", only_repr=True, graph=None, components=[] + self, + task, + page_title="Task Info", + only_repr=True, + graph=None, + components=[], + runtime=False, ): self._task = task self._only_repr = only_repr @@ -328,6 +342,7 @@ def __init__( self._page_title = page_title self.final_component = None self.page_component = None + self.runtime = runtime def render(self): """ @@ -340,7 +355,8 @@ def render(self): self._task, graph=self._graph ) # ignore the name as an artifact - del task_data_dict["data"]["name"] + if "name" in task_data_dict["data"]: + del task_data_dict["data"]["name"] _metadata = dict(version=1, template="defaultCardTemplate") # try to parse out metaflow version from tags, but let it go if unset @@ -370,11 +386,12 @@ def render(self): "Task Created On": task_data_dict["created_at"], "Task Finished On": task_data_dict["finished_at"], # Remove Microseconds from timedelta - "Task Duration": str(self._task.finished_at - self._task.created_at).split( - "." - )[0], "Tags": ", ".join(tags), } + if not self.runtime: + task_metadata_dict["Task Duration"] = str( + self._task.finished_at - self._task.created_at + ).split(".")[0] if len(user_info) > 0: task_metadata_dict["User"] = user_info[0].split("user:")[1] @@ -434,8 +451,12 @@ def render(self): p.id for p in self._task.parent.parent["_parameters"].task if p.id != "name" ] if len(param_ids) > 0: + # Extract parameter from the Parameter Task. That is less brittle. + parameter_data = TaskToDict( + only_repr=self._only_repr, runtime=self.runtime + )(self._task.parent.parent["_parameters"].task, graph=self._graph) param_component = ArtifactsComponent( - data=[task_data_dict["data"][pid] for pid in param_ids] + data=[parameter_data["data"][pid] for pid in param_ids] ) else: param_component = TitleComponent(text="No Parameters") @@ -580,6 +601,10 @@ class DefaultCard(MetaflowCard): ALLOW_USER_COMPONENTS = True + RUNTIME_UPDATABLE = True + + RELOAD_POLICY = MetaflowCard.RELOAD_POLICY_ONCHANGE + type = "default" def __init__(self, options=dict(only_repr=True), components=[], graph=None): @@ -589,7 +614,7 @@ def __init__(self, options=dict(only_repr=True), components=[], graph=None): self._only_repr = options["only_repr"] self._components = components - def render(self, task): + def render(self, task, runtime=False): RENDER_TEMPLATE = read_file(RENDER_TEMPLATE_PATH) JS_DATA = read_file(JS_PATH) CSS_DATA = read_file(CSS_PATH) @@ -598,6 +623,7 @@ def render(self, task): only_repr=self._only_repr, graph=self._graph, components=self._components, + runtime=runtime, ).render() pt = self._get_mustache() data_dict = dict( @@ -608,14 +634,36 @@ def render(self, task): title=task.pathspec, css=CSS_DATA, card_data_id=uuid.uuid4(), + RENDER_COMPLETE=not runtime, ) return pt.render(RENDER_TEMPLATE, data_dict) + def render_runtime(self, task, data): + return self.render(task, runtime=True) + + def refresh(self, task, data): + return data["components"] + + def reload_content_token(self, task, data): + """ + The reload token will change when the component array has changed in the Metaflow card. + The change in the component array is signified by the change in the component_update_ts. + """ + if task.finished: + return "final" + # `component_update_ts` will never be None. It is set to a default value when the `ComponentStore` is instantiated + # And it is updated when components added / removed / changed from the `ComponentStore`. + return "runtime-%s" % (str(data["component_update_ts"])) + class BlankCard(MetaflowCard): ALLOW_USER_COMPONENTS = True + RUNTIME_UPDATABLE = True + + RELOAD_POLICY = MetaflowCard.RELOAD_POLICY_ONCHANGE + type = "blank" def __init__(self, options=dict(title=""), components=[], graph=None): @@ -625,7 +673,7 @@ def __init__(self, options=dict(title=""), components=[], graph=None): self._title = options["title"] self._components = components - def render(self, task, components=[]): + def render(self, task, components=[], runtime=False): RENDER_TEMPLATE = read_file(RENDER_TEMPLATE_PATH) JS_DATA = read_file(JS_PATH) CSS_DATA = read_file(CSS_PATH) @@ -650,9 +698,27 @@ def render(self, task, components=[]): title=task.pathspec, css=CSS_DATA, card_data_id=uuid.uuid4(), + RENDER_COMPLETE=not runtime, ) return pt.render(RENDER_TEMPLATE, data_dict) + def render_runtime(self, task, data): + return self.render(task, runtime=True) + + def refresh(self, task, data): + return data["components"] + + def reload_content_token(self, task, data): + """ + The reload token will change when the component array has changed in the Metaflow card. + The change in the component array is signified by the change in the component_update_ts. + """ + if task.finished: + return "final" + # `component_update_ts` will never be None. It is set to a default value when the `ComponentStore` is instantiated + # And it is updated when components added / removed / changed from the `ComponentStore`. + return "runtime-%s" % (str(data["component_update_ts"])) + class TaskSpecCard(MetaflowCard): type = "taskspec_card" diff --git a/metaflow/plugins/cards/card_modules/components.py b/metaflow/plugins/cards/card_modules/components.py index 3ffcf0f533..55d032346b 100644 --- a/metaflow/plugins/cards/card_modules/components.py +++ b/metaflow/plugins/cards/card_modules/components.py @@ -82,6 +82,11 @@ class Artifact(UserComponent): Use a truncated representation. """ + REALTIME_UPDATABLE = True + + def update(self, artifact): + self._artifact = artifact + def __init__( self, artifact: Any, name: Optional[str] = None, compressed: bool = True ): @@ -89,13 +94,16 @@ def __init__( self._name = name self._task_to_dict = TaskToDict(only_repr=compressed) + @with_default_component_id @render_safely def render(self): artifact = self._task_to_dict.infer_object(self._artifact) artifact["name"] = None if self._name is not None: artifact["name"] = str(self._name) - return ArtifactsComponent(data=[artifact]).render() + af_component = ArtifactsComponent(data=[artifact]) + af_component.component_id = self.component_id + return af_component.render() class Table(UserComponent): @@ -139,10 +147,21 @@ class Table(UserComponent): Optional header row for the table. """ + REALTIME_UPDATABLE = True + + def update(self, *args, **kwargs): + msg = ( + "`Table` doesn't have an `update` method implemented. " + "Components within a table can be updated individually " + "but the table itself cannot be updated." + ) + _warning_with_component(self, msg) + def __init__( self, data: Optional[List[List[Union[str, MetaflowCardComponent]]]] = None, headers: Optional[List[str]] = None, + disable_updates: bool = False, ): data = data or [[]] headers = headers or [] @@ -154,8 +173,15 @@ def __init__( if data_bool: self._data = data + if disable_updates: + self.REALTIME_UPDATABLE = False + @classmethod - def from_dataframe(cls, dataframe=None, truncate: bool = True): + def from_dataframe( + cls, + dataframe=None, + truncate: bool = True, + ): """ Create a `Table` based on a Pandas dataframe. @@ -172,14 +198,24 @@ def from_dataframe(cls, dataframe=None, truncate: bool = True): table_data = task_to_dict._parse_pandas_dataframe( dataframe, truncate=truncate ) - return_val = cls(data=table_data["data"], headers=table_data["headers"]) + return_val = cls( + data=table_data["data"], + headers=table_data["headers"], + disable_updates=True, + ) return return_val else: return cls( headers=["Object type %s not supported" % object_type], + disable_updates=True, ) def _render_subcomponents(self): + for row in self._data: + for col in row: + if isinstance(col, VegaChart): + col._chart_inside_table = True + return [ SectionComponent.render_subcomponents( row, @@ -198,11 +234,14 @@ def _render_subcomponents(self): for row in self._data ] + @with_default_component_id @render_safely def render(self): - return TableComponent( + table_component = TableComponent( headers=self._headers, data=self._render_subcomponents() - ).render() + ) + table_component.component_id = self.component_id + return table_component.render() class Image(UserComponent): @@ -256,21 +295,62 @@ class Image(UserComponent): Optional label for the image. """ + REALTIME_UPDATABLE = True + + _PIL_IMAGE_MODULE_PATH = "PIL.Image.Image" + + _MATPLOTLIB_FIGURE_MODULE_PATH = "matplotlib.figure.Figure" + + _PLT_MODULE = None + + _PIL_MODULE = None + + @classmethod + def _get_pil_module(cls): + if cls._PIL_MODULE == "NOT_PRESENT": + return None + if cls._PIL_MODULE is None: + try: + import PIL + except ImportError: + cls._PIL_MODULE = "NOT_PRESENT" + return None + cls._PIL_MODULE = PIL + return cls._PIL_MODULE + + @classmethod + def _get_plt_module(cls): + if cls._PLT_MODULE == "NOT_PRESENT": + return None + if cls._PLT_MODULE is None: + try: + import matplotlib.pyplot as pyplt + except ImportError: + cls._PLT_MODULE = "NOT_PRESENT" + return None + cls._PLT_MODULE = pyplt + return cls._PLT_MODULE + @staticmethod def render_fail_headline(msg): return "[IMAGE_RENDER FAIL]: %s" % msg - def __init__(self, src=None, label=None): - self._error_comp = None + def _set_image_src(self, src, label=None): self._label = label - - if type(src) is not str: + self._src = None + self._error_comp = None + if src is None: + self._error_comp = ErrorComponent( + self.render_fail_headline("`Image` Component `src` cannot be `None`"), + "", + ) + elif type(src) is not str: try: self._src = self._bytes_to_base64(src) except TypeError: self._error_comp = ErrorComponent( self.render_fail_headline( - "first argument should be of type `bytes` or valid image base64 string" + "The `Image` `src` argument should be of type `bytes` or valid image base64 string" ), "Type of %s is invalid" % (str(type(src))), ) @@ -291,11 +371,141 @@ def __init__(self, src=None, label=None): else: self._error_comp = ErrorComponent( self.render_fail_headline( - "first argument should be of type `bytes` or valid image base64 string" + "The `Image` `src` argument should be of type `bytes` or valid image base64 string" ), "String %s is invalid base64 string" % src, ) + def __init__(self, src=None, label=None, disable_updates: bool = True): + if disable_updates: + self.REALTIME_UPDATABLE = False + self._set_image_src(src, label=label) + + def _update_image(self, img_obj, label=None): + task_to_dict = TaskToDict() + parsed_image, err_comp = None, None + + # First set image for bytes/string type + if task_to_dict.object_type(img_obj) in ["bytes", "str"]: + self._set_image_src(img_obj, label=label) + return + + if task_to_dict.object_type(img_obj).startswith("PIL"): + parsed_image, err_comp = self._parse_pil_image(img_obj) + elif _full_classname(img_obj) == self._MATPLOTLIB_FIGURE_MODULE_PATH: + parsed_image, err_comp = self._parse_matplotlib(img_obj) + else: + parsed_image, err_comp = None, ErrorComponent( + self.render_fail_headline( + "Invalid Type. Object %s is not supported. Supported types: %s" + % ( + type(img_obj), + ", ".join( + [ + "str", + "bytes", + self._PIL_IMAGE_MODULE_PATH, + self._MATPLOTLIB_FIGURE_MODULE_PATH, + ] + ), + ) + ), + "", + ) + + if parsed_image is not None: + self._set_image_src(parsed_image, label=label) + else: + self._set_image_src(None, label=label) + self._error_comp = err_comp + + @classmethod + def _pil_parsing_error(cls, error_type): + return None, ErrorComponent( + cls.render_fail_headline( + "first argument for `Image` should be of type %s" + % cls._PIL_IMAGE_MODULE_PATH + ), + "Type of %s is invalid. Type of %s required" + % (error_type, cls._PIL_IMAGE_MODULE_PATH), + ) + + @classmethod + def _parse_pil_image(cls, pilimage): + parsed_value = None + error_component = None + import io + + task_to_dict = TaskToDict() + _img_type = task_to_dict.object_type(pilimage) + + if not _img_type.startswith("PIL"): + return cls._pil_parsing_error(_img_type) + + # Set the module as a part of the class so that + # we don't keep reloading the module everytime + pil_module = cls._get_pil_module() + + if pil_module is None: + return parsed_value, ErrorComponent( + cls.render_fail_headline("PIL cannot be imported"), "" + ) + if not isinstance(pilimage, pil_module.Image.Image): + return cls._pil_parsing_error(_img_type) + + img_byte_arr = io.BytesIO() + try: + pilimage.save(img_byte_arr, format="PNG") + except OSError as e: + return parsed_value, ErrorComponent( + cls.render_fail_headline("PIL Image Not Parsable"), "%s" % repr(e) + ) + img_byte_arr = img_byte_arr.getvalue() + parsed_value = task_to_dict.parse_image(img_byte_arr) + return parsed_value, error_component + + @classmethod + def _parse_matplotlib(cls, plot): + import io + import traceback + + parsed_value = None + error_component = None + pyplt = cls._get_plt_module() + if pyplt is None: + return parsed_value, ErrorComponent( + cls.render_fail_headline("Matplotlib cannot be imported"), + "%s" % traceback.format_exc(), + ) + # First check if it is a valid Matplotlib figure. + figure = None + if _full_classname(plot) == cls._MATPLOTLIB_FIGURE_MODULE_PATH: + figure = plot + + # If it is not valid figure then check if it is matplotlib.axes.Axes or a matplotlib.axes._subplots.AxesSubplot + # These contain the `get_figure` function to get the main figure object. + if figure is None: + if getattr(plot, "get_figure", None) is None: + return parsed_value, ErrorComponent( + cls.render_fail_headline( + "Invalid Type. Object %s is not from `matplotlib`" % type(plot) + ), + "", + ) + else: + figure = plot.get_figure() + + task_to_dict = TaskToDict() + img_bytes_arr = io.BytesIO() + figure.savefig(img_bytes_arr, format="PNG") + parsed_value = task_to_dict.parse_image(img_bytes_arr.getvalue()) + pyplt.close(figure) + if parsed_value is not None: + return parsed_value, error_component + return parsed_value, ErrorComponent( + cls.render_fail_headline("Matplotlib plot's image is not parsable"), "" + ) + @staticmethod def _bytes_to_base64(bytes_arr): task_to_dict = TaskToDict() @@ -307,7 +517,9 @@ def _bytes_to_base64(bytes_arr): return parsed_image @classmethod - def from_pil_image(cls, pilimage, label: Optional[str] = None): + def from_pil_image( + cls, pilimage, label: Optional[str] = None, disable_updates: bool = False + ): """ Create an `Image` from a PIL image. @@ -319,43 +531,29 @@ def from_pil_image(cls, pilimage, label: Optional[str] = None): Optional label for the image. """ try: - import io - - PIL_IMAGE_PATH = "PIL.Image.Image" - task_to_dict = TaskToDict() - if task_to_dict.object_type(pilimage) != PIL_IMAGE_PATH: - return ErrorComponent( - cls.render_fail_headline( - "first argument for `Image` should be of type %s" - % PIL_IMAGE_PATH - ), - "Type of %s is invalid. Type of %s required" - % (task_to_dict.object_type(pilimage), PIL_IMAGE_PATH), - ) - img_byte_arr = io.BytesIO() - try: - pilimage.save(img_byte_arr, format="PNG") - except OSError as e: - return ErrorComponent( - cls.render_fail_headline("PIL Image Not Parsable"), "%s" % repr(e) - ) - img_byte_arr = img_byte_arr.getvalue() - parsed_image = task_to_dict.parse_image(img_byte_arr) + parsed_image, error_comp = cls._parse_pil_image(pilimage) if parsed_image is not None: - return cls(src=parsed_image, label=label) - return ErrorComponent( - cls.render_fail_headline("PIL Image Not Parsable"), "" - ) + img = cls( + src=parsed_image, label=label, disable_updates=disable_updates + ) + else: + img = cls(src=None, label=label, disable_updates=disable_updates) + img._error_comp = error_comp + return img except: import traceback - return ErrorComponent( + img = cls(src=None, label=label, disable_updates=disable_updates) + img._error_comp = ErrorComponent( cls.render_fail_headline("PIL Image Not Parsable"), "%s" % traceback.format_exc(), ) + return img @classmethod - def from_matplotlib(cls, plot, label: Optional[str] = None): + def from_matplotlib( + cls, plot, label: Optional[str] = None, disable_updates: bool = False + ): """ Create an `Image` from a Matplotlib plot. @@ -366,64 +564,62 @@ def from_matplotlib(cls, plot, label: Optional[str] = None): label : str, optional Optional label for the image. """ - import io - try: - try: - import matplotlib.pyplot as pyplt - except ImportError: - return ErrorComponent( - cls.render_fail_headline("Matplotlib cannot be imported"), - "%s" % traceback.format_exc(), - ) - # First check if it is a valid Matplotlib figure. - figure = None - if _full_classname(plot) == "matplotlib.figure.Figure": - figure = plot - - # If it is not valid figure then check if it is matplotlib.axes.Axes or a matplotlib.axes._subplots.AxesSubplot - # These contain the `get_figure` function to get the main figure object. - if figure is None: - if getattr(plot, "get_figure", None) is None: - return ErrorComponent( - cls.render_fail_headline( - "Invalid Type. Object %s is not from `matplotlib`" - % type(plot) - ), - "", - ) - else: - figure = plot.get_figure() - - task_to_dict = TaskToDict() - img_bytes_arr = io.BytesIO() - figure.savefig(img_bytes_arr, format="PNG") - parsed_image = task_to_dict.parse_image(img_bytes_arr.getvalue()) - pyplt.close(figure) + parsed_image, error_comp = cls._parse_matplotlib(plot) if parsed_image is not None: - return cls(src=parsed_image, label=label) - return ErrorComponent( - cls.render_fail_headline("Matplotlib plot's image is not parsable"), "" - ) + img = cls( + src=parsed_image, label=label, disable_updates=disable_updates + ) + else: + img = cls(src=None, label=label, disable_updates=disable_updates) + img._error_comp = error_comp + return img except: import traceback - return ErrorComponent( + img = cls(src=None, label=label, disable_updates=disable_updates) + img._error_comp = ErrorComponent( cls.render_fail_headline("Matplotlib plot's image is not parsable"), "%s" % traceback.format_exc(), ) + return img + @with_default_component_id @render_safely def render(self): if self._error_comp is not None: return self._error_comp.render() if self._src is not None: - return ImageComponent(src=self._src, label=self._label).render() + img_comp = ImageComponent(src=self._src, label=self._label) + img_comp.component_id = self.component_id + return img_comp.render() return ErrorComponent( self.render_fail_headline("`Image` Component `src` argument is `None`"), "" ).render() + def update(self, image, label=None): + """ + Update the image. + + Parameters + ---------- + image : PIL.Image or matplotlib.figure.Figure or matplotlib.axes.Axes or matplotlib.axes._subplots.AxesSubplot or bytes or str + The updated image object + label : str, optional + Optional label for the image. + """ + if not self.REALTIME_UPDATABLE: + msg = ( + "The `Image` component is disabled for realtime updates. " + "Please set `disable_updates` to `False` while creating the `Image` object." + ) + _warning_with_component(self, msg) + return + + _label = label if label is not None else self._label + self._update_image(image, label=_label) + class Error(UserComponent): """ @@ -477,9 +673,135 @@ class Markdown(UserComponent): Text formatted in Markdown. """ + REALTIME_UPDATABLE = True + + def update(self, text=None): + self._text = text + def __init__(self, text=None): self._text = text + @with_default_component_id + @render_safely + def render(self): + comp = MarkdownComponent(self._text) + comp.component_id = self.component_id + return comp.render() + + +class ProgressBar(UserComponent): + """ + A Progress bar for tracking progress of any task. + + Example: + ``` + progress_bar = ProgressBar( + max=100, + label="Progress Bar", + value=0, + unit="%", + metadata="0.1 items/s" + ) + current.card.append( + progress_bar + ) + for i in range(100): + progress_bar.update(i, metadata="%s items/s" % i) + + ``` + + Parameters + ---------- + text : str + Text formatted in Markdown. + """ + + type = "progressBar" + + REALTIME_UPDATABLE = True + + def __init__( + self, + max: int = 100, + label: str = None, + value: int = 0, + unit: str = None, + metadata: str = None, + ): + self._label = label + self._max = max + self._value = value + self._unit = unit + self._metadata = metadata + + def update(self, new_value: int, metadata: str = None): + self._value = new_value + if metadata is not None: + self._metadata = metadata + + @with_default_component_id + @render_safely + def render(self): + data = { + "type": self.type, + "id": self.component_id, + "max": self._max, + "value": self._value, + } + if self._label: + data["label"] = self._label + if self._unit: + data["unit"] = self._unit + if self._metadata: + data["details"] = self._metadata + return data + + +class VegaChart(UserComponent): + type = "vegaChart" + + REALTIME_UPDATABLE = True + + def __init__(self, spec: dict, show_controls=False): + self._spec = spec + self._show_controls = show_controls + self._chart_inside_table = False + + def update(self, spec=None): + if spec is not None: + self._spec = spec + + @classmethod + def from_altair_chart(cls, altair_chart): + from metaflow.plugins.cards.card_modules.convert_to_native_type import ( + _full_classname, + ) + + # This will feel slightly hacky but I am unable to find a natural way of determining the class + # name of the Altair chart. The only way I can think of is to use the full class name and then + # match with heuristics + + fulclsname = _full_classname(altair_chart) + if not all([x in fulclsname for x in ["altair", "vegalite", "Chart"]]): + raise ValueError(fulclsname + " is not an altair chart") + + altair_chart_dict = altair_chart.to_dict() + + cht = cls(spec=altair_chart_dict) + return cht + + @with_default_component_id @render_safely def render(self): - return MarkdownComponent(self._text).render() + data = { + "type": self.type, + "id": self.component_id, + "spec": self._spec, + } + if not self._show_controls: + data["options"] = {"actions": False} + if "width" not in self._spec and not self._chart_inside_table: + data["spec"]["width"] = "container" + if self._chart_inside_table and "autosize" not in self._spec: + data["spec"]["autosize"] = "fit-x" + return data diff --git a/metaflow/plugins/cards/card_modules/convert_to_native_type.py b/metaflow/plugins/cards/card_modules/convert_to_native_type.py index df93ae6c48..5a94872d70 100644 --- a/metaflow/plugins/cards/card_modules/convert_to_native_type.py +++ b/metaflow/plugins/cards/card_modules/convert_to_native_type.py @@ -44,7 +44,7 @@ def _full_classname(obj): class TaskToDict: - def __init__(self, only_repr=False): + def __init__(self, only_repr=False, runtime=False): # this dictionary holds all the supported functions import reprlib import pprint @@ -59,6 +59,7 @@ def __init__(self, only_repr=False): r.maxlist = 100 r.maxlevel = 3 self._repr = r + self._runtime = runtime self._only_repr = only_repr self._supported_types = { "tuple": self._parse_tuple, @@ -90,11 +91,16 @@ def __call__(self, task, graph=None): stderr=task.stderr, stdout=task.stdout, created_at=task.created_at.strftime(TIME_FORMAT), - finished_at=task.finished_at.strftime(TIME_FORMAT), + finished_at=None, pathspec=task.pathspec, graph=graph, data={}, ) + if not self._runtime: + if task.finished_at is not None: + task_dict.update( + dict(finished_at=task.finished_at.strftime(TIME_FORMAT)) + ) task_dict["data"], type_infered_objects = self._create_task_data_dict(task) task_dict.update(type_infered_objects) return task_dict diff --git a/metaflow/plugins/cards/card_modules/test_cards.py b/metaflow/plugins/cards/card_modules/test_cards.py index 9a6e2ff7dc..39fb06fe9d 100644 --- a/metaflow/plugins/cards/card_modules/test_cards.py +++ b/metaflow/plugins/cards/card_modules/test_cards.py @@ -154,13 +154,16 @@ class TestRefreshCard(MetaflowCard): type = "test_refresh_card" - def render(self, task, data) -> str: + def render(self, task) -> str: + return self._render_func(task, self.runtime_data) + + def _render_func(self, task, data): return self.HTML_TEMPLATE.replace( "[REPLACE_CONTENT_HERE]", json.dumps(data["user"]) ).replace("[PATHSPEC]", task.pathspec) def render_runtime(self, task, data): - return self.render(task, data) + return self._render_func(task, data) def refresh(self, task, data): return data @@ -195,14 +198,14 @@ class TestRefreshComponentCard(MetaflowCard): def __init__(self, options={}, components=[], graph=None): self._components = components - def render(self, task, data) -> str: + def render(self, task) -> str: # Calling `render`/`render_runtime` wont require the `data` object return self.HTML_TEMPLATE.replace( "[REPLACE_CONTENT_HERE]", json.dumps(self._components) ).replace("[PATHSPEC]", task.pathspec) def render_runtime(self, task, data): - return self.render(task, data) + return self.render(task) def refresh(self, task, data): # Govers the information passed in the data update diff --git a/metaflow/plugins/cards/card_server.py b/metaflow/plugins/cards/card_server.py new file mode 100644 index 0000000000..7b5008410b --- /dev/null +++ b/metaflow/plugins/cards/card_server.py @@ -0,0 +1,361 @@ +import os +import json +from http.server import BaseHTTPRequestHandler +from threading import Thread +from multiprocessing import Pipe +from multiprocessing.connection import Connection +from urllib.parse import urlparse +import time + +try: + from http.server import ThreadingHTTPServer +except ImportError: + from socketserver import ThreadingMixIn + from http.server import HTTPServer + + class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): + daemon_threads = True + + +from .card_client import CardContainer +from .exception import CardNotPresentException +from .card_resolver import resolve_paths_from_task +from metaflow.metaflow_config import DATASTORE_LOCAL_DIR +from metaflow import namespace +from metaflow.exception import ( + CommandException, + MetaflowNotFound, + MetaflowNamespaceMismatch, +) + + +VIEWER_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "card_viewer", "viewer.html" +) + +CARD_VIEWER_HTML = open(VIEWER_PATH).read() + +TASK_CACHE = {} + +_ClickLogger = None + + +class RunWatcher(Thread): + """ + A thread that watches for new runs and sends the run_id to the + card server when a new run is detected. It observes the `latest_run` + file in the `.metaflow/` directory. + """ + + def __init__(self, flow_name, connection: Connection): + super().__init__() + + self._watch_file = os.path.join( + os.getcwd(), DATASTORE_LOCAL_DIR, flow_name, "latest_run" + ) + self._current_run_id = self.get_run_id() + self.daemon = True + self._connection = connection + + def get_run_id(self): + if not os.path.exists(self._watch_file): + return None + with open(self._watch_file, "r") as f: + return f.read().strip() + + def watch(self): + while True: + run_id = self.get_run_id() + if run_id != self._current_run_id: + self._current_run_id = run_id + self._connection.send(run_id) + time.sleep(2) + + def run(self): + self.watch() + + +class CardServerOptions: + def __init__( + self, + flow_name, + run_object, + only_running, + follow_resumed, + flow_datastore, + follow_new_runs, + max_cards=20, + poll_interval=5, + ): + from metaflow import Run + + self.RunClass = Run + self.run_object = run_object + + self.flow_name = flow_name + self.only_running = only_running + self.follow_resumed = follow_resumed + self.flow_datastore = flow_datastore + self.max_cards = max_cards + self.follow_new_runs = follow_new_runs + self.poll_interval = poll_interval + + self._parent_conn, self._child_conn = Pipe() + + def refresh_run(self): + if not self.follow_new_runs: + return False + if not self.parent_conn.poll(): + return False + run_id = self.parent_conn.recv() + if run_id is None: + return False + namespace(None) + try: + self.run_object = self.RunClass(f"{self.flow_name}/{run_id}") + return True + except MetaflowNotFound: + return False + + @property + def parent_conn(self): + return self._parent_conn + + @property + def child_conn(self): + return self._child_conn + + +def cards_for_task( + flow_datastore, task_pathspec, card_type=None, card_hash=None, card_id=None +): + try: + paths, card_ds = resolve_paths_from_task( + flow_datastore, + task_pathspec, + type=card_type, + hash=card_hash, + card_id=card_id, + ) + except CardNotPresentException: + return None + for card in CardContainer(paths, card_ds, origin_pathspec=None): + yield card + + +def cards_for_run( + flow_datastore, + run_object, + only_running, + card_type=None, + card_hash=None, + card_id=None, + max_cards=20, +): + curr_idx = 0 + for step in run_object.steps(): + for task in step.tasks(): + if only_running and task.finished: + continue + card_generator = cards_for_task( + flow_datastore, + task.pathspec, + card_type=card_type, + card_hash=card_hash, + card_id=card_id, + ) + if card_generator is None: + continue + for card in card_generator: + curr_idx += 1 + if curr_idx >= max_cards: + raise StopIteration + yield task.pathspec, card + + +class CardViewerRoutes(BaseHTTPRequestHandler): + + card_options: CardServerOptions = None + + run_watcher: RunWatcher = None + + def do_GET(self): + try: + _, path = self.path.split("/", 1) + try: + prefix, suffix = path.split("/", 1) + except: + prefix = path + suffix = None + except: + prefix = None + if prefix in self.ROUTES: + self.ROUTES[prefix](self, suffix) + else: + self._response(open(VIEWER_PATH).read().encode("utf-8")) + + def get_runinfo(self, suffix): + run_id_changed = self.card_options.refresh_run() + if run_id_changed: + self.log_message( + "RunID changed in the background to %s" + % self.card_options.run_object.pathspec + ) + _ClickLogger( + "RunID changed in the background to %s" + % self.card_options.run_object.pathspec, + fg="blue", + ) + + if self.card_options.run_object is None: + self._response( + {"status": "No Run Found", "flow": self.card_options.flow_name}, + code=404, + is_json=True, + ) + return + + task_card_generator = cards_for_run( + self.card_options.flow_datastore, + self.card_options.run_object, + self.card_options.only_running, + max_cards=self.card_options.max_cards, + ) + flow_name = self.card_options.run_object.parent.id + run_id = self.card_options.run_object.id + cards = [] + for pathspec, card in task_card_generator: + step, task = pathspec.split("/")[-2:] + _task = self.card_options.run_object[step][task] + task_finished = True if _task.finished else False + cards.append( + dict( + task=pathspec, + label="%s/%s %s" % (step, task, card.hash), + card_object=dict( + hash=card.hash, + type=card.type, + path=card.path, + id=card.id, + ), + finished=task_finished, + card="%s/%s" % (pathspec, card.hash), + ) + ) + resp = { + "status": "ok", + "flow": flow_name, + "run_id": run_id, + "cards": cards, + "poll_interval": self.card_options.poll_interval, + } + self._response(resp, is_json=True) + + def get_card(self, suffix): + _suffix = urlparse(self.path).path + _, flow, run_id, step, task_id, card_hash = _suffix.strip("/").split("/") + + pathspec = "/".join([flow, run_id, step, task_id]) + cards = list( + cards_for_task( + self.card_options.flow_datastore, pathspec, card_hash=card_hash + ) + ) + if len(cards) == 0: + self._response({"status": "Card Not Found"}, code=404) + return + selected_card = cards[0] + self._response(selected_card.get().encode("utf-8")) + + def get_data(self, suffix): + _suffix = urlparse(self.path).path + _, flow, run_id, step, task_id, card_hash = _suffix.strip("/").split("/") + pathspec = "/".join([flow, run_id, step, task_id]) + cards = list( + cards_for_task( + self.card_options.flow_datastore, pathspec, card_hash=card_hash + ) + ) + if len(cards) == 0: + self._response( + { + "status": "Card Not Found", + }, + is_json=True, + code=404, + ) + return + + status = "ok" + try: + task_object = self.card_options.run_object[step][task_id] + except KeyError: + return self._response( + {"status": "Task Not Found", "is_complete": False}, + is_json=True, + code=404, + ) + + is_complete = task_object.finished + selected_card = cards[0] + card_data = selected_card.get_data() + if card_data is not None: + self.log_message( + "Task Success: %s, Task Finished: %s" + % (task_object.successful, is_complete) + ) + if not task_object.successful and is_complete: + status = "Task Failed" + self._response( + {"status": status, "payload": card_data, "is_complete": is_complete}, + is_json=True, + ) + else: + self._response( + {"status": "ok", "is_complete": is_complete}, + is_json=True, + code=404, + ) + + def _response(self, body, is_json=False, code=200): + self.send_response(code) + mime = "application/json" if is_json else "text/html" + self.send_header("Content-type", mime) + self.end_headers() + if is_json: + self.wfile.write(json.dumps(body).encode("utf-8")) + else: + self.wfile.write(body) + + ROUTES = {"runinfo": get_runinfo, "card": get_card, "data": get_data} + + +def _is_debug_mode(): + debug_flag = os.environ.get("METAFLOW_DEBUG_CARD_SERVER") + if debug_flag is None: + return False + return debug_flag.lower() in ["true", "1"] + + +def create_card_server(card_options: CardServerOptions, port, ctx_obj): + CardViewerRoutes.card_options = card_options + global _ClickLogger + _ClickLogger = ctx_obj.echo + if card_options.follow_new_runs: + CardViewerRoutes.run_watcher = RunWatcher( + card_options.flow_name, card_options.child_conn + ) + CardViewerRoutes.run_watcher.start() + server_addr = ("", port) + ctx_obj.echo( + "Starting card server on port %d " % (port), + fg="green", + bold=True, + ) + # Disable logging if not in debug mode + if not _is_debug_mode(): + CardViewerRoutes.log_request = lambda *args, **kwargs: None + CardViewerRoutes.log_message = lambda *args, **kwargs: None + + server = ThreadingHTTPServer(server_addr, CardViewerRoutes) + server.serve_forever() diff --git a/metaflow/plugins/cards/card_viewer/viewer.html b/metaflow/plugins/cards/card_viewer/viewer.html new file mode 100644 index 0000000000..00a794facd --- /dev/null +++ b/metaflow/plugins/cards/card_viewer/viewer.html @@ -0,0 +1,344 @@ + + + + + + Metaflow Run + + + +
+

Metaflow Run

+
+ +
+
+
Initializing
+ +
+
+
+ +
+ + + + diff --git a/metaflow/plugins/cards/component_serializer.py b/metaflow/plugins/cards/component_serializer.py index 5d79c8f710..66f0b62ba1 100644 --- a/metaflow/plugins/cards/component_serializer.py +++ b/metaflow/plugins/cards/component_serializer.py @@ -263,12 +263,13 @@ def extend(self, components): def clear(self): self._components.clear() - def _card_proc(self, mode): - self._card_creator.create(**self._card_creator_args, mode=mode) + def _card_proc(self, mode, sync=False): + self._card_creator.create(**self._card_creator_args, mode=mode, sync=sync) def refresh(self, data=None, force=False): self._latest_user_data = data nu = time.time() + first_render = True if self._last_render == 0 else False if nu - self._last_refresh < self._refresh_interval: # rate limit refreshes: silently ignore requests that @@ -287,11 +288,24 @@ def refresh(self, data=None, force=False): self._last_layout_change != self.components.layout_last_changed_on or self._last_layout_change is None ) - if force or last_rendered_before_minimum_interval or layout_has_changed: self._render_seq += 1 self._last_render = nu self._card_proc("render_runtime") + # The below `if not first_render` condition is a special case for the following scenario: + # Lets assume the case that the user is only doing `current.card.append` followed by `refresh`. + # In this case, there will be no process executed in `refresh` mode since `layout_has_changed` + # will always be true and as a result there will be no data update that informs the UI of the RELOAD_TOKEN change. + # This will cause the UI to seek for the data update object but will constantly find None. So if it is not + # the first render then we should also have a `refresh` call followed by a `render_runtime` call so + # that the UI can always be updated with the latest data. + if not first_render: + # For the general case, the CardCreator's ProcessManager run's the `refresh` / `render_runtime` in a asynchronous manner. + # Due to this when the `render_runtime` call is happening, an immediately subsequent call to `refresh` will not be able to + # execute since the card-process manager will be busy executing the `render_runtime` call and ignore the `refresh` call. + # Hence we need to pass the `sync=True` argument to the `refresh` call so that the `refresh` call is executed synchronously and waits for the + # `render_runtime` call to finish. + self._card_proc("refresh", sync=True) # We set self._last_layout_change so that when self._last_layout_change is not the same # as `self.components.layout_last_changed_on`, then the component array itself # has been modified. So we should force a re-render of the card.