Skip to content

Commit

Permalink
Merge pull request #250 from CITCOM-project/surrogateassisted
Browse files Browse the repository at this point in the history
Surrogate Assisted Causal Testing
  • Loading branch information
rsomers1998 authored Jan 30, 2024
2 parents b82241f + 48fd185 commit 4034519
Show file tree
Hide file tree
Showing 9 changed files with 637 additions and 4 deletions.
20 changes: 17 additions & 3 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def list_all_min_sep(
# 7. Check that there exists at least one neighbour of the treatment nodes that is not in the outcome node set
if treatment_node_set_neighbours.difference(outcome_node_set):
# 7.1. If so, sample a random node from the set of treatment nodes' neighbours not in the outcome node set
node = set(sample(treatment_node_set_neighbours.difference(outcome_node_set), 1))
node = set(sample(sorted(treatment_node_set_neighbours.difference(outcome_node_set)), 1))

# 7.2. Add this node to the treatment node set and recurse (left branch)
yield from list_all_min_sep(
Expand Down Expand Up @@ -125,7 +125,6 @@ def close_separator(


class CausalDAG(nx.DiGraph):

"""A causal DAG is a directed acyclic graph in which nodes represent random variables and edges represent causality
between a pair of random variables. We implement a CausalDAG as a networkx DiGraph with an additional check that
ensures it is acyclic. A CausalDAG must be specified as a dot file.
Expand Down Expand Up @@ -500,11 +499,20 @@ def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
return True
return any((self.depends_on_outputs(n, scenario) for n in self.graph.predecessors(node)))

def identification(self, base_test_case: BaseTestCase):
@staticmethod
def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario):
"""Remove variables labelled as hidden from adjustment set(s)
:param minimal_adjustment_sets: list of minimal adjustment set(s) to have hidden variables removed from
:param scenario: The modelling scenario which informs the variables that are hidden
"""
return [adj for adj in minimal_adjustment_sets if all(not scenario.variables.get(x).hidden for x in adj)]

def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None):
"""Identify and return the minimum adjustment set
:param base_test_case: A base test case instance containing the outcome_variable and the
treatment_variable required for identification.
:param scenario: The modelling scenario relating to the tests
:return minimal_adjustment_set: The smallest set of variables which can be adjusted for to obtain a causal
estimate as opposed to a purely associational estimate.
"""
Expand All @@ -520,6 +528,12 @@ def identification(self, base_test_case: BaseTestCase):
else:
raise ValueError("Causal effect should be 'total' or 'direct'")

if scenario is not None:
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)

if len(minimal_adjustment_sets) == 0:
return set()

minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
return minimal_adjustment_set

Expand Down
Empty file.
142 changes: 142 additions & 0 deletions causal_testing/surrogate/causal_surrogate_assisted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Module containing classes to define and run causal surrogate assisted test cases"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable

from causal_testing.data_collection.data_collector import ObservationalDataCollector
from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.testing.base_test_case import BaseTestCase
from causal_testing.testing.estimators import CubicSplineRegressionEstimator


@dataclass
class SimulationResult:
"""Data class holding the data and result metadata of a simulation"""

data: dict
fault: bool
relationship: str


class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the
space to be searched"""

@abstractmethod
def search(
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
) -> list:
"""Function which implements a search routine which searches for the optimal fitness value for the specified
scenario
:param surrogate_models: The surrogate models to be searched
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""


class Simulator(ABC):
"""Class to be inherited with Simulator specific functions to start, shutdown and run the simulation with the give
config file"""

@abstractmethod
def startup(self, **kwargs):
"""Function that when run, initialises and opens the Simulator"""

@abstractmethod
def shutdown(self, **kwargs):
"""Function to safely exit and shutdown the Simulator"""

@abstractmethod
def run_with_config(self, configuration: dict) -> SimulationResult:
"""Run the simulator with the given configuration and return the results in the structure of a
SimulationResult
:param configuration: The configuration required to initialise the Simulation
:return: Simulation results in the structure of the SimulationResult data class"""


class CausalSurrogateAssistedTestCase:
"""A class representing a single causal surrogate assisted test case."""

def __init__(
self,
specification: CausalSpecification,
search_algorithm: SearchAlgorithm,
simulator: Simulator,
):
self.specification = specification
self.search_algorithm = search_algorithm
self.simulator = simulator

def execute(
self,
data_collector: ObservationalDataCollector,
max_executions: int = 200,
custom_data_aggregator: Callable[[dict, dict], dict] = None,
):
"""For this specific test case, a search algorithm is used to find the most contradictory point in the input
space which is, therefore, most likely to indicate incorrect behaviour. This cadidate test case is run against
the simulator, checked for faults and the result returned with collected data
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
:param max_executions: Maximum number of simulator executions before exiting the search
:param custom_data_aggregator:
:return: tuple containing SimulationResult or str, execution number and collected data"""
data_collector.collect_data()

for i in range(max_executions):
surrogate_models = self.generate_surrogates(self.specification, data_collector)
candidate_test_case, _, surrogate = self.search_algorithm.search(surrogate_models, self.specification)

self.simulator.startup()
test_result = self.simulator.run_with_config(candidate_test_case)
self.simulator.shutdown()

if custom_data_aggregator is not None:
if data_collector.data is not None:
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data)
else:
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True)

if test_result.fault:
print(
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with "
f"expected {surrogate.expected_relationship}."
)
test_result.relationship = (
f"{surrogate.treatment} -> {surrogate.outcome} expected {surrogate.expected_relationship}"
)
return test_result, i + 1, data_collector.data

print("No fault found")
return "No fault found", i + 1, data_collector.data

def generate_surrogates(
self, specification: CausalSpecification, data_collector: ObservationalDataCollector
) -> list[CubicSplineRegressionEstimator]:
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata.
:param specification: The Causal Specification (combination of Scenario and Causal Dag)
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario
:return: A list of surrogate models
"""
surrogate_models = []

for u, v in specification.causal_dag.graph.edges:
edge_metadata = specification.causal_dag.graph.adj[u][v]
if "included" in edge_metadata:
from_var = specification.scenario.variables.get(u)
to_var = specification.scenario.variables.get(v)
base_test_case = BaseTestCase(from_var, to_var)

minimal_adjustment_set = specification.causal_dag.identification(base_test_case, specification.scenario)

surrogate = CubicSplineRegressionEstimator(
u,
0,
0,
minimal_adjustment_set,
v,
4,
df=data_collector.data,
expected_relationship=edge_metadata["expected"],
)
surrogate_models.append(surrogate)

return surrogate_models
114 changes: 114 additions & 0 deletions causal_testing/surrogate/surrogate_search_algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Module containing implementation of search algorithm for surrogate search """
# Fitness functions are required to be iteratively defined, including all variables within.

from operator import itemgetter
from pygad import GA

from causal_testing.specification.causal_specification import CausalSpecification
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm


class GeneticSearchAlgorithm(SearchAlgorithm):
"""Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models."""

def __init__(self, delta=0.05, config: dict = None) -> None:
super().__init__()

self.delta = delta
self.config = config
self.contradiction_functions = {
"positive": lambda x: -1 * x,
"negative": lambda x: x,
"no_effect": abs,
"some_effect": lambda x: abs(1 / x),
}

# pylint: disable=too-many-locals
def search(
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification
) -> list:
solutions = []

for surrogate in surrogate_models:
contradiction_function = self.contradiction_functions[surrogate.expected_relationship]

# The GA fitness function after including required variables into the function's scope
# Unused arguments are required for pygad's fitness function signature
#pylint: disable=cell-var-from-loop
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument
surrogate.control_value = solution[0] - self.delta
surrogate.treatment_value = solution[0] + self.delta

adjustment_dict = {}
for i, adjustment in enumerate(surrogate.adjustment_set):
adjustment_dict[adjustment] = solution[i + 1]

ate = surrogate.estimate_ate_calculated(adjustment_dict)

return contradiction_function(ate)

gene_types, gene_space = self.create_gene_types(surrogate, specification)

ga = GA(
num_generations=200,
num_parents_mating=4,
fitness_func=fitness_function,
sol_per_pop=10,
num_genes=1 + len(surrogate.adjustment_set),
gene_space=gene_space,
gene_type=gene_types,
)

if self.config is not None:
for k, v in self.config.items():
if k == "gene_space":
raise ValueError(
"Gene space should not be set through config. This is generated from the causal "
"specification"
)
setattr(ga, k, v)

ga.run()
solution, fitness, _ = ga.best_solution()

solution_dict = {}
solution_dict[surrogate.treatment] = solution[0]
for idx, adj in enumerate(surrogate.adjustment_set):
solution_dict[adj] = solution[idx + 1]
solutions.append((solution_dict, fitness, surrogate))

return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges

@staticmethod
def create_gene_types(
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification
) -> tuple[list, list]:
"""Generate the gene_types and gene_space for a given fitness function and specification
:param surrogate_model: Instance of a CubicSplineRegressionEstimator
:param specification: The Causal Specification (combination of Scenario and Causal Dag)"""

var_space = {}
var_space[surrogate_model.treatment] = {}
for adj in surrogate_model.adjustment_set:
var_space[adj] = {}

for relationship in list(specification.scenario.constraints):
rel_split = str(relationship).split(" ")

if rel_split[0] in var_space:
if rel_split[1] == ">=":
var_space[rel_split[0]]["low"] = int(rel_split[2])
elif rel_split[1] == "<=":
var_space[rel_split[0]]["high"] = int(rel_split[2])

gene_space = []
gene_space.append(var_space[surrogate_model.treatment])
for adj in surrogate_model.adjustment_set:
gene_space.append(var_space[adj])

gene_types = []
gene_types.append(specification.scenario.variables.get(surrogate_model.treatment).datatype)
for adj in surrogate_model.adjustment_set:
gene_types.append(specification.scenario.variables.get(adj).datatype)
return gene_types, gene_space
52 changes: 52 additions & 0 deletions causal_testing/testing/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,58 @@ def _get_confidence_intervals(self, model, treatment):
return [ci_low, ci_high]


class CubicSplineRegressionEstimator(LinearRegressionEstimator):
"""A Cubic Spline Regression Estimator is a parametric estimator which restricts the variables in the data to a
combination of parameters and basis functions of the variables.
"""

def __init__(
# pylint: disable=too-many-arguments
self,
treatment: str,
treatment_value: float,
control_value: float,
adjustment_set: set,
outcome: str,
basis: int,
df: pd.DataFrame = None,
effect_modifiers: dict[Variable:Any] = None,
formula: str = None,
alpha: float = 0.05,
expected_relationship=None,
):
super().__init__(
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, formula, alpha
)

self.expected_relationship = expected_relationship

if effect_modifiers is None:
effect_modifiers = []

if formula is None:
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"

def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
model = self._run_linear_regression()

x = {"Intercept": 1, self.treatment: self.treatment_value}
if adjustment_config is not None:
for k, v in adjustment_config.items():
x[k] = v
if self.effect_modifiers is not None:
for k, v in self.effect_modifiers.items():
x[k] = v

treatment = model.predict(x).iloc[0]

x[self.treatment] = self.control_value
control = model.predict(x).iloc[0]

return treatment - control


class InstrumentalVariableEstimator(Estimator):
"""
Carry out estimation using instrumental variable adjustment rather than conventional adjustment. This means we do
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ dependencies = [
"scipy~=1.7",
"statsmodels~=0.13",
"tabulate~=0.8",
"pydot~=1.4"
"pydot~=1.4",
"pygad~=3.2"
]
dynamic = ["version"]

Expand Down
Loading

0 comments on commit 4034519

Please sign in to comment.