Skip to content

Commit

Permalink
Merge pull request #103 from CITCOM-project/base-causal-test-case
Browse files Browse the repository at this point in the history
This merge contains both the base_causal_test_case branch as well as the causal_test_case_refactor branch
  • Loading branch information
christopher-wild authored Jan 31, 2023
2 parents b2ce1eb + c8c1d3f commit 1098043
Show file tree
Hide file tree
Showing 18 changed files with 675 additions and 231 deletions.
37 changes: 22 additions & 15 deletions causal_testing/generation/abstract_causal_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from causal_testing.specification.variable import Variable
from causal_testing.testing.causal_test_case import CausalTestCase
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
from causal_testing.testing.base_test_case import BaseTestCase

from enum import Enum

Expand All @@ -18,7 +19,7 @@

class AbstractCausalTestCase:
"""
An abstract test case serves as a generator for concrete test cases. Instead of having concrete conctrol
An abstract test case serves as a generator for concrete test cases. Instead of having concrete control
and treatment values, we instead just specify the intervention and the treatment variables. This then
enables potentially infinite concrete test cases to be generated between different values of the treatment.
"""
Expand All @@ -33,10 +34,11 @@ def __init__(
estimate_type: str = "ate",
effect: str = "total",
):
assert treatment_variable in scenario.variables.values(), (
"Treatment variables must be a subset of variables."
+ f" Instead got:\ntreatment_variable={treatment_variable}\nvariables={scenario.variables}"
)
if treatment_variable not in scenario.variables.values():
raise ValueError(
"Treatment variables must be a subset of variables."
+ f" Instead got:\ntreatment_variables={treatment_variable}\nvariables={scenario.variables}"
)

assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"

Expand Down Expand Up @@ -119,16 +121,21 @@ def _generate_concrete_tests(
)
model = optimizer.model()

base_test_case = BaseTestCase(
treatment_variable=self.treatment_variable,
outcome_variable=list(self.expected_causal_effect.keys())[0],
effect=self.effect,
)

concrete_test = CausalTestCase(
control_input_configuration={v: v.cast(model[v.z3]) for v in [self.treatment_variable]},
treatment_input_configuration={
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in [self.treatment_variable]
},
base_test_case=base_test_case,
control_value=self.treatment_variable.cast(model[self.treatment_variable.z3]),
treatment_value=self.treatment_variable.cast(
model[self.scenario.treatment_variables[self.treatment_variable.name].z3]
),
expected_causal_effect=list(self.expected_causal_effect.values())[0],
outcome_variables=list(self.expected_causal_effect.keys()),
estimate_type=self.estimate_type,
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
effect=self.effect,
)

