Skip to content

Commit

Permalink
Schema management for dataset-types
Browse files Browse the repository at this point in the history
As discussed, we're going with a list of tuples, and not validating the types.
That said, we have a set of types we allow and we can assume that any of those are supported,
whereas the others are allowed but will not be fully supported.

Note we also made the following decisions:

1. We implement the types as a single tag internally. This will change,
   so its not exposed to the user
2. We only allow this to decorate registered dataframe types
  • Loading branch information
elijahbenizzy committed Dec 27, 2023
1 parent c4babe4 commit e8f83d3
Show file tree
Hide file tree
Showing 11 changed files with 413 additions and 259 deletions.
1 change: 1 addition & 0 deletions docs/reference/decorators/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ For reference we list available decorators for Hamilton here. Note: use
resolve
save_to
subdag
schema
tag
with_columns
12 changes: 12 additions & 0 deletions docs/reference/decorators/schema.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
=======================
schema
=======================

`@schema` is a function modifier that allows you to specify a schema for the
function's inputs/outputs. This can be used to validate data at runtime, visualize, etc...


**Reference Documentation**

.. autoclass:: hamilton.function_modifiers.schema
:members: output
421 changes: 182 additions & 239 deletions examples/spark/pyspark_feature_catalog/example_usage.ipynb

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions examples/spark/pyspark_feature_catalog/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@
from pyspark.sql import functions as sf
from with_columns import darkshore_flag, durotar_flag

from hamilton.function_modifiers import tag
from hamilton.function_modifiers import schema
from hamilton.plugins.h_spark import with_columns

WORLD_OF_WARCRAFT__SCHEMA = "zone:str, level:int, avatarId:int"
WORLD_OF_WARCRAFT_SCHEMA = (("zone", "str"), ("level", "int"), ("avatarId", "int"))


def spark_session() -> ps.SparkSession:
return ps.SparkSession.builder.master("local[1]").getOrCreate()


@tag(spark_schema=WORLD_OF_WARCRAFT__SCHEMA)
@schema.output(*WORLD_OF_WARCRAFT_SCHEMA)
def world_of_warcraft(spark_session: ps.SparkSession) -> ps.DataFrame:
return spark_session.read.parquet("data/wow.parquet")


@with_columns(darkshore_flag, durotar_flag, columns_to_pass=["zone"])
@tag(spark_schema=WORLD_OF_WARCRAFT__SCHEMA + ", darkshore_flag:int, durotar_flag:int")
@schema.output(*WORLD_OF_WARCRAFT_SCHEMA, ("darkshore_flag", "int"), ("durotar_flag", "int"))
def with_flags(world_of_warcraft: ps.DataFrame) -> ps.DataFrame:
return world_of_warcraft


@tag(spark_schema="total_count:int, darkshore_count:int, durotar_count:int")
@schema.output(("total_count", "int"), ("darkshore_count", "int"), ("durotar_count", "int"))
def zone_counts(with_flags: ps.DataFrame, aggregation_level: str) -> ps.DataFrame:
return with_flags.groupby(aggregation_level).agg(
sf.count("*").alias("total_count"),
Expand All @@ -32,7 +32,7 @@ def zone_counts(with_flags: ps.DataFrame, aggregation_level: str) -> ps.DataFram
)


