Skip to content

Commit

Permalink
refactor to support artifact_id
Browse files Browse the repository at this point in the history
  • Loading branch information
emattia committed Jul 17, 2024
1 parent 658b4c2 commit ff2c1a7
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,22 +1,38 @@
from metaflow.cards import MetaflowCard
from metaflow.cards import MetaflowCard
from metaflow.exception import MetaflowException
import datetime

note = f'''This dataset was loaded using the HuggingFace Datasets library at {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.
note = f"""This dataset was loaded using the HuggingFace Datasets library at {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.
This card uses HuggingFace's dataset viewer, which let's you view the <b>latest</b> version on HuggingFace, not necessarily the version used in this run.
'''
"""


class HuggingfaceDatasetCard(MetaflowCard):

type = "huggingface_dataset"

def __init__(self, options={'id': None, 'vh': 550}, **kwargs):
if not options.get('id'):
raise MetaflowException("Dataset ID is required for Huggingface Dataset card.")
self.dataset_id = options.get('id', '')
self.vh = options.get('vh', 550)
RUNTIME_UPDATABLE = True
RELOAD_POLICY = "never"

def __init__(self, options={"id": None, "artifact_id": None, "vh": 550}, **kwargs):
if not options.get("id") and not options.get("artifact_id"):
raise MetaflowException(
"Dataset ID or Metaflow FlowSpec artifact_id is required for Huggingface Dataset card."
)
self.dataset_id = options.get("id", "")
self.artifact_id = options.get("artifact_id", "")
self.vh = options.get("vh", 550)

def render_runtime(self, task, data):
return self.render(task, runtime=True)

def render(self, task):
dataset_viewer_url = f'https://huggingface.co/datasets/{self.dataset_id}/embed/viewer'
if not self.dataset_id:
self.dataset_id = getattr(task.data, self.artifact_id)
dataset_viewer_url = (
f"https://huggingface.co/datasets/{self.dataset_id}/embed/viewer"
)
return f'<html><body><p>{note}</p><iframe src="{dataset_viewer_url}" width="100%" height="{self.vh}vh"></iframe></body></html>'

CARDS = [HuggingfaceDatasetCard]

CARDS = [HuggingfaceDatasetCard]
Original file line number Diff line number Diff line change
@@ -1,27 +1,109 @@
from metaflow.decorators import StepDecorator
from metaflow.exception import MetaflowException
from collections import defaultdict
import json
import uuid

class huggingface_dataset_deco:
CARD_TYPE = "huggingface_dataset"
VERTICAL_HEIGHT_DEFAULT = 550

def __init__(self, **kwargs):
self.kwargs = kwargs

def __call__(self, step_func):
class CardDecoratorInjector:
"""
Mixin Useful for injecting @card decorators from other first class Metaflow decorators.
"""

from metaflow import _huggingface_dataset, card, current
_first_time_init = defaultdict(dict)

_card = card(
type="huggingface_dataset",
id=self.kwargs.get("id").replace("/", "_").replace("-", "_"),
options={
"id": self.kwargs.get("id"),
"vh": self.kwargs.get("vh", 550)
}
)
@classmethod
def _get_first_time_init_cached_value(cls, step_name, card_id):
return cls._first_time_init.get(step_name, {}).get(card_id, None)

@classmethod
def _set_first_time_init_cached_value(cls, step_name, card_id, value):
cls._first_time_init[step_name][card_id] = value

def _card_deco_already_attached(self, step, card_id):
for decorator in step.decorators:
if decorator.name == "card":
if decorator.attributes["id"] and card_id in decorator.attributes["id"]:
return True
return False

def _get_step(self, flow, step_name):
for step in flow:
if step.name == step_name:
return step
return None

def _first_time_init_check(self, step_dag_node, card_id):
""" """
return not self._card_deco_already_attached(step_dag_node, card_id)

def attach_card_decorator(self, flow, step_name, card_id, card_type, options):
"""
This method is called `step_init` in your StepDecorator code since
this class is used as a Mixin
"""
from metaflow import decorators as _decorators

if not all([card_id, card_type]):
raise MetaflowException(
"`INJECTED_CARD_ID` and `INJECTED_CARD_TYPE` must be set in the `CardDecoratorInjector` Mixin"
)

return _card(_huggingface_dataset(**self.kwargs)(step_func))
step_dag_node = self._get_step(flow, step_name)
if (
self._get_first_time_init_cached_value(step_name, card_id) is None
): # First check class level setting.
if self._first_time_init_check(step_dag_node, card_id):
self._set_first_time_init_cached_value(step_name, card_id, True)
_decorators._attach_decorators_to_step(
step_dag_node,
[
"card:type=%s,id=%s,options=%s"
% (CARD_TYPE, card_id, json.dumps(options))
],
)
else:
self._set_first_time_init_cached_value(step_name, card_id, False)


class HuggingfaceDatasetDecorator(StepDecorator):
class HuggingfaceDatasetDecorator(StepDecorator, CardDecoratorInjector):

name = "_huggingface_dataset"
defaults = {"id": None}
name = CARD_TYPE
defaults = {"id": None, "artifact_id": None, "vh": 550}

def step_init(
self, flow, graph, step_name, decorators, environment, flow_datastore, logger
):

if not self.attributes.get("id") and not self.attributes.get("artifact_id"):
raise MetaflowException(
"Dataset ID or Metaflow FlowSpec artifact_id is required for Huggingface Dataset card."
)

if (
self.attributes.get("id") is not None
and self.attributes.get("artifact_id") is not None
):
raise MetaflowException(
"Both Dataset ID and Metaflow FlowSpec artifact_id cannot be set at the same time."
)

if self.attributes.get("id"):
_id = self.attributes.get("id").replace("/", "_").replace("-", "_")
else:
_id = self.attributes.get("artifact_id")

self.attach_card_decorator(
flow,
step_name,
_id,
CARD_TYPE,
options={
"id": self.attributes.get("id"),
"artifact_id": self.attributes.get("artifact_id"),
"vh": self.attributes.get("vh", 550),
},
)
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
STEP_DECORATORS_DESC = [
("_huggingface_dataset", ".huggingface_dataset.deco.HuggingfaceDatasetDecorator"),
]
("huggingface_dataset", ".huggingface_dataset.deco.HuggingfaceDatasetDecorator"),
]
Original file line number Diff line number Diff line change
@@ -1 +1 @@
toplevel = "toplevel"
toplevel = "toplevel"
3 changes: 1 addition & 2 deletions metaflow_extensions/huggingface_dataset/toplevel/toplevel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
__mf_extensions__ = "huggingface_dataset"

from ..plugins.huggingface_dataset.deco import huggingface_dataset_deco as huggingface_dataset
import pkg_resources

try:
__version__ = pkg_resources.get_distribution("metaflow-card-hf-dataset").version
except:
# this happens on remote environments since the job package
# does not have a version
__version__ = None
__version__ = None
10 changes: 6 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from setuptools import find_namespace_packages, setup


def get_long_description() -> str:
with open("README.md") as fh:
return fh.read()


setup(
name="metaflow-card-hf-dataset",
version="0.0.3",
name="metaflow-card-hf-dataset",
version="0.0.4",
description="A metaflow card that renders HTML inputs.",
long_description=get_long_description(),
long_description_content_type="text/markdown",
author="Outerbounds",
author_email="[email protected]",
license="Apache Software License 2.0",
packages=find_namespace_packages(include=['metaflow_extensions.*']),
)
packages=find_namespace_packages(include=["metaflow_extensions.*"]),
)

0 comments on commit ff2c1a7

Please sign in to comment.