diff --git a/data_quality.md b/data_quality.md index 38288605d..3d9ceb56b 100644 --- a/data_quality.md +++ b/data_quality.md @@ -137,7 +137,7 @@ it executes on every column that's extracted. ## Handling the results We utilize tags to index nodes that represent data quality. All data-quality related tags start with the -prefix `hamilton.data_quality`. Currently there are two: +prefix `hamilton.data_quality`. Currently, there are two: 1. `hamilton.data_quality.contains_dq_results` -- this is a boolean that tells whether a node outputs a data quality results. These are nodes that get injected when @@ -149,3 +149,58 @@ Note that these tags will not be present if the node is not related to data qual don't assume they're in every node. To query one can simply filter for all the nodes that contain these tags and access the results! + +## Configuring data quality + +While data quality decorators can be configured in code, we also allow you to configure them as part of the +`config` dictionary passed to the driver. This enables you to do the following, either on a per-node or global level: + +1. Override the importance level +2. Disable data quality + +All configuration keys have two components, joined by a `.` The first component is the prefix `data_quality`, and the second is either +`node_name` or `global`. The `node_name` component is the name of the node, which indicates that and the `global` component is the global configuration. + +The value will be a dictionary with two possible values: + +1. `importance` -- the importance level of the data quality check. Can be either "warn" or "fail" +2. `enabled` -- a boolean indicating whether the data quality check is enabled or not. + +The specific node name will take precedence, and `global` will apply after that. The information in the code +will take third place (although you are unable to disable through code aside from removing/commenting the decorator out). + + Let's look at some examples: + +```python +# This will globally disable *all* data quality checks +config = { + 'data_quality.global': { + 'enabled': False + }, +} +# This will set the importance of all decorated nodes to "warn" +config = { + 'data_quality.global': { + 'importance': 'warn' + }, +} + +# This will disable the data quality check for the node `foo` +config = { + 'data_quality.foo': { + 'enabled': False + }, +} + +# This will set the importance of the node `foo` to "warn" +config = { + 'data_quality.foo': { + 'importance': 'warn' + }, +} +``` + +Note that the node name refers to the node being decorated. In *most* cases this will be equal to the name of the function, but not in all cases. +If you have `parameterize`, you'll want to use the name of the specific node (which will correspond most likely to the name of the `target` parameter). + +Also note the precedence order. **The node-specific configuration will take precedence over the global configuration.** diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index a701ddb60..47201ef3b 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -118,7 +118,9 @@ def __call__(self, fn: Callable): setattr(fn, lifecycle_name, [self]) return fn - def required_config(self) -> Optional[List[str]]: + def required_config( + self, fn: Callable, nodes: Optional[Collection[node.Node]] + ) -> Optional[List[str]]: """Declares the required configuration keys for this decorator. Note that these configuration keys will be filtered and passed to the `configuration` parameter of the functions that this decorator uses. @@ -130,7 +132,9 @@ def required_config(self) -> Optional[List[str]]: """ return [] - def optional_config(self) -> Dict[str, Any]: + def optional_config( + self, fn: Callable, nodes: Optional[Collection[node.Node]] + ) -> Dict[str, Any]: """Declares the optional configuration keys for this decorator. These are configuration keys that can be used by the decorator, but are not required. Along with these we have *defaults*, which we will use to pass to the config. @@ -498,17 +502,24 @@ def resolve_config( return config_out -def filter_config(config: Dict[str, Any], decorator: NodeTransformLifecycle) -> Dict[str, Any]: +def filter_config( + config: Dict[str, Any], + decorator: NodeTransformLifecycle, + fn: Callable, + modifiying_nodes: Optional[Collection[node.Node]] = None, +) -> Dict[str, Any]: """Filters the config to only include the keys in config_required TODO -- break this into two so we can make it easier to test. :param config: The config to filter + :param fn: The function we're calling on + :param modifiying_nodes: The nodes this decorator is modifying (optional) :param config_required: The keys to include :param decorator: The decorator that is utilizing the configuration :return: The filtered config """ - config_required = decorator.required_config() - config_optional_with_defaults = decorator.optional_config() + config_required = decorator.required_config(fn, modifiying_nodes) + config_optional_with_defaults = decorator.optional_config(fn, modifiying_nodes) if config_required is None: # This is an out to allow for backwards compatibility for the config.resolve decorator # Note this is an internal API, but we made the config with the `resolve` parameter public @@ -548,20 +559,26 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node] """ node_resolvers = getattr(fn, NodeResolver.get_lifecycle_name(), [DefaultNodeResolver()]) for resolver in node_resolvers: - fn = resolver.resolve(fn, config=filter_config(config, resolver)) + fn = resolver.resolve(fn, config=filter_config(config, resolver, fn)) if fn is None: return [] (node_creator,) = getattr(fn, NodeCreator.get_lifecycle_name(), [DefaultNodeCreator()]) - nodes = node_creator.generate_nodes(fn, filter_config(config, node_creator)) + nodes = node_creator.generate_nodes(fn, filter_config(config, node_creator, fn)) if hasattr(fn, NodeExpander.get_lifecycle_name()): (node_expander,) = getattr(fn, NodeExpander.get_lifecycle_name(), [DefaultNodeExpander()]) - nodes = node_expander.transform_dag(nodes, filter_config(config, node_expander), fn) + nodes = node_expander.transform_dag( + nodes, filter_config(config, node_expander, fn, nodes), fn + ) node_transformers = getattr(fn, NodeTransformer.get_lifecycle_name(), []) for dag_modifier in node_transformers: - nodes = dag_modifier.transform_dag(nodes, filter_config(config, dag_modifier), fn) + nodes = dag_modifier.transform_dag( + nodes, filter_config(config, dag_modifier, fn, nodes), fn + ) node_decorators = getattr(fn, NodeDecorator.get_lifecycle_name(), [DefaultNodeDecorator()]) for node_decorator in node_decorators: - nodes = node_decorator.transform_dag(nodes, filter_config(config, node_decorator), fn) + nodes = node_decorator.transform_dag( + nodes, filter_config(config, node_decorator, fn, nodes), fn + ) return nodes diff --git a/hamilton/function_modifiers/configuration.py b/hamilton/function_modifiers/configuration.py index 0235a5963..62446af81 100644 --- a/hamilton/function_modifiers/configuration.py +++ b/hamilton/function_modifiers/configuration.py @@ -1,5 +1,6 @@ from typing import Any, Callable, Collection, Dict, List, Optional +from .. import node from . import base """Decorators that handle the configuration of a function. These can be viewed as @@ -75,11 +76,11 @@ def __init__( self.target_name = target_name self._config_used = config_used - def required_config(self) -> Optional[List[str]]: + def required_config(self, fn: Callable, nodes: Collection[node.Node]) -> Optional[List[str]]: """Nothing is currently required""" return [] # All of these can default to None - def optional_config(self) -> Dict[str, Any]: + def optional_config(self, fn: Callable, nodes: Collection[node.Node]) -> Dict[str, Any]: """Everything is optional with None as the required value""" return {key: None for key in self._config_used} diff --git a/hamilton/function_modifiers/validation.py b/hamilton/function_modifiers/validation.py index e3dbb2fb6..cbb338dfe 100644 --- a/hamilton/function_modifiers/validation.py +++ b/hamilton/function_modifiers/validation.py @@ -1,5 +1,6 @@ import abc -from typing import Any, Callable, Collection, Dict, List, Type +import dataclasses +from typing import Any, Callable, Collection, Dict, List, Optional, Type from hamilton import node from hamilton.data_quality import base as dq_base @@ -12,6 +13,30 @@ DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG = "hamilton.data_quality.source_node" +@dataclasses.dataclass +class ValidatorConfig: + should_run: bool + importance: dq_base.DataValidationLevel + + @staticmethod + def from_validator( + validator: dq_base.DataValidator, + config: Dict[str, Any], + node_name: str, + ) -> "ValidatorConfig": + global_key = "data_quality.global" + node_key = f"data_quality.{node_name}" + global_config = config.get(global_key, {}) + node_config = config.get(node_key, {}) + should_run = node_config.get("enabled", global_config.get("enabled", True)) + importance = node_config.get( + "importance", global_config.get("importance", validator.importance.value) + ) + return ValidatorConfig( + should_run=should_run, importance=dq_base.DataValidationLevel(importance) + ) + + class BaseDataValidationDecorator(base.NodeTransformer): @abc.abstractmethod def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValidator]: @@ -25,9 +50,17 @@ def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValida def transform_node( self, node_: node.Node, config: Dict[str, Any], fn: Callable ) -> Collection[node.Node]: + validators = self.get_validators(node_) + validator_configs = [ + ValidatorConfig.from_validator(validator, config, node_.name) + for validator in validators + ] + # If no validators are enabled, return the original node + if not any(validator_config.should_run for validator_config in validator_configs): + return [node_] raw_node = node.Node( - name=node_.name - + "_raw", # TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want + name=node_.name + "_raw", + # TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want typ=node_.type, doc_string=node_.documentation, callabl=node_.callable, @@ -35,10 +68,11 @@ def transform_node( input_types=node_.input_types, tags=node_.tags, ) - validators = self.get_validators(node_) validator_nodes = [] - validator_name_map = {} - for validator in validators: + validator_name_config_map = {} + for validator, validator_config in zip(validators, validator_configs): + if not validator_config.should_run: + continue def validation_function(validator_to_call: dq_base.DataValidator = validator, **kwargs): result = list(kwargs.values())[0] # This should just have one kwarg @@ -60,11 +94,11 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, ** }, }, ) - validator_name_map[validator_node_name] = validator + validator_name_config_map[validator_node_name] = (validator, validator_config) validator_nodes.append(validator_node) def final_node_callable( - validator_nodes=validator_nodes, validator_name_map=validator_name_map, **kwargs + validator_nodes=validator_nodes, validator_name_map=validator_name_config_map, **kwargs ): """Callable for the final node. First calls the action on every node, then @@ -75,11 +109,11 @@ def final_node_callable( """ failures = [] for validator_node in validator_nodes: - validator: dq_base.DataValidator = validator_name_map[validator_node.name] + validator, config = validator_name_map[validator_node.name] validation_result: dq_base.ValidationResult = kwargs[validator_node.name] - if validator.importance == dq_base.DataValidationLevel.WARN: + if config.importance == dq_base.DataValidationLevel.WARN: dq_base.act_warn(node_.name, validation_result, validator) - else: + elif config.importance == dq_base.DataValidationLevel.FAIL: failures.append((validation_result, validator)) dq_base.act_fail_bulk(node_.name, failures) return kwargs[raw_node.name] @@ -104,6 +138,21 @@ def final_node_callable( def validate(self, fn: Callable): pass + def optional_config( + self, fn: Callable, nodes: Optional[Collection[node.Node]] + ) -> Dict[str, Any]: + """Returns the configuration for the decorator. + + :param fn: Function to validate + :param nodes: Nodes to validate + :return: Configuration that gets passed to the decorator + """ + out = {"data_quality.global": {}} + if nodes is not None: + for node_ in nodes: + out[f"data_quality.{node_.name}"] = {} + return out + class check_output_custom(BaseDataValidationDecorator): """Class to use if you want to implement your own custom validators. diff --git a/tests/function_modifiers/test_validation.py b/tests/function_modifiers/test_validation.py index ffd6da143..44010789e 100644 --- a/tests/function_modifiers/test_validation.py +++ b/tests/function_modifiers/test_validation.py @@ -1,15 +1,23 @@ +from typing import Any, Dict + import numpy as np import pandas as pd import pytest -from hamilton import node -from hamilton.data_quality.base import DataValidationError, ValidationResult +from hamilton import ad_hoc_utils, driver, node +from hamilton.data_quality.base import ( + DataValidationError, + DataValidationLevel, + DataValidator, + ValidationResult, +) from hamilton.function_modifiers import ( DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG, IS_DATA_VALIDATOR_TAG, check_output, check_output_custom, ) +from hamilton.function_modifiers.validation import ValidatorConfig from hamilton.node import DependencyType from tests.resources.dq_dummy_examples import ( DUMMY_VALIDATORS_FOR_TESTING, @@ -154,3 +162,171 @@ def test_data_quality_constants_for_api_consistency(): # simple tests to test data quality constants remain the same assert IS_DATA_VALIDATOR_TAG == "hamilton.data_quality.contains_dq_results" assert DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG == "hamilton.data_quality.source_node" + + +@pytest.mark.parametrize( + "validator,config,node_name,expected_result", + [ + ( + SampleDataValidator2(0, "warn"), + {}, + "test", + ValidatorConfig(True, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "fail"), + {}, + "test", + ValidatorConfig(True, DataValidationLevel.FAIL), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.test": {"enabled": False}}, + "test", + ValidatorConfig(False, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.test": {"enabled": True}}, + "test", + ValidatorConfig(True, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "fail"), + {"data_quality.test": {"enabled": False, "importance": "warn"}}, + "test", + ValidatorConfig(False, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.global": {"enabled": False}}, + "test", + ValidatorConfig(False, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.global": {"enabled": False, "importance": "warn"}}, + "test", + ValidatorConfig(False, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "warn"), + { + "data_quality.global": {"enabled": False, "importance": "warn"}, + "data_quality.test": {"enabled": True}, + }, + "test", + ValidatorConfig(True, DataValidationLevel.WARN), + ), + ( + SampleDataValidator2(0, "warn"), + {"data_quality.global": {"enabled": True}, "data_quality.test": {"importance": "fail"}}, + "test", + ValidatorConfig(True, DataValidationLevel.FAIL), + ), + ( + SampleDataValidator2(0, "warn"), + { + "data_quality.global": {"enabled": False}, + "data_quality.test": {"enabled": True, "importance": "fail"}, + }, + "test", + ValidatorConfig(True, DataValidationLevel.FAIL), + ), + ], +) +def test_validator_config_derive( + validator: DataValidator, + config: Dict[str, Any], + node_name: str, + expected_result: ValidatorConfig, +): + assert ValidatorConfig.from_validator(validator, config, node_name) == expected_result + + +def test_validator_config_produces_no_validation_with_global(): + decorator = check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + SampleDataValidator3(dtype=np.int64, importance="fail"), + ) + + def fn(input: pd.Series) -> pd.Series: + return input + + node_ = node.Node.from_fn(fn) + config = {"data_quality.global": {"enabled": False}} + subdag = decorator.transform_node(node_, config=config, fn=fn) + assert 1 == len(subdag) + node_, *_ = subdag + assert node_.name == "fn" + # Ensure nothing's been messed with + pd.testing.assert_series_equal(node_(pd.Series([1.0, 2.0, 3.0])), pd.Series([1.0, 2.0, 3.0])) + + +def test_validator_config_produces_no_validation_with_node_override(): + decorator = check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + SampleDataValidator3(dtype=np.int64, importance="fail"), + ) + + def fn(input: pd.Series) -> pd.Series: + return input + + node_ = node.Node.from_fn(fn) + config = {f"data_quality.{node_.name}": {"enabled": False}} + subdag = decorator.transform_node(node_, config=config, fn=fn) + assert 1 == len(subdag) + node_, *_ = subdag + assert node_.name == "fn" + # Ensure nothing's been messed with + pd.testing.assert_series_equal(node_(pd.Series([1.0, 2.0, 3.0])), pd.Series([1.0, 2.0, 3.0])) + + +def test_validator_config_produces_no_validation_with_node_level_override(): + decorator = check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + SampleDataValidator3(dtype=np.int64, importance="fail"), + ) + + def fn(input: pd.Series) -> pd.Series: + return input + + node_ = node.Node.from_fn(fn) + config = {f"data_quality.{node_.name}": {"importance": "warn"}} + subdag = decorator.transform_node(node_, config=config, fn=fn) + assert 4 == len(subdag) + nodes_by_name = {n.name: n for n in subdag} + # We set this to warn so this should not break + nodes_by_name["fn"].callable( + fn_raw=pd.Series([1.0, 2.0, 3.0]), + fn_dummy_data_validator_2=ValidationResult(False, "", {}), + fn_dummy_data_validator_3=ValidationResult(False, "", {}), + ) + + +def test_data_validator_end_to_end_fails(): + @check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + ) + def data_quality_check_that_doesnt_pass() -> pd.Series: + return pd.Series([1, 2]) + + dr = driver.Driver( + {}, ad_hoc_utils.create_temporary_module(data_quality_check_that_doesnt_pass) + ) + with pytest.raises(DataValidationError): + dr.execute(final_vars=["data_quality_check_that_doesnt_pass"], inputs={}) + + +def test_data_validator_end_to_end_succeed_when_node_disabled(): + @check_output_custom( + SampleDataValidator2(dataset_length=1, importance="fail"), + ) + def data_quality_check_that_doesnt_pass() -> pd.Series: + return pd.Series([1, 2]) + + dr = driver.Driver( + {f"data_quality.{data_quality_check_that_doesnt_pass.__name__}": {"enabled": False}}, + ad_hoc_utils.create_temporary_module(data_quality_check_that_doesnt_pass), + ) + dr.execute(final_vars=["data_quality_check_that_doesnt_pass"], inputs={})