Skip to content

Commit 4034519

Browse files
authored
Merge pull request #250 from CITCOM-project/surrogateassisted
Surrogate Assisted Causal Testing
2 parents b82241f + 48fd185 commit 4034519

File tree

9 files changed

+637
-4
lines changed

9 files changed

+637
-4
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def list_all_min_sep(
6666
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
6767
if treatment_node_set_neighbours.difference(outcome_node_set):
6868
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
69-
node = set(sample(treatment_node_set_neighbours.difference(outcome_node_set), 1))
69+
node = set(sample(sorted(treatment_node_set_neighbours.difference(outcome_node_set)), 1))
7070

7171
# 7.2. Add this node to the treatment node set and recurse (left branch)
7272
yield from list_all_min_sep(
@@ -125,7 +125,6 @@ def close_separator(
125125

126126

127127
class CausalDAG(nx.DiGraph):
128-
129128
"""A causal DAG is a directed acyclic graph in which nodes represent random variables and edges represent causality
130129
between a pair of random variables. We implement a CausalDAG as a networkx DiGraph with an additional check that
131130
ensures it is acyclic. A CausalDAG must be specified as a dot file.
@@ -500,11 +499,20 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
500499
return True
501500
return any((self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)))
502501

503-
def identification(self, base_test_case: BaseTestCase):
502+
@staticmethod
503+
def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario):
504+
"""Remove variables labelled as hidden from adjustment set(s)
505+
:param minimal_adjustment_sets: list of minimal adjustment set(s) to have hidden variables removed from
506+
:param scenario: The modelling scenario which informs the variables that are hidden
507+
"""
508+
return [adj for adj in minimal_adjustment_sets if all(not scenario.variables.get(x).hidden for x in adj)]
509+
510+
def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None):
504511
"""Identify and return the minimum adjustment set
505512
506513
:param base_test_case: A base test case instance containing the outcome_variable and the
507514
treatment_variable required for identification.
515+
:param scenario: The modelling scenario relating to the tests
508516
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
509517
estimate as opposed to a purely associational estimate.
510518
"""
@@ -520,6 +528,12 @@ def identification(self, base_test_case: BaseTestCase):
520528
else:
521529
raise ValueError("Causal effect should be 'total' or 'direct'")
522530

531+
if scenario is not None:
532+
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)
533+
534+
if len(minimal_adjustment_sets) == 0:
535+
return set()
536+
523537
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
524538
return minimal_adjustment_set
525539

causal_testing/surrogate/__init__.py

Whitespace-only changes.
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Module containing classes to define and run causal surrogate assisted test cases"""
2+
3+
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
5+
from typing import Callable
6+
7+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
8+
from causal_testing.specification.causal_specification import CausalSpecification
9+
from causal_testing.testing.base_test_case import BaseTestCase
10+
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
11+
12+
13+
@dataclass
14+
class SimulationResult:
15+
"""Data class holding the data and result metadata of a simulation"""
16+
17+
data: dict
18+
fault: bool
19+
relationship: str
20+
21+
22+
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
23+
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
24+
space to be searched"""
25+
26+
@abstractmethod
27+
def search(
28+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
29+
) -> list:
30+
"""Function which implements a search routine which searches for the optimal fitness value for the specified
31+
scenario
32+
:param surrogate_models: The surrogate models to be searched
33+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
34+
35+
36+
class Simulator(ABC):
37+
"""Class to be inherited with Simulator specific functions to start, shutdown and run the simulation with the give
38+
config file"""
39+
40+
@abstractmethod
41+
def startup(self, **kwargs):
42+
"""Function that when run, initialises and opens the Simulator"""
43+
44+
@abstractmethod
45+
def shutdown(self, **kwargs):
46+
"""Function to safely exit and shutdown the Simulator"""
47+
48+
@abstractmethod
49+
def run_with_config(self, configuration: dict) -> SimulationResult:
50+
"""Run the simulator with the given configuration and return the results in the structure of a
51+
SimulationResult
52+
:param configuration: The configuration required to initialise the Simulation
53+
:return: Simulation results in the structure of the SimulationResult data class"""
54+
55+
56+
class CausalSurrogateAssistedTestCase:
57+
"""A class representing a single causal surrogate assisted test case."""
58+
59+
def __init__(
60+
self,
61+
specification: CausalSpecification,
62+
search_algorithm: SearchAlgorithm,
63+
simulator: Simulator,
64+
):
65+
self.specification = specification
66+
self.search_algorithm = search_algorithm
67+
self.simulator = simulator
68+
69+
def execute(
70+
self,
71+
data_collector: ObservationalDataCollector,
72+
max_executions: int = 200,
73+
custom_data_aggregator: Callable[[dict, dict], dict] = None,
74+
):
75+
"""For this specific test case, a search algorithm is used to find the most contradictory point in the input
76+
space which is, therefore, most likely to indicate incorrect behaviour. This cadidate test case is run against
77+
the simulator, checked for faults and the result returned with collected data
78+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
79+
:param max_executions: Maximum number of simulator executions before exiting the search
80+
:param custom_data_aggregator:
81+
:return: tuple containing SimulationResult or str, execution number and collected data"""
82+
data_collector.collect_data()
83+
84+
for i in range(max_executions):
85+
surrogate_models = self.generate_surrogates(self.specification, data_collector)
86+
candidate_test_case, _, surrogate = self.search_algorithm.search(surrogate_models, self.specification)
87+
88+
self.simulator.startup()
89+
test_result = self.simulator.run_with_config(candidate_test_case)
90+
self.simulator.shutdown()
91+
92+
if custom_data_aggregator is not None:
93+
if data_collector.data is not None:
94+
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
95+
else:
96+
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)
97+
98+
if test_result.fault:
99+
print(
100+
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "
101+
f"expected {surrogate.expected_relationship}."
102+
)
103+
test_result.relationship = (
104+
f"{surrogate.treatment} -> {surrogate.outcome} expected {surrogate.expected_relationship}"
105+
)
106+
return test_result, i + 1, data_collector.data
107+
108+
print("No fault found")
109+
return "No fault found", i + 1, data_collector.data
110+
111+
def generate_surrogates(
112+
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
113+
) -> list[CubicSplineRegressionEstimator]:
114+
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
115+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
116+
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
117+
:return: A list of surrogate models
118+
"""
119+
surrogate_models = []
120+
121+
for u, v in specification.causal_dag.graph.edges:
122+
edge_metadata = specification.causal_dag.graph.adj[u][v]
123+
if "included" in edge_metadata:
124+
from_var = specification.scenario.variables.get(u)
125+
to_var = specification.scenario.variables.get(v)
126+
base_test_case = BaseTestCase(from_var, to_var)
127+
128+
minimal_adjustment_set = specification.causal_dag.identification(base_test_case, specification.scenario)
129+
130+
surrogate = CubicSplineRegressionEstimator(
131+
u,
132+
0,
133+
0,
134+
minimal_adjustment_set,
135+
v,
136+
4,
137+
df=data_collector.data,
138+
expected_relationship=edge_metadata["expected"],
139+
)
140+
surrogate_models.append(surrogate)
141+
142+
return surrogate_models
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Module containing implementation of search algorithm for surrogate search """
2+
# Fitness functions are required to be iteratively defined, including all variables within.
3+
4+
from operator import itemgetter
5+
from pygad import GA
6+
7+
from causal_testing.specification.causal_specification import CausalSpecification
8+
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
9+
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm
10+
11+
12+
class GeneticSearchAlgorithm(SearchAlgorithm):
13+
"""Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""
14+
15+
def __init__(self, delta=0.05, config: dict = None) -> None:
16+
super().__init__()
17+
18+
self.delta = delta
19+
self.config = config
20+
self.contradiction_functions = {
21+
"positive": lambda x: -1 * x,
22+
"negative": lambda x: x,
23+
"no_effect": abs,
24+
"some_effect": lambda x: abs(1 / x),
25+
}
26+
27+
# pylint: disable=too-many-locals
28+
def search(
29+
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
30+
) -> list:
31+
solutions = []
32+
33+
for surrogate in surrogate_models:
34+
contradiction_function = self.contradiction_functions[surrogate.expected_relationship]
35+
36+
# The GA fitness function after including required variables into the function's scope
37+
# Unused arguments are required for pygad's fitness function signature
38+
#pylint: disable=cell-var-from-loop
39+
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
40+
surrogate.control_value = solution[0] - self.delta
41+
surrogate.treatment_value = solution[0] + self.delta
42+
43+
adjustment_dict = {}
44+
for i, adjustment in enumerate(surrogate.adjustment_set):
45+
adjustment_dict[adjustment] = solution[i + 1]
46+
47+
ate = surrogate.estimate_ate_calculated(adjustment_dict)
48+
49+
return contradiction_function(ate)
50+
51+
gene_types, gene_space = self.create_gene_types(surrogate, specification)
52+
53+
ga = GA(
54+
num_generations=200,
55+
num_parents_mating=4,
56+
fitness_func=fitness_function,
57+
sol_per_pop=10,
58+
num_genes=1 + len(surrogate.adjustment_set),
59+
gene_space=gene_space,
60+
gene_type=gene_types,
61+
)
62+
63+
if self.config is not None:
64+
for k, v in self.config.items():
65+
if k == "gene_space":
66+
raise ValueError(
67+
"Gene space should not be set through config. This is generated from the causal "
68+
"specification"
69+
)
70+
setattr(ga, k, v)
71+
72+
ga.run()
73+
solution, fitness, _ = ga.best_solution()
74+
75+
solution_dict = {}
76+
solution_dict[surrogate.treatment] = solution[0]
77+
for idx, adj in enumerate(surrogate.adjustment_set):
78+
solution_dict[adj] = solution[idx + 1]
79+
solutions.append((solution_dict, fitness, surrogate))
80+
81+
return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges
82+
83+
@staticmethod
84+
def create_gene_types(
85+
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
86+
) -> tuple[list, list]:
87+
"""Generate the gene_types and gene_space for a given fitness function and specification
88+
:param surrogate_model: Instance of a CubicSplineRegressionEstimator
89+
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""
90+
91+
var_space = {}
92+
var_space[surrogate_model.treatment] = {}
93+
for adj in surrogate_model.adjustment_set:
94+
var_space[adj] = {}
95+
96+
for relationship in list(specification.scenario.constraints):
97+
rel_split = str(relationship).split(" ")
98+
99+
if rel_split[0] in var_space:
100+
if rel_split[1] == ">=":
101+
var_space[rel_split[0]]["low"] = int(rel_split[2])
102+
elif rel_split[1] == "<=":
103+
var_space[rel_split[0]]["high"] = int(rel_split[2])
104+
105+
gene_space = []
106+
gene_space.append(var_space[surrogate_model.treatment])
107+
for adj in surrogate_model.adjustment_set:
108+
gene_space.append(var_space[adj])
109+
110+
gene_types = []
111+
gene_types.append(specification.scenario.variables.get(surrogate_model.treatment).datatype)
112+
for adj in surrogate_model.adjustment_set:
113+
gene_types.append(specification.scenario.variables.get(adj).datatype)
114+
return gene_types, gene_space

causal_testing/testing/estimators.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,58 @@ def _get_confidence_intervals(self, model, treatment):
439439
return [ci_low, ci_high]
440440

441441

442+
class CubicSplineRegressionEstimator(LinearRegressionEstimator):
443+
"""A Cubic Spline Regression Estimator is a parametric estimator which restricts the variables in the data to a
444+
combination of parameters and basis functions of the variables.
445+
"""
446+
447+
def __init__(
448+
# pylint: disable=too-many-arguments
449+
self,
450+
treatment: str,
451+
treatment_value: float,
452+
control_value: float,
453+
adjustment_set: set,
454+
outcome: str,
455+
basis: int,
456+
df: pd.DataFrame = None,
457+
effect_modifiers: dict[Variable:Any] = None,
458+
formula: str = None,
459+
alpha: float = 0.05,
460+
expected_relationship=None,
461+
):
462+
super().__init__(
463+
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
464+
)
465+
466+
self.expected_relationship = expected_relationship
467+
468+
if effect_modifiers is None:
469+
effect_modifiers = []
470+
471+
if formula is None:
472+
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
473+
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"
474+
475+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
476+
model = self._run_linear_regression()
477+
478+
x = {"Intercept": 1, self.treatment: self.treatment_value}
479+
if adjustment_config is not None:
480+
for k, v in adjustment_config.items():
481+
x[k] = v
482+
if self.effect_modifiers is not None:
483+
for k, v in self.effect_modifiers.items():
484+
x[k] = v
485+
486+
treatment = model.predict(x).iloc[0]
487+
488+
x[self.treatment] = self.control_value
489+
control = model.predict(x).iloc[0]
490+
491+
return treatment - control
492+
493+
442494
class InstrumentalVariableEstimator(Estimator):
443495
"""
444496
Carry out estimation using instrumental variable adjustment rather than conventional adjustment. This means we do

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ dependencies = [
2626
"scipy~=1.7",
2727
"statsmodels~=0.13",
2828
"tabulate~=0.8",
29-
"pydot~=1.4"
29+
"pydot~=1.4",
30+
"pygad~=3.2"
3031
]
3132
dynamic = ["version"]
3233

0 commit comments

Comments
 (0)