@tag(spark_schema="mean_level:float")
@schema.output(("mean_level", "float"))
def level_info(world_of_warcraft: ps.DataFrame, aggregation_level: str) -> ps.DataFrame:
return world_of_warcraft.groupby(aggregation_level).agg(
sf.mean("level").alias("mean_level"),
Expand Down
23 changes: 23 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ def display_all_functions(
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Displays the graph of all functions loaded!
Expand All @@ -681,6 +682,8 @@ def display_all_functions(
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if
the nodes have schema data provided
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand All @@ -693,6 +696,7 @@ def display_all_functions(
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand All @@ -711,6 +715,7 @@ def _visualize_execution_helper(
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
):
"""Helper function to visualize execution, using a passed-in function graph.
Expand All @@ -723,6 +728,7 @@ def _visualize_execution_helper(
:param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL.
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:return: the graphviz object if you want to do more with it.
"""
# TODO should determine if the visualization logic should live here or in the graph.py module
Expand Down Expand Up @@ -754,6 +760,7 @@ def _visualize_execution_helper(
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand All @@ -771,6 +778,7 @@ def visualize_execution(
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes Execution.
Expand Down Expand Up @@ -800,6 +808,7 @@ def visualize_execution(
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand All @@ -817,6 +826,7 @@ def visualize_execution(
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
show_schema=show_schema,
)

@capture_function_usage
Expand Down Expand Up @@ -859,6 +869,7 @@ def display_downstream_of(
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG starting from the passed in function name(s).
Expand All @@ -879,6 +890,7 @@ def display_downstream_of(
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand All @@ -903,6 +915,7 @@ def display_downstream_of(
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand All @@ -918,6 +931,7 @@ def display_upstream_of(
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Creates a visualization of the DAG going backwards from the passed in function name(s).
Expand All @@ -938,6 +952,7 @@ def display_upstream_of(
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:return: the graphviz object if you want to do more with it.
If returned as the result in a Jupyter Notebook cell, it will render.
"""
Expand All @@ -957,6 +972,7 @@ def display_upstream_of(
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1029,6 +1045,7 @@ def visualize_path_between(
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes the path between two nodes.
Expand All @@ -1051,6 +1068,7 @@ def visualize_path_between(
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, display the schema of the DAG if nodes have schema data provided
:return: graphviz object.
:raise ValueError: if the upstream or downstream node names are not found in the graph,
or there is no path between them.
Expand Down Expand Up @@ -1108,6 +1126,7 @@ def visualize_path_between(
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1367,6 +1386,7 @@ def visualize_materialization(
orient: str = "LR",
hide_inputs: bool = False,
deduplicate_inputs: bool = False,
show_schema: bool = True,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Visualizes materialization. This helps give you a sense of how materialization
will impact the DAG.
Expand All @@ -1385,6 +1405,8 @@ def visualize_materialization(
:param hide_inputs: If True, no input nodes are displayed.
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param show_schema: If True, show the schema of the materialized nodes
if nodes have schema metadata attached.
:return: The graphviz graph, if you want to do something with it
"""
if additional_vars is None:
Expand All @@ -1409,6 +1431,7 @@ def visualize_materialization(
orient=orient,
hide_inputs=hide_inputs,
deduplicate_inputs=deduplicate_inputs,
show_schema=show_schema,
)

def validate_execution(
Expand Down
1 change: 1 addition & 0 deletions hamilton/function_modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
# Metadata-specifying decorators
tag = metadata.tag
tag_outputs = metadata.tag_outputs
schema = metadata.schema

# data quality + associated tags
check_output = validation.check_output
Expand Down
11 changes: 11 additions & 0 deletions hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,16 @@ def __init__(self, target: TargetType):
"""
super().__init__(target=target)

def validate_node(self, node_: node.Node):
"""Validates that a node is valid for this decorator. This is
not the same as validation on the function, as this is done
during node-resolution.
:param node_: Node to validate
:raises InvalidDecoratorException: if the node is not valid for this decorator
"""
pass

def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:
Expand All @@ -628,6 +638,7 @@ def transform_node(
:param fn: Function we're decorating
:return: The nodes produced by the transformation
"""
self.validate_node(node_)
return [self.decorate_node(node_)]

@classmethod
Expand Down
96 changes: 94 additions & 2 deletions hamilton/function_modifiers/metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Tuple

from hamilton import node
from hamilton import htypes, node, registry
from hamilton.function_modifiers import base

"""Decorators that attach metadata to nodes"""
Expand Down Expand Up @@ -166,3 +166,95 @@ def decorate_node(self, node_: node.Node) -> node.Node:
new_tags = node_.tags.copy()
new_tags.update(self.tag_mapping.get(node_.name, {}))
return tag(**new_tags).decorate_node(node_)


# These represent a generic schema type -- E.G. one that will
# be supported across the entire set of usable dataframe/dataset types
# Eventually we'll be integrating mappings of these into the registry,
# but for now this serves largely as a placeholder/documentation
# GENERIC_SCHEMA_TYPES = (
# "int",
# "float",
# "str",
# "bool",
# "dict",
# "list",
# "object",
# "datetime",
# "date",
# )


class SchemaOutput(tag):
def __init__(self, *fields: Tuple[str, str], target_: Optional[str] = None):
"""Initializes SchemaOutput. See docs for `@schema.output` for more details."""

tag_value = ",".join([f"{key}={value}" for key, value in fields])
super(SchemaOutput, self).__init__(
**{schema.INTERNAL_SCHEMA_OUTPUT_KEY: tag_value}, target_=target_
)

def validate_node(self, node_: node.Node):
"""Validates that the node has a return type of a registered dataframe.
:param node_: Node to validate
:raises InvalidDecoratorException: if the node does not have a return type of a registered dataframe.
"""
output_type = node_.type
available_types = registry.get_registered_dataframe_types()
for _, type_ in available_types.items():
if htypes.custom_subclass_check(output_type, type_):
return
raise base.InvalidDecoratorException(
f"Node {node_.name} has type {output_type} which is not a registered type for a dataset. "
f"Registered types are {available_types}. If you found this, either (a) ensure you have the "
f"right package installed, or (b) reach out to the team to figure out how to add yours."
)

@classmethod
def allows_multiple(cls) -> bool:
"""Currently this only applies to a single output. If it is a set of nodes with multiple outputs,
it will apply to the "final" (sink) one. We can change this if there's need."""
return False

def validate(self, fn: Callable):
"""Bypassed for now -- we have no function-level or class-level validations yet,
but this is done at `@tag`, which this inherits. We will be moving away from inheriting tag.
"""
pass


class schema:
"""Container class for schema stuff. This is purely so we can have a nice API for it -- E.G. Schema.output"""

INTERNAL_SCHEMA_OUTPUT_KEY = "hamilton.internal.schema_output"

@staticmethod
def output(*fields: Tuple[str, str], target_: Optional[str] = None) -> SchemaOutput:
"""Initializes a `@schema.output` decorator. This takes in a list of fields, which are tuples of the form
`(field_name, field_type)`. The field type must be one of the function_modifiers.SchemaTypes types.
:param target_: Target node to decorate -- if `None` it'll decorate all final nodes (E.G. sinks in the subdag),
otherwise it will decorate the specified node.
:param fields: List of fields to add to the schema. Each field is a tuple of the form `(field_name, field_type)`
This is implemented using tags, but that might change. Thus you should not
rely on the tags created by this decorator (which is why they are prefixed with `internal`).
To use this, you should decorate a node with `@schema.output`
Example usage:
.. code-block:: python
@schema.output(
("a", "int"),
("b", "float"),
("c", "str")
)
def example_schema() -> Tuple[int, float, str]:
return pd.DataFrame.from_records({"a": [1], "b": [2.0], "c": ["3"]})
Then, when drawing the DAG, the schema will be displayed as sub-elements in the node for the DAG (if `display_schema` is selected).
"""
return SchemaOutput(*fields, target_=target_)
Loading

0 comments on commit e8f83d3

Please sign in to comment.