for v in self.scenario.inputs():
Expand All @@ -150,7 +157,7 @@ def _generate_concrete_tests(
# Treatment run
if rct:
treatment_run = control_run.copy()
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
treatment_run.update({concrete_test.treatment_variable.name: concrete_test.treatment_value})
treatment_run["bin"] = index
runs.append(treatment_run)

Expand Down Expand Up @@ -197,7 +204,7 @@ def generate_concrete_tests(
runs = pd.concat([runs, runs_])
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"

control_configs = pd.DataFrame([test.control_input_configuration for test in concrete_tests])
control_configs = pd.DataFrame([{test.treatment_variable: test.control_value} for test in concrete_tests])
ks_stats = {
var: stats.kstest(control_configs[var], var.distribution.cdf).statistic
for var in control_configs.columns
Expand All @@ -220,8 +227,8 @@ def generate_concrete_tests(
for var in effect_modifier_configs.columns
}
)
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
control_values = [test.control_value for test in concrete_tests]
treatment_values = [test.treatment_value for test in concrete_tests]

if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
set(zip(control_values, treatment_values))
Expand Down
16 changes: 8 additions & 8 deletions causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
logger.info("Executing test: %s", test["name"])
logger.info(abstract_test)
logger.info([(v.name, v.distribution) for v in [abstract_test.treatment_variable]])
logger.info([abstract_test.treatment_variable.name, abstract_test.treatment_variable.distribution])
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])
Expand Down Expand Up @@ -201,15 +201,15 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
"""
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
causal_test_engine.identification(causal_test_case)
treatment_vars = list(causal_test_case.treatment_input_configuration)
minimal_adjustment_set = causal_test_engine.minimal_adjustment_set - {v.name for v in treatment_vars}
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
treatment_var = causal_test_case.treatment_variable
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
estimation_model = estimator(
(list(treatment_vars)[0].name,),
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][0],
[causal_test_case.control_input_configuration[v] for v in treatment_vars][0],
(treatment_var.name,),
causal_test_case.treatment_value,
causal_test_case.control_value,
minimal_adjustment_set,
(list(causal_test_case.outcome_variables)[0].name,),
(causal_test_case.outcome_variable.name,),
causal_test_engine.scenario_execution_data_df,
effect_modifiers=causal_test_case.effect_modifier_configuration,
)
Expand Down
23 changes: 23 additions & 0 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,5 +468,28 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
return True
return any([self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)])

def identification(self, base_test_case):
"""Identify and return the minimum adjustment set
:param base_test_case: A base test case instance containing the outcome_variable and the
treatment_variable required for identification.
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
estimate as opposed to a purely associational estimate.
"""
minimal_adjustment_sets = []
if base_test_case.effect == "total":
minimal_adjustment_sets = self.enumerate_minimal_adjustment_sets(
[base_test_case.treatment_variable.name], [base_test_case.outcome_variable.name]
)
elif base_test_case.effect == "direct":
minimal_adjustment_sets = self.direct_effect_adjustment_sets(
[base_test_case.treatment_variable.name], [base_test_case.outcome_variable.name]
)
else:
raise ValueError("Causal effect should be 'total' or 'direct'")

minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
return minimal_adjustment_set

def __str__(self):
return f"Nodes: {self.graph.nodes}\nEdges: {self.graph.edges}"
14 changes: 14 additions & 0 deletions causal_testing/testing/base_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass
from causal_testing.specification.variable import Variable
from causal_testing.testing.effect import Effect


@dataclass(frozen=True)
class BaseTestCase:
"""
A base causal test case represents the relationship of an edge on a causal DAG.
"""

treatment_variable: Variable
outcome_variable: Variable
effect: str = Effect.total.value
80 changes: 41 additions & 39 deletions causal_testing/testing/causal_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,75 @@

from causal_testing.specification.variable import Variable
from causal_testing.testing.causal_test_outcome import CausalTestOutcome
from causal_testing.testing.base_test_case import BaseTestCase

logger = logging.getLogger(__name__)


class CausalTestCase:
"""
A causal test case is a triple (X, Delta, Y), where X is an input configuration, Delta is an intervention, and
Y is the expected causal effect on a particular output. The goal of a causal test case is to test whether the
intervention Delta made to the input configuration X causes the model-under-test to produce the expected change
in Y.
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
specified.
The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment causes the
model-under-test to produce the expected change. The CausalTestCase structure is designed for execution using the
CausalTestEngine, using either execute_test() function to execute a single test case or packing CausalTestCases into
a CausalTestSuite and executing them as a batch using the execute_test_suite() function.
"""

def __init__(
self,
control_input_configuration: dict[Variable:Any],
base_test_case: BaseTestCase,
expected_causal_effect: CausalTestOutcome,
outcome_variables: dict[Variable],
treatment_input_configuration: dict[Variable:Any] = None,
control_value: Any,
treatment_value: Any = None,
estimate_type: str = "ate",
effect_modifier_configuration: dict[Variable:Any] = None,
effect: str = "total",
):
"""
When a CausalTestCase is initialised, it takes the intervention and applies it to the input configuration to
create two distinct input configurations: a control input configuration and a treatment input configuration.
The former is the input configuration before applying the intervention and the latter is the input configuration
after applying the intervention.
:param control_input_configuration: The input configuration representing the control values of the treatment
variables.
:param treatment_input_configuration: The input configuration representing the treatment values of the treatment
variables. That is, the input configuration *after* applying the intervention.
:param base_test_case: A BaseTestCase object consisting of a treatment variable, outcome variable and effect
:param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect).
:param control_value: The control value for the treatment variable (before intervention).
:param treatment_value: The treatment value for the treatment variable (after intervention).
:param estimate_type: A string which denotes the type of estimate to return
:param effect_modifier_configuration:
"""
self.control_input_configuration = control_input_configuration
self.base_test_case = base_test_case
self.control_value = control_value
self.expected_causal_effect = expected_causal_effect
self.outcome_variables = outcome_variables
self.treatment_input_configuration = treatment_input_configuration
self.outcome_variable = base_test_case.outcome_variable
self.treatment_variable = base_test_case.treatment_variable
self.treatment_value = treatment_value
self.estimate_type = estimate_type
self.effect = effect
self.effect = base_test_case.effect

if effect_modifier_configuration:
self.effect_modifier_configuration = effect_modifier_configuration
else:
self.effect_modifier_configuration = dict()
assert (
self.control_input_configuration.keys() == self.treatment_input_configuration.keys()
), "Control and treatment input configurations must have the same keys."

def get_treatment_variables(self):
"""Return a list of the treatment variables (as strings) for this causal test case."""
return [v.name for v in self.control_input_configuration]
def get_treatment_variable(self):
"""Return the treatment variable name (as string) for this causal test case"""
return self.treatment_variable.name

def get_outcome_variables(self):
"""Return a list of the outcome variables (as strings) for this causal test case."""
return [v.name for v in self.outcome_variables]
def get_outcome_variable(self):
"""Return the outcome variable name (as string) for this causal test case."""
return self.outcome_variable.name

def get_control_values(self):
"""Return a list of the control values for each treatment variable in this causal test case."""
return list(self.control_input_configuration.values())
def get_control_value(self):
"""Return a the control value of the treatment variable in this causal test case."""
return self.control_value

def get_treatment_values(self):
"""Return a list of the treatment values for each treatment variable in this causal test case."""
return list(self.treatment_input_configuration.values())
def get_treatment_value(self):
"""Return the treatment value of the treatment variable in this causal test case."""
return self.treatment_value

def __str__(self):
treatment_config = {k.name: v for k, v in self.treatment_input_configuration.items()}
control_config = {k.name: v for k, v in self.control_input_configuration.items()}
treatment_config = {self.treatment_variable.name: self.treatment_value}
control_config = {self.treatment_variable.name: self.control_value}
outcome_variable = {self.outcome_variable}
return (
f"Running {treatment_config} instead of {control_config} should cause the following "
f"changes to {self.outcome_variables}: {self.expected_causal_effect}."
f"changes to {outcome_variable}: {self.expected_causal_effect}."
)
Loading

0 comments on commit 1098043

Please sign in to comment.