diff --git a/docs/reference/decorators/index.rst b/docs/reference/decorators/index.rst index dfd818465..c389a021f 100644 --- a/docs/reference/decorators/index.rst +++ b/docs/reference/decorators/index.rst @@ -29,5 +29,6 @@ For reference we list available decorators for Hamilton here. Note: use resolve save_to subdag + schema tag with_columns diff --git a/docs/reference/decorators/schema.rst b/docs/reference/decorators/schema.rst new file mode 100644 index 000000000..85713f853 --- /dev/null +++ b/docs/reference/decorators/schema.rst @@ -0,0 +1,12 @@ +======================= +schema +======================= + +`@schema` is a function modifier that allows you to specify a schema for the +function's inputs/outputs. This can be used to validate data at runtime, visualize, etc... + + +**Reference Documentation** + +.. autoclass:: hamilton.function_modifiers.schema + :members: output diff --git a/examples/spark/pyspark_feature_catalog/example_usage.ipynb b/examples/spark/pyspark_feature_catalog/example_usage.ipynb index 4c312c06f..f75b6082b 100644 --- a/examples/spark/pyspark_feature_catalog/example_usage.ipynb +++ b/examples/spark/pyspark_feature_catalog/example_usage.ipynb @@ -12,10 +12,10 @@ "As user you only need to specify which features (groups of features) you want to compute.\n", "\n", "## Including column level lineage\n", - "By using the `@tag` decorator with the `spark_schema` key you can extend the lineage up to column level, e.g.:\n", + "By using the `@schema.output` decorator you can extend the lineage up to column level, e.g.:\n", "\n", - "```\n", - "@tag(spark_schema=\"zone:str, level:int, avatarId:int\")\n", + "```python\n", + "@schema.output((\"zone\", \"str\"), (\"level\", \"int\"), (\"avatarId\", \"int\"))\n", "```\n", "\n", "## Download data\n", @@ -33,14 +33,18 @@ "end_time": "2023-05-22T22:27:16.952106Z", "start_time": "2023-05-22T22:27:15.116241Z" }, - "collapsed": false + "collapsed": false, + "jupyter": { + "outputs_hidden": false + } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n" + "/Users/elijahbenizzy/.pyenv/versions/3.11.4/envs/hamilton-3-11-fresh/lib/python3.11/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", + " warnings.warn(\n" ] } ], @@ -101,7 +105,7 @@ "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ "WARNING:hamilton.telemetry:Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/dagworks-inc/hamilton#usage-analytics--data-privacy for details.\n" @@ -113,245 +117,245 @@ "\n", "\n", - "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "cluster_with_flags\n", - "\n", - "\n", - "\n", - "cluster_level_info\n", - "\n", - "\n", - "\n", - "cluster_zone_counts\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "cluster_world_of_warcraft\n", - "\n", + "\n", "\n", "\n", "cluster__legend\n", - "\n", - "Legend\n", + "\n", + "Legend\n", "\n", - "\n", - "\n", - "with_flags\n", - "\n", - "with_flags\n", - "DataFrame\n", + "\n", + "cluster_zone_counts\n", + "\n", + "\n", + "\n", + "cluster_level_info\n", + "\n", + "\n", + "\n", + "cluster_with_flags\n", + "\n", "\n", "\n", - "\n", + "\n", "zone_counts\n", - "\n", - "zone_counts\n", - "DataFrame\n", + "\n", + "zone_counts\n", + "DataFrame\n", "\n", - "\n", - "\n", - "with_flags->zone_counts\n", - "\n", - "\n", - "\n", - "\n", + "\n", "\n", - "zone: str          \n", - "\n", - "zone: str          \n", + "total_count=int    \n", + "\n", + "total_count=int    \n", "\n", - "\n", + "\n", "\n", - "level: int         \n", - "\n", - "level: int         \n", + "darkshore_count=int\n", + "\n", + "darkshore_count=int\n", "\n", - "\n", + "\n", "\n", - "avatarId: int      \n", - "\n", - "avatarId: int      \n", - "\n", - "\n", - "\n", - "darkshore_flag: int\n", - "\n", - "darkshore_flag: int\n", - "\n", - "\n", - "\n", - "durotar_flag: int  \n", - "\n", - "durotar_flag: int  \n", - "\n", - "\n", - "\n", - "level_info\n", - "\n", - "level_info\n", - "DataFrame\n", - "\n", - "\n", - "\n", - "mean_level: float\n", - "\n", - "mean_level: float\n", - "\n", - "\n", - "\n", - "with_flags.durotar_flag\n", - "\n", - "with_flags.durotar_flag\n", - "DataFrame\n", - "\n", - "\n", - "\n", - "with_flags.durotar_flag->with_flags\n", - "\n", - "\n", + "durotar_count=int  \n", + "\n", + "durotar_count=int  \n", "\n", "\n", - "\n", + "\n", "spark_session\n", - "\n", - "spark_session\n", - "SparkSession\n", + "\n", + "spark_session\n", + "SparkSession\n", "\n", "\n", "\n", "world_of_warcraft\n", - "\n", - "world_of_warcraft\n", - "DataFrame\n", + "\n", + "world_of_warcraft\n", + "DataFrame\n", "\n", "\n", "\n", "spark_session->world_of_warcraft\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", - "\n", - "total_count: int    \n", - "\n", - "total_count: int    \n", + "\n", + "\n", + "with_flags.durotar_flag\n", + "\n", + "with_flags.durotar_flag\n", + "DataFrame\n", "\n", - "\n", - "\n", - "darkshore_count: int\n", - "\n", - "darkshore_count: int\n", + "\n", + "\n", + "with_flags\n", + "\n", + "with_flags\n", + "DataFrame\n", "\n", - "\n", - "\n", - "durotar_count: int  \n", - "\n", - "durotar_count: int  \n", + "\n", + "\n", + "with_flags.durotar_flag->with_flags\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "with_flags.darkshore_flag\n", - "\n", - "with_flags.darkshore_flag\n", - "DataFrame\n", + "\n", + "with_flags.darkshore_flag\n", + "DataFrame\n", "\n", "\n", - "\n", + "\n", "with_flags.darkshore_flag->with_flags.durotar_flag\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", - "\n", - "world_of_warcraft->level_info\n", - "\n", - "\n", + "\n", + "\n", + "level_info\n", + "\n", + "level_info\n", + "DataFrame\n", + "\n", + "\n", + "\n", + "mean_level=float\n", + "\n", + "mean_level=float\n", + "\n", + "\n", + "\n", + "with_flags->zone_counts\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "zone=str          \n", + "\n", + "zone=str          \n", + "\n", + "\n", + "\n", + "level=int         \n", + "\n", + "level=int         \n", + "\n", + "\n", + "\n", + "avatarId=int      \n", + "\n", + "avatarId=int      \n", + "\n", + "\n", + "\n", + "darkshore_flag=int\n", + "\n", + "darkshore_flag=int\n", + "\n", + "\n", + "\n", + "durotar_flag=int  \n", + "\n", + "durotar_flag=int  \n", "\n", "\n", - "\n", + "\n", "world_of_warcraft->with_flags.durotar_flag\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "world_of_warcraft->with_flags.darkshore_flag\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", + "\n", + "\n", + "world_of_warcraft->level_info\n", + "\n", + "\n", + "\n", + "\n", "\n", - "zone: str    \n", - "\n", - "zone: str    \n", + "zone=str    \n", + "\n", + "zone=str    \n", "\n", - "\n", + "\n", "\n", - "level: int   \n", - "\n", - "level: int   \n", + "level=int   \n", + "\n", + "level=int   \n", "\n", - "\n", + "\n", "\n", - "avatarId: int\n", - "\n", - "avatarId: int\n", + "avatarId=int\n", + "\n", + "avatarId=int\n", "\n", - "\n", + "\n", "\n", - "_level_info_inputs\n", - "\n", - "aggregation_level\n", - "str\n", + "_zone_counts_inputs\n", + "\n", + "aggregation_level\n", + "str\n", "\n", - "\n", - "\n", - "_level_info_inputs->level_info\n", - "\n", - "\n", + "\n", + "\n", + "_zone_counts_inputs->zone_counts\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "_level_info_inputs->zone_counts\n", - "\n", - "\n", + "_zone_counts_inputs->level_info\n", + "\n", + "\n", "\n", "\n", "\n", "input\n", - "\n", - "input\n", + "\n", + "input\n", "\n", "\n", "\n", "function\n", - "\n", - "function\n", + "\n", + "function\n", "\n", "\n", "\n", "cluster\n", - "\n", - "cluster\n", + "\n", + "cluster\n", "\n", - "\n", + "\n", "\n", - "column\n", - "\n", - "column\n", + "field\n", + "\n", + "field\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -376,77 +380,16 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Setting default log level to \"WARN\".\n", - "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "23/12/22 15:52:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n", - "23/12/22 15:52:32 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+--------+-----------+---------------+-------------+------------------+\n", - "|avatarId|total_count|darkshore_count|durotar_count| mean_level|\n", - "+--------+-----------+---------------+-------------+------------------+\n", - "| 29| 5597| 3| 101| 58.09862426299804|\n", - "| 1806| 477| 0| 4|20.041928721174003|\n", - "| 9233| 1611| 0| 10| 50.16139044072005|\n", - "| 2040| 748| 0| 16|18.259358288770052|\n", - "| 26| 2103| 0| 40|55.342368045649074|\n", - "| 10422| 217| 0| 47| 14.32258064516129|\n", - "| 9978| 53| 0| 9| 7.490566037735849|\n", - "| 5385| 370| 0| 5|15.062162162162162|\n", - "| 4823| 45| 0| 32| 8.8|\n", - "| 9458| 934| 0| 90|28.927194860813703|\n", - "| 474| 272| 2| 13|22.327205882352942|\n", - "| 2453| 3359| 4| 105| 45.28966954450729|\n", - "| 3764| 43| 0| 0| 7.511627906976744|\n", - "| 5409| 3| 0| 0| 1.0|\n", - "| 6721| 3| 0| 0|1.6666666666666667|\n", - "| 4590| 3| 0| 3|2.6666666666666665|\n", - "| 5556| 46| 0| 46| 5.434782608695652|\n", - "| 4894| 5| 0| 5| 5.4|\n", - "| 1950| 61| 0| 0| 8.459016393442623|\n", - "| 3506| 1| 0| 0| 2.0|\n", - "+--------+-----------+---------------+-------------+------------------+\n", - "only showing top 20 rows\n", - "\n" - ] - } - ], + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], "source": [ "feature_groups_of_interest = [\"level_info\", \"zone_counts\"]\n", "features = dr.execute(feature_groups_of_interest, inputs={\"aggregation_level\": aggregation_level})\n", "features.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -465,9 +408,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.11.4" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/examples/spark/pyspark_feature_catalog/features.py b/examples/spark/pyspark_feature_catalog/features.py index 06a19c2ad..d4b14d7c4 100644 --- a/examples/spark/pyspark_feature_catalog/features.py +++ b/examples/spark/pyspark_feature_catalog/features.py @@ -2,28 +2,28 @@ from pyspark.sql import functions as sf from with_columns import darkshore_flag, durotar_flag -from hamilton.function_modifiers import tag +from hamilton.function_modifiers import schema from hamilton.plugins.h_spark import with_columns -WORLD_OF_WARCRAFT__SCHEMA = "zone:str, level:int, avatarId:int" +WORLD_OF_WARCRAFT_SCHEMA = (("zone", "str"), ("level", "int"), ("avatarId", "int")) def spark_session() -> ps.SparkSession: return ps.SparkSession.builder.master("local[1]").getOrCreate() -@tag(spark_schema=WORLD_OF_WARCRAFT__SCHEMA) +@schema.output(*WORLD_OF_WARCRAFT_SCHEMA) def world_of_warcraft(spark_session: ps.SparkSession) -> ps.DataFrame: return spark_session.read.parquet("data/wow.parquet") @with_columns(darkshore_flag, durotar_flag, columns_to_pass=["zone"]) -@tag(spark_schema=WORLD_OF_WARCRAFT__SCHEMA + ", darkshore_flag:int, durotar_flag:int") +@schema.output(*WORLD_OF_WARCRAFT_SCHEMA, ("darkshore_flag", "int"), ("durotar_flag", "int")) def with_flags(world_of_warcraft: ps.DataFrame) -> ps.DataFrame: return world_of_warcraft -@tag(spark_schema="total_count:int, darkshore_count:int, durotar_count:int") +@schema.output(("total_count", "int"), ("darkshore_count", "int"), ("durotar_count", "int")) def zone_counts(with_flags: ps.DataFrame, aggregation_level: str) -> ps.DataFrame: return with_flags.groupby(aggregation_level).agg( sf.count("*").alias("total_count"), @@ -32,7 +32,7 @@ def zone_counts(with_flags: ps.DataFrame, aggregation_level: str) -> ps.DataFram ) -@tag(spark_schema="mean_level:float") +@schema.output(("mean_level", "float")) def level_info(world_of_warcraft: ps.DataFrame, aggregation_level: str) -> ps.DataFrame: return world_of_warcraft.groupby(aggregation_level).agg( sf.mean("level").alias("mean_level"), diff --git a/hamilton/driver.py b/hamilton/driver.py index 33969dbdd..f1dbe3151 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -663,6 +663,7 @@ def display_all_functions( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + show_schema: bool = True, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Displays the graph of all functions loaded! @@ -681,6 +682,8 @@ def display_all_functions( :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. Can improve readability depending on the specifics of the DAG. + :param show_schema: If True, display the schema of the DAG if + the nodes have schema data provided :return: the graphviz object if you want to do more with it. If returned as the result in a Jupyter Notebook cell, it will render. """ @@ -693,6 +696,7 @@ def display_all_functions( orient=orient, hide_inputs=hide_inputs, deduplicate_inputs=deduplicate_inputs, + display_fields=show_schema, ) except ImportError as e: logger.warning(f"Unable to import {e}", exc_info=True) @@ -711,6 +715,7 @@ def _visualize_execution_helper( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + show_schema: bool = True, ): """Helper function to visualize execution, using a passed-in function graph. @@ -723,6 +728,7 @@ def _visualize_execution_helper( :param orient: `LR` stands for "left to right". Accepted values are TB, LR, BT, RL. :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. + :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :return: the graphviz object if you want to do more with it. """ # TODO should determine if the visualization logic should live here or in the graph.py module @@ -754,6 +760,7 @@ def _visualize_execution_helper( orient=orient, hide_inputs=hide_inputs, deduplicate_inputs=deduplicate_inputs, + display_fields=show_schema, ) except ImportError as e: logger.warning(f"Unable to import {e}", exc_info=True) @@ -771,6 +778,7 @@ def visualize_execution( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + show_schema: bool = True, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Visualizes Execution. @@ -800,6 +808,7 @@ def visualize_execution( :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. Can improve readability depending on the specifics of the DAG. + :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :return: the graphviz object if you want to do more with it. If returned as the result in a Jupyter Notebook cell, it will render. """ @@ -817,6 +826,7 @@ def visualize_execution( orient=orient, hide_inputs=hide_inputs, deduplicate_inputs=deduplicate_inputs, + show_schema=show_schema, ) @capture_function_usage @@ -859,6 +869,7 @@ def display_downstream_of( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + show_schema: bool = True, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Creates a visualization of the DAG starting from the passed in function name(s). @@ -879,6 +890,7 @@ def display_downstream_of( :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. Can improve readability depending on the specifics of the DAG. + :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :return: the graphviz object if you want to do more with it. If returned as the result in a Jupyter Notebook cell, it will render. """ @@ -903,6 +915,7 @@ def display_downstream_of( orient=orient, hide_inputs=hide_inputs, deduplicate_inputs=deduplicate_inputs, + display_fields=show_schema, ) except ImportError as e: logger.warning(f"Unable to import {e}", exc_info=True) @@ -918,6 +931,7 @@ def display_upstream_of( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + show_schema: bool = True, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Creates a visualization of the DAG going backwards from the passed in function name(s). @@ -938,6 +952,7 @@ def display_upstream_of( :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. Can improve readability depending on the specifics of the DAG. + :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :return: the graphviz object if you want to do more with it. If returned as the result in a Jupyter Notebook cell, it will render. """ @@ -957,6 +972,7 @@ def display_upstream_of( orient=orient, hide_inputs=hide_inputs, deduplicate_inputs=deduplicate_inputs, + display_fields=show_schema, ) except ImportError as e: logger.warning(f"Unable to import {e}", exc_info=True) @@ -1029,6 +1045,7 @@ def visualize_path_between( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + show_schema: bool = True, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Visualizes the path between two nodes. @@ -1051,6 +1068,7 @@ def visualize_path_between( :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. Can improve readability depending on the specifics of the DAG. + :param show_schema: If True, display the schema of the DAG if nodes have schema data provided :return: graphviz object. :raise ValueError: if the upstream or downstream node names are not found in the graph, or there is no path between them. @@ -1108,6 +1126,7 @@ def visualize_path_between( orient=orient, hide_inputs=hide_inputs, deduplicate_inputs=deduplicate_inputs, + display_fields=show_schema, ) except ImportError as e: logger.warning(f"Unable to import {e}", exc_info=True) @@ -1367,6 +1386,7 @@ def visualize_materialization( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + show_schema: bool = True, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Visualizes materialization. This helps give you a sense of how materialization will impact the DAG. @@ -1385,6 +1405,8 @@ def visualize_materialization( :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. Can improve readability depending on the specifics of the DAG. + :param show_schema: If True, show the schema of the materialized nodes + if nodes have schema metadata attached. :return: The graphviz graph, if you want to do something with it """ if additional_vars is None: @@ -1409,6 +1431,7 @@ def visualize_materialization( orient=orient, hide_inputs=hide_inputs, deduplicate_inputs=deduplicate_inputs, + show_schema=show_schema, ) def validate_execution( diff --git a/hamilton/function_modifiers/__init__.py b/hamilton/function_modifiers/__init__.py index 058e4bcef..2bfcc779b 100644 --- a/hamilton/function_modifiers/__init__.py +++ b/hamilton/function_modifiers/__init__.py @@ -70,6 +70,7 @@ # Metadata-specifying decorators tag = metadata.tag tag_outputs = metadata.tag_outputs +schema = metadata.schema # data quality + associated tags check_output = validation.check_output diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index 645128738..ec2ddb5d2 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -618,6 +618,16 @@ def __init__(self, target: TargetType): """ super().__init__(target=target) + def validate_node(self, node_: node.Node): + """Validates that a node is valid for this decorator. This is + not the same as validation on the function, as this is done + during node-resolution. + + :param node_: Node to validate + :raises InvalidDecoratorException: if the node is not valid for this decorator + """ + pass + def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: @@ -628,6 +638,7 @@ def transform_node( :param fn: Function we're decorating :return: The nodes produced by the transformation """ + self.validate_node(node_) return [self.decorate_node(node_)] @classmethod diff --git a/hamilton/function_modifiers/metadata.py b/hamilton/function_modifiers/metadata.py index 9e74dfd8c..33451b2e7 100644 --- a/hamilton/function_modifiers/metadata.py +++ b/hamilton/function_modifiers/metadata.py @@ -1,6 +1,6 @@ -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional, Tuple -from hamilton import node +from hamilton import htypes, node, registry from hamilton.function_modifiers import base """Decorators that attach metadata to nodes""" @@ -166,3 +166,95 @@ def decorate_node(self, node_: node.Node) -> node.Node: new_tags = node_.tags.copy() new_tags.update(self.tag_mapping.get(node_.name, {})) return tag(**new_tags).decorate_node(node_) + + +# These represent a generic schema type -- E.G. one that will +# be supported across the entire set of usable dataframe/dataset types +# Eventually we'll be integrating mappings of these into the registry, +# but for now this serves largely as a placeholder/documentation +# GENERIC_SCHEMA_TYPES = ( +# "int", +# "float", +# "str", +# "bool", +# "dict", +# "list", +# "object", +# "datetime", +# "date", +# ) + + +class SchemaOutput(tag): + def __init__(self, *fields: Tuple[str, str], target_: Optional[str] = None): + """Initializes SchemaOutput. See docs for `@schema.output` for more details.""" + + tag_value = ",".join([f"{key}={value}" for key, value in fields]) + super(SchemaOutput, self).__init__( + **{schema.INTERNAL_SCHEMA_OUTPUT_KEY: tag_value}, target_=target_ + ) + + def validate_node(self, node_: node.Node): + """Validates that the node has a return type of a registered dataframe. + + :param node_: Node to validate + :raises InvalidDecoratorException: if the node does not have a return type of a registered dataframe. + """ + output_type = node_.type + available_types = registry.get_registered_dataframe_types() + for _, type_ in available_types.items(): + if htypes.custom_subclass_check(output_type, type_): + return + raise base.InvalidDecoratorException( + f"Node {node_.name} has type {output_type} which is not a registered type for a dataset. " + f"Registered types are {available_types}. If you found this, either (a) ensure you have the " + f"right package installed, or (b) reach out to the team to figure out how to add yours." + ) + + @classmethod + def allows_multiple(cls) -> bool: + """Currently this only applies to a single output. If it is a set of nodes with multiple outputs, + it will apply to the "final" (sink) one. We can change this if there's need.""" + return False + + def validate(self, fn: Callable): + """Bypassed for now -- we have no function-level or class-level validations yet, + but this is done at `@tag`, which this inherits. We will be moving away from inheriting tag. + """ + pass + + +class schema: + """Container class for schema stuff. This is purely so we can have a nice API for it -- E.G. Schema.output""" + + INTERNAL_SCHEMA_OUTPUT_KEY = "hamilton.internal.schema_output" + + @staticmethod + def output(*fields: Tuple[str, str], target_: Optional[str] = None) -> SchemaOutput: + """Initializes a `@schema.output` decorator. This takes in a list of fields, which are tuples of the form + `(field_name, field_type)`. The field type must be one of the function_modifiers.SchemaTypes types. + + :param target_: Target node to decorate -- if `None` it'll decorate all final nodes (E.G. sinks in the subdag), + otherwise it will decorate the specified node. + :param fields: List of fields to add to the schema. Each field is a tuple of the form `(field_name, field_type)` + + This is implemented using tags, but that might change. Thus you should not + rely on the tags created by this decorator (which is why they are prefixed with `internal`). + + To use this, you should decorate a node with `@schema.output` + + Example usage: + + .. code-block:: python + + @schema.output( + ("a", "int"), + ("b", "float"), + ("c", "str") + ) + def example_schema() -> Tuple[int, float, str]: + return pd.DataFrame.from_records({"a": [1], "b": [2.0], "c": ["3"]}) + + Then, when drawing the DAG, the schema will be displayed as sub-elements in the node for the DAG (if `display_schema` is selected). + """ + return SchemaOutput(*fields, target_=target_) diff --git a/hamilton/graph.py b/hamilton/graph.py index b27e1aae0..d0b001094 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -15,6 +15,7 @@ from hamilton.execution import graph_functions from hamilton.execution.graph_functions import combine_config_and_inputs, execute_subdag from hamilton.function_modifiers import base as fm_base +from hamilton.function_modifiers.metadata import schema from hamilton.graph_utils import find_functions from hamilton.htypes import get_type_as_string, types_match from hamilton.lifecycle.base import LifecycleAdapterSet @@ -189,6 +190,7 @@ def create_graphviz_graph( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + display_fields: bool = True, ) -> "graphviz.Digraph": # noqa: F821 """Helper function to create a graphviz graph. @@ -311,7 +313,7 @@ def _get_function_modifier_style(modifier: str) -> Dict[str, str]: modifier_style = dict(style="filled,diagonals") elif modifier == "materializer": modifier_style = dict(shape="cylinder") - elif modifier == "column": + elif modifier == "field": modifier_style = dict(fillcolor="#c8dae0", fontname="Courier") elif modifier == "cluster": modifier_style = dict( @@ -359,7 +361,7 @@ def _get_legend(node_types: Set[str]): "input", "function", "cluster", - "column", + "field", "output", "materializer", "override", @@ -444,13 +446,14 @@ def _get_legend(node_types: Set[str]): digraph.node(n.name, label=label, **node_style) - if n.tags.get("spark_schema"): - # When a node is tagged with spark_schema -> add a cluster with a node for each column + # only do field-level visualization if there's a schema specified and we want to display it + if n.tags.get(schema.INTERNAL_SCHEMA_OUTPUT_KEY) and display_fields: + # When a node has attached schema data -> add a cluster with a node for each field seen_node_types.add("cluster") - seen_node_types.add("column") + seen_node_types.add("field") - def _create_equal_length_cols(spark_schema_tag: str) -> list[str]: - cols = spark_schema_tag.split(",") + def _create_equal_length_cols(schema_tag: str) -> List[str]: + cols = schema_tag.split(",") for i in range(len(cols)): def _insert_space_after_colon(col: str) -> str: @@ -472,8 +475,8 @@ def _insert_space_after_colon(col: str) -> str: "margin": "10", } ) - column_node_style = node_style.copy() - column_node_style.update( + field_node_style = node_style.copy() + field_node_style.update( { "fillcolor": "#c8dae0", "fontname": "Courier", @@ -482,9 +485,9 @@ def _insert_space_after_colon(col: str) -> str: ) with digraph.subgraph(name="cluster_" + n.name) as c: c.attr(**cluster_node_style) - cols = _create_equal_length_cols(n.tags.get("spark_schema")) + cols = _create_equal_length_cols(n.tags.get(schema.INTERNAL_SCHEMA_OUTPUT_KEY)) for i in range(len(cols)): - c.node(cols[i], **column_node_style) + c.node(cols[i], **field_node_style) c.node(n.name) # create edges @@ -653,6 +656,7 @@ def display_all( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + display_fields: bool = True, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Displays & saves a dot file of the entire DAG structure constructed. @@ -668,6 +672,7 @@ def display_all( :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. Can improve readability depending on the specifics of the DAG. + :param display_fields: If True, display fields in the graph if node has attached schema metadata :return: the graphviz graph object if it was created. None if not. """ all_nodes = set() @@ -690,6 +695,7 @@ def display_all( orient=orient, hide_inputs=hide_inputs, deduplicate_inputs=deduplicate_inputs, + display_fields=display_fields, ) def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool: @@ -733,6 +739,7 @@ def display( orient: str = "LR", hide_inputs: bool = False, deduplicate_inputs: bool = False, + display_fields: bool = True, ) -> Optional["graphviz.Digraph"]: # noqa F821 """Function to display the graph represented by the passed in nodes. @@ -752,6 +759,8 @@ def display( :param hide_inputs: If True, no input nodes are displayed. :param deduplicate_inputs: If True, remove duplicate input nodes. Can improve readability depending on the specifics of the DAG. + :param display_fields: If True, display fields in the graph if node has attached + schema metadata :return: the graphviz graph object if it was created. None if not. """ # Check to see if optional dependencies have been installed. @@ -777,6 +786,7 @@ def display( orient, hide_inputs, deduplicate_inputs, + display_fields=display_fields, ) kwargs = {"view": True} if render_kwargs and isinstance(render_kwargs, dict): diff --git a/tests/function_modifiers/test_metadata.py b/tests/function_modifiers/test_metadata.py index 17fb70125..225dfdbb6 100644 --- a/tests/function_modifiers/test_metadata.py +++ b/tests/function_modifiers/test_metadata.py @@ -2,6 +2,7 @@ import pytest from hamilton import function_modifiers, node +from hamilton.function_modifiers import base as fm_base def test_tags(): @@ -162,3 +163,32 @@ def data() -> pd.DataFrame: assert node_map["a"].tags["target"] == "column" assert node_map["b"].tags["target"] == "column" assert node_map["data"].tags.get("target") is None + + +def test_decorate_node_with_schema_output(): + # quick test to decorate node with schemas + # this tests an internal implementation, so we will likely change + # in the future, but we'll want to keep the same behavior for now + @function_modifiers.schema.output(("foo", "int"), ("bar", "float"), ("baz", "str")) + def foo() -> pd.DataFrame: + return pd.DataFrame.from_records([{"foo": 1, "bar": 2.0, "baz": "3"}]) + + nodes = function_modifiers.base.resolve_nodes(foo, {}) + node_map = {node_.name: node_ for node_ in nodes} + node_ = node_map["foo"] + assert ( + node_.tags[function_modifiers.schema.INTERNAL_SCHEMA_OUTPUT_KEY] + == "foo=int,bar=float,baz=str" + ) + + +def test_decorate_node_with_schema_output_invalid_type(): + # quick test to decorate node with schemas + # this tests an internal implementation, so we will likely change + # in the future, but we'll want to keep the same behavior for now + @function_modifiers.schema.output(("foo", "int"), ("bar", "float"), ("baz", "str")) + def foo() -> int: # int has no columns/fields + return 10 + + with pytest.raises(fm_base.InvalidDecoratorException): + function_modifiers.base.resolve_nodes(foo, {}) diff --git a/tests/test_graph.py b/tests/test_graph.py index def83ba51..f50c71889 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -10,6 +10,7 @@ import hamilton.htypes from hamilton import ad_hoc_utils, base, graph, node from hamilton.execution import graph_functions +from hamilton.function_modifiers import schema from hamilton.lifecycle import base as lifecycle_base from hamilton.node import NodeType @@ -774,7 +775,7 @@ def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path): assert f"rankdir={orient}" in dot -@pytest.mark.parametrize("hide_inputs", [(True), (False)]) +@pytest.mark.parametrize("hide_inputs", [(True,), (False,)]) def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path): dot_file_path = tmp_path / "dag.dot" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2}) @@ -807,6 +808,36 @@ def test_function_graph_display_without_saving(): assert isinstance(digraph, graphviz.Digraph) +@pytest.mark.parametrize("display_fields", [(True,), (False,)]) +def test_function_graph_display_fields(display_fields: bool, tmp_path: pathlib.Path): + dot_file_path = tmp_path / "dag.dot" + + @schema.output(("foo", "int"), ("bar", "float"), ("baz", "str")) + def df_with_schema() -> pd.DataFrame: + pass + + mod = ad_hoc_utils.create_temporary_module(df_with_schema) + fg = graph.FunctionGraph.from_modules(mod, config={}) + + fg.display( + set(fg.get_nodes()), + output_file_path=str(dot_file_path), + render_kwargs={"view": False}, + display_fields=display_fields, + ) + dot_lines = dot_file_path.open("r").readlines() + if display_fields: + assert any("foo" in line for line in dot_lines) + assert any("bar" in line for line in dot_lines) + assert any("baz" in line for line in dot_lines) + assert any("cluster" in line for line in dot_lines) + else: + assert not any("foo" in line for line in dot_lines) + assert not any("bar" in line for line in dot_lines) + assert not any("baz" in line for line in dot_lines) + assert not any("cluster" in line for line in dot_lines) + + def test_create_graphviz_graph(): """Tests that we create a graphviz graph""" fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})