diff --git a/data_quality.md b/data_quality.md index 4fe99a6ec..924d1e05b 100644 --- a/data_quality.md +++ b/data_quality.md @@ -164,7 +164,7 @@ All configuration keys have two components, joined by a `.` The first component The value will be a dictionary with two possible values: 1. `importance` -- the importance level of the data quality check. Can be either "warn" or "fail" -2. `enable` -- a boolean indicating whether the data quality check is enabled or not. +2. `enabled` -- a boolean indicating whether the data quality check is enabled or not. The specific node name will take precedence, and `global` will apply after that. The information in the code will take third place (although you are unable to disable through code aside from removing/commenting the decorator out). @@ -175,7 +175,7 @@ will take third place (although you are unable to disable through code aside fro # This will globally disable *all* data quality checks config = { 'data_quality.global': { - 'enable': False + 'enabled': False }, } # This will set the importance of all decorated nodes to "warn" @@ -188,7 +188,7 @@ config = { # This will disable the data quality check for the node `foo` config = { 'data_quality.foo': { - 'enable': False + 'enabled': False }, } diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index a701ddb60..47201ef3b 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -118,7 +118,9 @@ def __call__(self, fn: Callable): setattr(fn, lifecycle_name, [self]) return fn - def required_config(self) -> Optional[List[str]]: + def required_config( + self, fn: Callable, nodes: Optional[Collection[node.Node]] + ) -> Optional[List[str]]: """Declares the required configuration keys for this decorator. Note that these configuration keys will be filtered and passed to the `configuration` parameter of the functions that this decorator uses. @@ -130,7 +132,9 @@ def required_config(self) -> Optional[List[str]]: """ return [] - def optional_config(self) -> Dict[str, Any]: + def optional_config( + self, fn: Callable, nodes: Optional[Collection[node.Node]] + ) -> Dict[str, Any]: """Declares the optional configuration keys for this decorator. These are configuration keys that can be used by the decorator, but are not required. Along with these we have *defaults*, which we will use to pass to the config. @@ -498,17 +502,24 @@ def resolve_config( return config_out -def filter_config(config: Dict[str, Any], decorator: NodeTransformLifecycle) -> Dict[str, Any]: +def filter_config( + config: Dict[str, Any], + decorator: NodeTransformLifecycle, + fn: Callable, + modifiying_nodes: Optional[Collection[node.Node]] = None, +) -> Dict[str, Any]: """Filters the config to only include the keys in config_required TODO -- break this into two so we can make it easier to test. :param config: The config to filter + :param fn: The function we're calling on + :param modifiying_nodes: The nodes this decorator is modifying (optional) :param config_required: The keys to include :param decorator: The decorator that is utilizing the configuration :return: The filtered config """ - config_required = decorator.required_config() - config_optional_with_defaults = decorator.optional_config() + config_required = decorator.required_config(fn, modifiying_nodes) + config_optional_with_defaults = decorator.optional_config(fn, modifiying_nodes) if config_required is None: # This is an out to allow for backwards compatibility for the config.resolve decorator # Note this is an internal API, but we made the config with the `resolve` parameter public @@ -548,20 +559,26 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node] """ node_resolvers = getattr(fn, NodeResolver.get_lifecycle_name(), [DefaultNodeResolver()]) for resolver in node_resolvers: - fn = resolver.resolve(fn, config=filter_config(config, resolver)) + fn = resolver.resolve(fn, config=filter_config(config, resolver, fn)) if fn is None: return [] (node_creator,) = getattr(fn, NodeCreator.get_lifecycle_name(), [DefaultNodeCreator()]) - nodes = node_creator.generate_nodes(fn, filter_config(config, node_creator)) + nodes = node_creator.generate_nodes(fn, filter_config(config, node_creator, fn)) if hasattr(fn, NodeExpander.get_lifecycle_name()): (node_expander,) = getattr(fn, NodeExpander.get_lifecycle_name(), [DefaultNodeExpander()]) - nodes = node_expander.transform_dag(nodes, filter_config(config, node_expander), fn) + nodes = node_expander.transform_dag( + nodes, filter_config(config, node_expander, fn, nodes), fn + ) node_transformers = getattr(fn, NodeTransformer.get_lifecycle_name(), []) for dag_modifier in node_transformers: - nodes = dag_modifier.transform_dag(nodes, filter_config(config, dag_modifier), fn) + nodes = dag_modifier.transform_dag( + nodes, filter_config(config, dag_modifier, fn, nodes), fn + ) node_decorators = getattr(fn, NodeDecorator.get_lifecycle_name(), [DefaultNodeDecorator()]) for node_decorator in node_decorators: - nodes = node_decorator.transform_dag(nodes, filter_config(config, node_decorator), fn) + nodes = node_decorator.transform_dag( + nodes, filter_config(config, node_decorator, fn, nodes), fn + ) return nodes diff --git a/hamilton/function_modifiers/configuration.py b/hamilton/function_modifiers/configuration.py index 0235a5963..62446af81 100644 --- a/hamilton/function_modifiers/configuration.py +++ b/hamilton/function_modifiers/configuration.py @@ -1,5 +1,6 @@ from typing import Any, Callable, Collection, Dict, List, Optional +from .. import node from . import base """Decorators that handle the configuration of a function. These can be viewed as @@ -75,11 +76,11 @@ def __init__( self.target_name = target_name self._config_used = config_used - def required_config(self) -> Optional[List[str]]: + def required_config(self, fn: Callable, nodes: Collection[node.Node]) -> Optional[List[str]]: """Nothing is currently required""" return [] # All of these can default to None - def optional_config(self) -> Dict[str, Any]: + def optional_config(self, fn: Callable, nodes: Collection[node.Node]) -> Dict[str, Any]: """Everything is optional with None as the required value""" return {key: None for key in self._config_used} diff --git a/hamilton/function_modifiers/validation.py b/hamilton/function_modifiers/validation.py index cee02c0a7..e7042558c 100644 --- a/hamilton/function_modifiers/validation.py +++ b/hamilton/function_modifiers/validation.py @@ -1,6 +1,6 @@ import abc import dataclasses -from typing import Any, Callable, Collection, Dict, List, Type +from typing import Any, Callable, Collection, Dict, List, Optional, Type from hamilton import node from hamilton.data_quality import base as dq_base @@ -20,9 +20,21 @@ class ValidatorConfig: @staticmethod def from_validator( - validator: dq_base.DataValidator, config: Dict[str, Any] + validator: dq_base.DataValidator, + config: Dict[str, Any], + node_name: str, ) -> "ValidatorConfig": - return ValidatorConfig(should_run=True, importance=validator.importance) + global_key = "data_quality.global" + node_key = f"data_quality.{node_name}" + global_config = config.get(global_key, {}) + node_config = config.get(node_key, {}) + should_run = global_config.get("enabled", node_config.get("enabled", True)) + importance = node_config.get( + "importance", global_config.get("importance", validator.importance.value) + ) + return ValidatorConfig( + should_run=should_run, importance=dq_base.DataValidationLevel(importance) + ) class BaseDataValidationDecorator(base.NodeTransformer): @@ -38,10 +50,17 @@ def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValida def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: - + validators = self.get_validators(node_) + validator_configs = [ + ValidatorConfig.from_validator(validator, config, node_.name) + for validator in validators + ] + # If no validators are enabled, return the original node + if not any(validator_config.should_run for validator_config in validator_configs): + return [node_] raw_node = node.Node( - name=node_.name - + "_raw", # TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want + name=node_.name + "_raw", + # TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want typ=node_.type, doc_string=node_.documentation, callabl=node_.callable, @@ -49,11 +68,9 @@ def transform_node( input_types=node_.input_types, tags=node_.tags, ) - validators = self.get_validators(node_) validator_nodes = [] validator_name_config_map = {} - for validator in validators: - validator_config = ValidatorConfig.from_validator(validator, config) + for validator, validator_config in zip(validators, validator_configs): if not validator_config.should_run: continue @@ -121,6 +138,21 @@ def final_node_callable( def validate(self, fn: Callable): pass + def optional_config( + self, fn: Callable, nodes: Optional[Collection[node.Node]] + ) -> Dict[str, Any]: + """Returns the configuration for the decorator. + + :param fn: Function to validate + :param nodes: Nodes to validate + :return: Configuration that gets passed to the decorator + """ + out = {"data_quality.global": {}} + if nodes is not None: + for node_ in nodes: + out[f"data_quality.{node_.name}"] = {} + return out + class check_output_custom(BaseDataValidationDecorator): """Class to use if you want to implement your own custom validators. diff --git a/tests/function_modifiers/test_validation.py b/tests/function_modifiers/test_validation.py index ffd6da143..d544414eb 100644 --- a/tests/function_modifiers/test_validation.py +++ b/tests/function_modifiers/test_validation.py @@ -1,15 +1,23 @@ +from typing import Any, Dict + import numpy as np import pandas as pd import pytest -from hamilton import node -from hamilton.data_quality.base import DataValidationError, ValidationResult +from hamilton import ad_hoc_utils, driver, node +from hamilton.data_quality.base import ( + DataValidationError, + DataValidationLevel, + DataValidator, + ValidationResult, +) from hamilton.function_modifiers import ( DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG, IS_DATA_VALIDATOR_TAG, check_output, check_output_custom, ) +from hamilton.function_modifiers.validation import ValidatorConfig from hamilton.node import DependencyType from tests.resources.dq_dummy_examples import ( DUMMY_VALIDATORS_FOR_TESTING, @@ -154,3 +162,147 @@ def test_data_quality_constants_for_api_consistency(): # simple tests to test data quality constants remain the same assert IS_DATA_VALIDATOR_TAG == "hamilton.data_quality.contains_dq_results" assert DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG == "hamilton.data_quality.source_node" + + +@pytest.mark.parametrize( + "validator,config,node_name,expected_result", + [ + ( + SampleDataValidator2(0, "warn"), + {}, + "test", + ValidatorConfig(True, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "fail"), + {}, + "test", + ValidatorConfig(True, DataValidationLevel.FAIL), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.test": {"enabled": False}}, + "test", + ValidatorConfig(False, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.test": {"enabled": True}}, + "test", + ValidatorConfig(True, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "fail"), + {"data_quality.test": {"enabled": False, "importance": "warn"}}, + "test", + ValidatorConfig(False, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.global": {"enabled": False}}, + "test", + ValidatorConfig(False, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.global": {"enabled": False, "importance": "warn"}}, + "test", + ValidatorConfig(False, DataValidationLevel.WARN), + ), + ], +) +def test_validator_config_derive( + validator: DataValidator, + config: Dict[str, Any], + node_name: str, + expected_result: ValidatorConfig, +): + assert ValidatorConfig.from_validator(validator, config, node_name) == expected_result + + +def test_validator_config_produces_no_validation_with_global(): + decorator = check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + SampleDataValidator3(dtype=np.int64, importance="fail"), + ) + + def fn(input: pd.Series) -> pd.Series: + return input + + node_ = node.Node.from_fn(fn) + config = {"data_quality.global": {"enabled": False}} + subdag = decorator.transform_node(node_, config=config, fn=fn) + assert 1 == len(subdag) + node_, *_ = subdag + assert node_.name == "fn" + # Ensure nothing's been messed with + pd.testing.assert_series_equal(node_(pd.Series([1.0, 2.0, 3.0])), pd.Series([1.0, 2.0, 3.0])) + + +def test_validator_config_produces_no_validation_with_node_override(): + decorator = check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + SampleDataValidator3(dtype=np.int64, importance="fail"), + ) + + def fn(input: pd.Series) -> pd.Series: + return input + + node_ = node.Node.from_fn(fn) + config = {f"data_quality.{node_.name}": {"enabled": False}} + subdag = decorator.transform_node(node_, config=config, fn=fn) + assert 1 == len(subdag) + node_, *_ = subdag + assert node_.name == "fn" + # Ensure nothing's been messed with + pd.testing.assert_series_equal(node_(pd.Series([1.0, 2.0, 3.0])), pd.Series([1.0, 2.0, 3.0])) + + +def test_validator_config_produces_no_validation_with_node_level_override(): + decorator = check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + SampleDataValidator3(dtype=np.int64, importance="fail"), + ) + + def fn(input: pd.Series) -> pd.Series: + return input + + node_ = node.Node.from_fn(fn) + config = {f"data_quality.{node_.name}": {"importance": "warn"}} + subdag = decorator.transform_node(node_, config=config, fn=fn) + assert 4 == len(subdag) + nodes_by_name = {n.name: n for n in subdag} + # We set this to warn so this should not break + nodes_by_name["fn"].callable( + fn_raw=pd.Series([1.0, 2.0, 3.0]), + fn_dummy_data_validator_2=ValidationResult(False, "", {}), + fn_dummy_data_validator_3=ValidationResult(False, "", {}), + ) + + +def test_data_validator_end_to_end_fails(): + @check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + ) + def data_quality_check_that_doesnt_pass() -> pd.Series: + return pd.Series([1, 2]) + + dr = driver.Driver( + {}, ad_hoc_utils.create_temporary_module(data_quality_check_that_doesnt_pass) + ) + with pytest.raises(DataValidationError): + dr.execute(final_vars=["data_quality_check_that_doesnt_pass"], inputs={}) + + +def test_data_validator_end_to_end_succeed_when_node_disabled(): + @check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + ) + def data_quality_check_that_doesnt_pass() -> pd.Series: + return pd.Series([1, 2]) + + dr = driver.Driver( + {f"data_quality.{data_quality_check_that_doesnt_pass.__name__}": {"enabled": False}}, + ad_hoc_utils.create_temporary_module(data_quality_check_that_doesnt_pass), + ) + dr.execute(final_vars=["data_quality_check_that_doesnt_pass"], inputs={})