-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
135 additions
and
36 deletions.
There are no files selected for viewing
36 changes: 26 additions & 10 deletions
36
metaflow_extensions/huggingface_dataset/plugins/cards/huggingface_dataset/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,38 @@ | ||
from metaflow.cards import MetaflowCard | ||
from metaflow.cards import MetaflowCard | ||
from metaflow.exception import MetaflowException | ||
import datetime | ||
|
||
note = f'''This dataset was loaded using the HuggingFace Datasets library at {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}. | ||
note = f"""This dataset was loaded using the HuggingFace Datasets library at {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}. | ||
This card uses HuggingFace's dataset viewer, which let's you view the <b>latest</b> version on HuggingFace, not necessarily the version used in this run. | ||
''' | ||
""" | ||
|
||
|
||
class HuggingfaceDatasetCard(MetaflowCard): | ||
|
||
type = "huggingface_dataset" | ||
|
||
def __init__(self, options={'id': None, 'vh': 550}, **kwargs): | ||
if not options.get('id'): | ||
raise MetaflowException("Dataset ID is required for Huggingface Dataset card.") | ||
self.dataset_id = options.get('id', '') | ||
self.vh = options.get('vh', 550) | ||
RUNTIME_UPDATABLE = True | ||
RELOAD_POLICY = "never" | ||
|
||
def __init__(self, options={"id": None, "artifact_id": None, "vh": 550}, **kwargs): | ||
if not options.get("id") and not options.get("artifact_id"): | ||
raise MetaflowException( | ||
"Dataset ID or Metaflow FlowSpec artifact_id is required for Huggingface Dataset card." | ||
) | ||
self.dataset_id = options.get("id", "") | ||
self.artifact_id = options.get("artifact_id", "") | ||
self.vh = options.get("vh", 550) | ||
|
||
def render_runtime(self, task, data): | ||
return self.render(task, runtime=True) | ||
|
||
def render(self, task): | ||
dataset_viewer_url = f'https://huggingface.co/datasets/{self.dataset_id}/embed/viewer' | ||
if not self.dataset_id: | ||
self.dataset_id = getattr(task.data, self.artifact_id) | ||
dataset_viewer_url = ( | ||
f"https://huggingface.co/datasets/{self.dataset_id}/embed/viewer" | ||
) | ||
return f'<html><body><p>{note}</p><iframe src="{dataset_viewer_url}" width="100%" height="{self.vh}vh"></iframe></body></html>' | ||
|
||
CARDS = [HuggingfaceDatasetCard] | ||
|
||
CARDS = [HuggingfaceDatasetCard] |
116 changes: 99 additions & 17 deletions
116
metaflow_extensions/huggingface_dataset/plugins/huggingface_dataset/deco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,109 @@ | ||
from metaflow.decorators import StepDecorator | ||
from metaflow.exception import MetaflowException | ||
from collections import defaultdict | ||
import json | ||
import uuid | ||
|
||
class huggingface_dataset_deco: | ||
CARD_TYPE = "huggingface_dataset" | ||
VERTICAL_HEIGHT_DEFAULT = 550 | ||
|
||
def __init__(self, **kwargs): | ||
self.kwargs = kwargs | ||
|
||
def __call__(self, step_func): | ||
class CardDecoratorInjector: | ||
""" | ||
Mixin Useful for injecting @card decorators from other first class Metaflow decorators. | ||
""" | ||
|
||
from metaflow import _huggingface_dataset, card, current | ||
_first_time_init = defaultdict(dict) | ||
|
||
_card = card( | ||
type="huggingface_dataset", | ||
id=self.kwargs.get("id").replace("/", "_").replace("-", "_"), | ||
options={ | ||
"id": self.kwargs.get("id"), | ||
"vh": self.kwargs.get("vh", 550) | ||
} | ||
) | ||
@classmethod | ||
def _get_first_time_init_cached_value(cls, step_name, card_id): | ||
return cls._first_time_init.get(step_name, {}).get(card_id, None) | ||
|
||
@classmethod | ||
def _set_first_time_init_cached_value(cls, step_name, card_id, value): | ||
cls._first_time_init[step_name][card_id] = value | ||
|
||
def _card_deco_already_attached(self, step, card_id): | ||
for decorator in step.decorators: | ||
if decorator.name == "card": | ||
if decorator.attributes["id"] and card_id in decorator.attributes["id"]: | ||
return True | ||
return False | ||
|
||
def _get_step(self, flow, step_name): | ||
for step in flow: | ||
if step.name == step_name: | ||
return step | ||
return None | ||
|
||
def _first_time_init_check(self, step_dag_node, card_id): | ||
""" """ | ||
return not self._card_deco_already_attached(step_dag_node, card_id) | ||
|
||
def attach_card_decorator(self, flow, step_name, card_id, card_type, options): | ||
""" | ||
This method is called `step_init` in your StepDecorator code since | ||
this class is used as a Mixin | ||
""" | ||
from metaflow import decorators as _decorators | ||
|
||
if not all([card_id, card_type]): | ||
raise MetaflowException( | ||
"`INJECTED_CARD_ID` and `INJECTED_CARD_TYPE` must be set in the `CardDecoratorInjector` Mixin" | ||
) | ||
|
||
return _card(_huggingface_dataset(**self.kwargs)(step_func)) | ||
step_dag_node = self._get_step(flow, step_name) | ||
if ( | ||
self._get_first_time_init_cached_value(step_name, card_id) is None | ||
): # First check class level setting. | ||
if self._first_time_init_check(step_dag_node, card_id): | ||
self._set_first_time_init_cached_value(step_name, card_id, True) | ||
_decorators._attach_decorators_to_step( | ||
step_dag_node, | ||
[ | ||
"card:type=%s,id=%s,options=%s" | ||
% (CARD_TYPE, card_id, json.dumps(options)) | ||
], | ||
) | ||
else: | ||
self._set_first_time_init_cached_value(step_name, card_id, False) | ||
|
||
|
||
class HuggingfaceDatasetDecorator(StepDecorator): | ||
class HuggingfaceDatasetDecorator(StepDecorator, CardDecoratorInjector): | ||
|
||
name = "_huggingface_dataset" | ||
defaults = {"id": None} | ||
name = CARD_TYPE | ||
defaults = {"id": None, "artifact_id": None, "vh": 550} | ||
|
||
def step_init( | ||
self, flow, graph, step_name, decorators, environment, flow_datastore, logger | ||
): | ||
|
||
if not self.attributes.get("id") and not self.attributes.get("artifact_id"): | ||
raise MetaflowException( | ||
"Dataset ID or Metaflow FlowSpec artifact_id is required for Huggingface Dataset card." | ||
) | ||
|
||
if ( | ||
self.attributes.get("id") is not None | ||
and self.attributes.get("artifact_id") is not None | ||
): | ||
raise MetaflowException( | ||
"Both Dataset ID and Metaflow FlowSpec artifact_id cannot be set at the same time." | ||
) | ||
|
||
if self.attributes.get("id"): | ||
_id = self.attributes.get("id").replace("/", "_").replace("-", "_") | ||
else: | ||
_id = self.attributes.get("artifact_id") | ||
|
||
self.attach_card_decorator( | ||
flow, | ||
step_name, | ||
_id, | ||
CARD_TYPE, | ||
options={ | ||
"id": self.attributes.get("id"), | ||
"artifact_id": self.attributes.get("artifact_id"), | ||
"vh": self.attributes.get("vh", 550), | ||
}, | ||
) |
4 changes: 2 additions & 2 deletions
4
metaflow_extensions/huggingface_dataset/plugins/mfextinit_huggingface_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
STEP_DECORATORS_DESC = [ | ||
("_huggingface_dataset", ".huggingface_dataset.deco.HuggingfaceDatasetDecorator"), | ||
] | ||
("huggingface_dataset", ".huggingface_dataset.deco.HuggingfaceDatasetDecorator"), | ||
] |
2 changes: 1 addition & 1 deletion
2
metaflow_extensions/huggingface_dataset/toplevel/mfextinit_huggingface_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
toplevel = "toplevel" | ||
toplevel = "toplevel" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
__mf_extensions__ = "huggingface_dataset" | ||
|
||
from ..plugins.huggingface_dataset.deco import huggingface_dataset_deco as huggingface_dataset | ||
import pkg_resources | ||
|
||
try: | ||
__version__ = pkg_resources.get_distribution("metaflow-card-hf-dataset").version | ||
except: | ||
# this happens on remote environments since the job package | ||
# does not have a version | ||
__version__ = None | ||
__version__ = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,19 @@ | ||
from setuptools import find_namespace_packages, setup | ||
|
||
|
||
def get_long_description() -> str: | ||
with open("README.md") as fh: | ||
return fh.read() | ||
|
||
|
||
setup( | ||
name="metaflow-card-hf-dataset", | ||
version="0.0.3", | ||
name="metaflow-card-hf-dataset", | ||
version="0.0.4", | ||
description="A metaflow card that renders HTML inputs.", | ||
long_description=get_long_description(), | ||
long_description_content_type="text/markdown", | ||
author="Outerbounds", | ||
author_email="[email protected]", | ||
license="Apache Software License 2.0", | ||
packages=find_namespace_packages(include=['metaflow_extensions.*']), | ||
) | ||
packages=find_namespace_packages(include=["metaflow_extensions.*"]), | ||
) |