Skip to content

Commit

Permalink
Adds basic ability to configure data quality
Browse files Browse the repository at this point in the history
This works both globally and locally. Currently its only at the node
level, but it shouldn't be to hard to disable specific validators/set
warnings on them.
  • Loading branch information
elijahbenizzy committed Feb 28, 2023
1 parent e5f7839 commit 64fecf2
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 26 deletions.
6 changes: 3 additions & 3 deletions data_quality.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ All configuration keys have two components, joined by a `.` The first component
The value will be a dictionary with two possible values:

1. `importance` -- the importance level of the data quality check. Can be either "warn" or "fail"
2. `enable` -- a boolean indicating whether the data quality check is enabled or not.
2. `enabled` -- a boolean indicating whether the data quality check is enabled or not.

The specific node name will take precedence, and `global` will apply after that. The information in the code
will take third place (although you are unable to disable through code aside from removing/commenting the decorator out).
Expand All @@ -175,7 +175,7 @@ will take third place (although you are unable to disable through code aside fro
# This will globally disable *all* data quality checks
config = {
'data_quality.global': {
'enable': False
'enabled': False
},
}
# This will set the importance of all decorated nodes to "warn"
Expand All @@ -188,7 +188,7 @@ config = {
# This will disable the data quality check for the node `foo`
config = {
'data_quality.foo': {
'enable': False
'enabled': False
},
}

Expand Down
37 changes: 27 additions & 10 deletions hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def __call__(self, fn: Callable):
setattr(fn, lifecycle_name, [self])
return fn

def required_config(self) -> Optional[List[str]]:
def required_config(
self, fn: Callable, nodes: Optional[Collection[node.Node]]
) -> Optional[List[str]]:
"""Declares the required configuration keys for this decorator.
Note that these configuration keys will be filtered and passed to the `configuration`
parameter of the functions that this decorator uses.
Expand All @@ -130,7 +132,9 @@ def required_config(self) -> Optional[List[str]]:
"""
return []

def optional_config(self) -> Dict[str, Any]:
def optional_config(
self, fn: Callable, nodes: Optional[Collection[node.Node]]
) -> Dict[str, Any]:
"""Declares the optional configuration keys for this decorator.
These are configuration keys that can be used by the decorator, but are not required.
Along with these we have *defaults*, which we will use to pass to the config.
Expand Down Expand Up @@ -498,17 +502,24 @@ def resolve_config(
return config_out


def filter_config(config: Dict[str, Any], decorator: NodeTransformLifecycle) -> Dict[str, Any]:
def filter_config(
config: Dict[str, Any],
decorator: NodeTransformLifecycle,
fn: Callable,
modifiying_nodes: Optional[Collection[node.Node]] = None,
) -> Dict[str, Any]:
"""Filters the config to only include the keys in config_required
TODO -- break this into two so we can make it easier to test.
:param config: The config to filter
:param fn: The function we're calling on
:param modifiying_nodes: The nodes this decorator is modifying (optional)
:param config_required: The keys to include
:param decorator: The decorator that is utilizing the configuration
:return: The filtered config
"""
config_required = decorator.required_config()
config_optional_with_defaults = decorator.optional_config()
config_required = decorator.required_config(fn, modifiying_nodes)
config_optional_with_defaults = decorator.optional_config(fn, modifiying_nodes)
if config_required is None:
# This is an out to allow for backwards compatibility for the config.resolve decorator
# Note this is an internal API, but we made the config with the `resolve` parameter public
Expand Down Expand Up @@ -548,20 +559,26 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]
"""
node_resolvers = getattr(fn, NodeResolver.get_lifecycle_name(), [DefaultNodeResolver()])
for resolver in node_resolvers:
fn = resolver.resolve(fn, config=filter_config(config, resolver))
fn = resolver.resolve(fn, config=filter_config(config, resolver, fn))
if fn is None:
return []
(node_creator,) = getattr(fn, NodeCreator.get_lifecycle_name(), [DefaultNodeCreator()])
nodes = node_creator.generate_nodes(fn, filter_config(config, node_creator))
nodes = node_creator.generate_nodes(fn, filter_config(config, node_creator, fn))
if hasattr(fn, NodeExpander.get_lifecycle_name()):
(node_expander,) = getattr(fn, NodeExpander.get_lifecycle_name(), [DefaultNodeExpander()])
nodes = node_expander.transform_dag(nodes, filter_config(config, node_expander), fn)
nodes = node_expander.transform_dag(
nodes, filter_config(config, node_expander, fn, nodes), fn
)
node_transformers = getattr(fn, NodeTransformer.get_lifecycle_name(), [])
for dag_modifier in node_transformers:
nodes = dag_modifier.transform_dag(nodes, filter_config(config, dag_modifier), fn)
nodes = dag_modifier.transform_dag(
nodes, filter_config(config, dag_modifier, fn, nodes), fn
)
node_decorators = getattr(fn, NodeDecorator.get_lifecycle_name(), [DefaultNodeDecorator()])
for node_decorator in node_decorators:
nodes = node_decorator.transform_dag(nodes, filter_config(config, node_decorator), fn)
nodes = node_decorator.transform_dag(
nodes, filter_config(config, node_decorator, fn, nodes), fn
)
return nodes


Expand Down
5 changes: 3 additions & 2 deletions hamilton/function_modifiers/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Collection, Dict, List, Optional

from .. import node
from . import base

"""Decorators that handle the configuration of a function. These can be viewed as
Expand Down Expand Up @@ -75,11 +76,11 @@ def __init__(
self.target_name = target_name
self._config_used = config_used

def required_config(self) -> Optional[List[str]]:
def required_config(self, fn: Callable, nodes: Collection[node.Node]) -> Optional[List[str]]:
"""Nothing is currently required"""
return [] # All of these can default to None

def optional_config(self) -> Dict[str, Any]:
def optional_config(self, fn: Callable, nodes: Collection[node.Node]) -> Dict[str, Any]:
"""Everything is optional with None as the required value"""
return {key: None for key in self._config_used}

Expand Down
50 changes: 41 additions & 9 deletions hamilton/function_modifiers/validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import dataclasses
from typing import Any, Callable, Collection, Dict, List, Type
from typing import Any, Callable, Collection, Dict, List, Optional, Type

from hamilton import node
from hamilton.data_quality import base as dq_base
Expand All @@ -20,9 +20,21 @@ class ValidatorConfig:

@staticmethod
def from_validator(
validator: dq_base.DataValidator, config: Dict[str, Any]
validator: dq_base.DataValidator,
config: Dict[str, Any],
node_name: str,
) -> "ValidatorConfig":
return ValidatorConfig(should_run=True, importance=validator.importance)
global_key = "data_quality.global"
node_key = f"data_quality.{node_name}"
global_config = config.get(global_key, {})
node_config = config.get(node_key, {})
should_run = global_config.get("enabled", node_config.get("enabled", True))
importance = node_config.get(
"importance", global_config.get("importance", validator.importance.value)
)
return ValidatorConfig(
should_run=should_run, importance=dq_base.DataValidationLevel(importance)
)


class BaseDataValidationDecorator(base.NodeTransformer):
Expand All @@ -38,22 +50,27 @@ def get_validators(self, node_to_validate: node.Node) -> List[dq_base.DataValida
def transform_node(
self, node_: node.Node, config: Dict[str, Any], fn: Callable
) -> Collection[node.Node]:

validators = self.get_validators(node_)
validator_configs = [
ValidatorConfig.from_validator(validator, config, node_.name)
for validator in validators
]
# If no validators are enabled, return the original node
if not any(validator_config.should_run for validator_config in validator_configs):
return [node_]
raw_node = node.Node(
name=node_.name
+ "_raw", # TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want
name=node_.name + "_raw",
# TODO -- make this unique -- this will break with multiple validation decorators, which we *don't* want
typ=node_.type,
doc_string=node_.documentation,
callabl=node_.callable,
node_source=node_.node_source,
input_types=node_.input_types,
tags=node_.tags,
)
validators = self.get_validators(node_)
validator_nodes = []
validator_name_config_map = {}
for validator in validators:
validator_config = ValidatorConfig.from_validator(validator, config)
for validator, validator_config in zip(validators, validator_configs):
if not validator_config.should_run:
continue

Expand Down Expand Up @@ -121,6 +138,21 @@ def final_node_callable(
def validate(self, fn: Callable):
pass

def optional_config(
self, fn: Callable, nodes: Optional[Collection[node.Node]]
) -> Dict[str, Any]:
"""Returns the configuration for the decorator.
:param fn: Function to validate
:param nodes: Nodes to validate
:return: Configuration that gets passed to the decorator
"""
out = {"data_quality.global": {}}
if nodes is not None:
for node_ in nodes:
out[f"data_quality.{node_.name}"] = {}
return out


class check_output_custom(BaseDataValidationDecorator):
"""Class to use if you want to implement your own custom validators.
Expand Down
156 changes: 154 additions & 2 deletions tests/function_modifiers/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from typing import Any, Dict

import numpy as np
import pandas as pd
import pytest

from hamilton import node
from hamilton.data_quality.base import DataValidationError, ValidationResult
from hamilton import ad_hoc_utils, driver, node
from hamilton.data_quality.base import (
DataValidationError,
DataValidationLevel,
DataValidator,
ValidationResult,
)
from hamilton.function_modifiers import (
DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG,
IS_DATA_VALIDATOR_TAG,
check_output,
check_output_custom,
)
from hamilton.function_modifiers.validation import ValidatorConfig
from hamilton.node import DependencyType
from tests.resources.dq_dummy_examples import (
DUMMY_VALIDATORS_FOR_TESTING,
Expand Down Expand Up @@ -154,3 +162,147 @@ def test_data_quality_constants_for_api_consistency():
# simple tests to test data quality constants remain the same
assert IS_DATA_VALIDATOR_TAG == "hamilton.data_quality.contains_dq_results"
assert DATA_VALIDATOR_ORIGINAL_OUTPUT_TAG == "hamilton.data_quality.source_node"


@pytest.mark.parametrize(
"validator,config,node_name,expected_result",
[
(
SampleDataValidator2(0, "warn"),
{},
"test",
ValidatorConfig(True, DataValidationLevel.WARN),
),
(
SampleDataValidator2(0, "fail"),
{},
"test",
ValidatorConfig(True, DataValidationLevel.FAIL),
),
(
SampleDataValidator2(0, "warn"),
{"data_quality.test": {"enabled": False}},
"test",
ValidatorConfig(False, DataValidationLevel.WARN),
),
(
SampleDataValidator2(0, "warn"),
{"data_quality.test": {"enabled": True}},
"test",
ValidatorConfig(True, DataValidationLevel.WARN),
),
(
SampleDataValidator2(0, "fail"),
{"data_quality.test": {"enabled": False, "importance": "warn"}},
"test",
ValidatorConfig(False, DataValidationLevel.WARN),
),
(
SampleDataValidator2(0, "warn"),
{"data_quality.global": {"enabled": False}},
"test",
ValidatorConfig(False, DataValidationLevel.WARN),
),
(
SampleDataValidator2(0, "warn"),
{"data_quality.global": {"enabled": False, "importance": "warn"}},
"test",
ValidatorConfig(False, DataValidationLevel.WARN),
),
],
)
def test_validator_config_derive(
validator: DataValidator,
config: Dict[str, Any],
node_name: str,
expected_result: ValidatorConfig,
):
assert ValidatorConfig.from_validator(validator, config, node_name) == expected_result


def test_validator_config_produces_no_validation_with_global():
decorator = check_output_custom(
SampleDataValidator2(dataset_length=1, importance="fail"),
SampleDataValidator3(dtype=np.int64, importance="fail"),
)

def fn(input: pd.Series) -> pd.Series:
return input

node_ = node.Node.from_fn(fn)
config = {"data_quality.global": {"enabled": False}}
subdag = decorator.transform_node(node_, config=config, fn=fn)
assert 1 == len(subdag)
node_, *_ = subdag
assert node_.name == "fn"
# Ensure nothing's been messed with
pd.testing.assert_series_equal(node_(pd.Series([1.0, 2.0, 3.0])), pd.Series([1.0, 2.0, 3.0]))


def test_validator_config_produces_no_validation_with_node_override():
decorator = check_output_custom(
SampleDataValidator2(dataset_length=1, importance="fail"),
SampleDataValidator3(dtype=np.int64, importance="fail"),
)

def fn(input: pd.Series) -> pd.Series:
return input

node_ = node.Node.from_fn(fn)
config = {f"data_quality.{node_.name}": {"enabled": False}}
subdag = decorator.transform_node(node_, config=config, fn=fn)
assert 1 == len(subdag)
node_, *_ = subdag
assert node_.name == "fn"
# Ensure nothing's been messed with
pd.testing.assert_series_equal(node_(pd.Series([1.0, 2.0, 3.0])), pd.Series([1.0, 2.0, 3.0]))


def test_validator_config_produces_no_validation_with_node_level_override():
decorator = check_output_custom(
SampleDataValidator2(dataset_length=1, importance="fail"),
SampleDataValidator3(dtype=np.int64, importance="fail"),
)

def fn(input: pd.Series) -> pd.Series:
return input

node_ = node.Node.from_fn(fn)
config = {f"data_quality.{node_.name}": {"importance": "warn"}}
subdag = decorator.transform_node(node_, config=config, fn=fn)
assert 4 == len(subdag)
nodes_by_name = {n.name: n for n in subdag}
# We set this to warn so this should not break
nodes_by_name["fn"].callable(
fn_raw=pd.Series([1.0, 2.0, 3.0]),
fn_dummy_data_validator_2=ValidationResult(False, "", {}),
fn_dummy_data_validator_3=ValidationResult(False, "", {}),
)


def test_data_validator_end_to_end_fails():
@check_output_custom(
SampleDataValidator2(dataset_length=1, importance="fail"),
)
def data_quality_check_that_doesnt_pass() -> pd.Series:
return pd.Series([1, 2])

dr = driver.Driver(
{}, ad_hoc_utils.create_temporary_module(data_quality_check_that_doesnt_pass)
)
with pytest.raises(DataValidationError):
dr.execute(final_vars=["data_quality_check_that_doesnt_pass"], inputs={})


def test_data_validator_end_to_end_succeed_when_node_disabled():
@check_output_custom(
SampleDataValidator2(dataset_length=1, importance="fail"),
)
def data_quality_check_that_doesnt_pass() -> pd.Series:
return pd.Series([1, 2])

dr = driver.Driver(
{f"data_quality.{data_quality_check_that_doesnt_pass.__name__}": {"enabled": False}},
ad_hoc_utils.create_temporary_module(data_quality_check_that_doesnt_pass),
)
dr.execute(final_vars=["data_quality_check_that_doesnt_pass"], inputs={})

0 comments on commit 64fecf2

Please sign in to comment.