-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #250 from CITCOM-project/surrogateassisted
Surrogate Assisted Causal Testing
- Loading branch information
Showing
9 changed files
with
637 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
"""Module containing classes to define and run causal surrogate assisted test cases""" | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Callable | ||
|
||
from causal_testing.data_collection.data_collector import ObservationalDataCollector | ||
from causal_testing.specification.causal_specification import CausalSpecification | ||
from causal_testing.testing.base_test_case import BaseTestCase | ||
from causal_testing.testing.estimators import CubicSplineRegressionEstimator | ||
|
||
|
||
@dataclass | ||
class SimulationResult: | ||
"""Data class holding the data and result metadata of a simulation""" | ||
|
||
data: dict | ||
fault: bool | ||
relationship: str | ||
|
||
|
||
class SearchAlgorithm(ABC): # pylint: disable=too-few-public-methods | ||
"""Class to be inherited with the search algorithm consisting of a search function and the fitness function of the | ||
space to be searched""" | ||
|
||
@abstractmethod | ||
def search( | ||
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification | ||
) -> list: | ||
"""Function which implements a search routine which searches for the optimal fitness value for the specified | ||
scenario | ||
:param surrogate_models: The surrogate models to be searched | ||
:param specification: The Causal Specification (combination of Scenario and Causal Dag)""" | ||
|
||
|
||
class Simulator(ABC): | ||
"""Class to be inherited with Simulator specific functions to start, shutdown and run the simulation with the give | ||
config file""" | ||
|
||
@abstractmethod | ||
def startup(self, **kwargs): | ||
"""Function that when run, initialises and opens the Simulator""" | ||
|
||
@abstractmethod | ||
def shutdown(self, **kwargs): | ||
"""Function to safely exit and shutdown the Simulator""" | ||
|
||
@abstractmethod | ||
def run_with_config(self, configuration: dict) -> SimulationResult: | ||
"""Run the simulator with the given configuration and return the results in the structure of a | ||
SimulationResult | ||
:param configuration: The configuration required to initialise the Simulation | ||
:return: Simulation results in the structure of the SimulationResult data class""" | ||
|
||
|
||
class CausalSurrogateAssistedTestCase: | ||
"""A class representing a single causal surrogate assisted test case.""" | ||
|
||
def __init__( | ||
self, | ||
specification: CausalSpecification, | ||
search_algorithm: SearchAlgorithm, | ||
simulator: Simulator, | ||
): | ||
self.specification = specification | ||
self.search_algorithm = search_algorithm | ||
self.simulator = simulator | ||
|
||
def execute( | ||
self, | ||
data_collector: ObservationalDataCollector, | ||
max_executions: int = 200, | ||
custom_data_aggregator: Callable[[dict, dict], dict] = None, | ||
): | ||
"""For this specific test case, a search algorithm is used to find the most contradictory point in the input | ||
space which is, therefore, most likely to indicate incorrect behaviour. This cadidate test case is run against | ||
the simulator, checked for faults and the result returned with collected data | ||
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario | ||
:param max_executions: Maximum number of simulator executions before exiting the search | ||
:param custom_data_aggregator: | ||
:return: tuple containing SimulationResult or str, execution number and collected data""" | ||
data_collector.collect_data() | ||
|
||
for i in range(max_executions): | ||
surrogate_models = self.generate_surrogates(self.specification, data_collector) | ||
candidate_test_case, _, surrogate = self.search_algorithm.search(surrogate_models, self.specification) | ||
|
||
self.simulator.startup() | ||
test_result = self.simulator.run_with_config(candidate_test_case) | ||
self.simulator.shutdown() | ||
|
||
if custom_data_aggregator is not None: | ||
if data_collector.data is not None: | ||
data_collector.data = custom_data_aggregator(data_collector.data, test_result.data) | ||
else: | ||
data_collector.data = data_collector.data.append(test_result.data, ignore_index=True) | ||
|
||
if test_result.fault: | ||
print( | ||
f"Fault found between {surrogate.treatment} causing {surrogate.outcome}. Contradiction with " | ||
f"expected {surrogate.expected_relationship}." | ||
) | ||
test_result.relationship = ( | ||
f"{surrogate.treatment} -> {surrogate.outcome} expected {surrogate.expected_relationship}" | ||
) | ||
return test_result, i + 1, data_collector.data | ||
|
||
print("No fault found") | ||
return "No fault found", i + 1, data_collector.data | ||
|
||
def generate_surrogates( | ||
self, specification: CausalSpecification, data_collector: ObservationalDataCollector | ||
) -> list[CubicSplineRegressionEstimator]: | ||
"""Generate a surrogate model for each edge of the dag that specifies it is included in the DAG metadata. | ||
:param specification: The Causal Specification (combination of Scenario and Causal Dag) | ||
:param data_collector: An ObservationalDataCollector which gathers data relevant to the specified scenario | ||
:return: A list of surrogate models | ||
""" | ||
surrogate_models = [] | ||
|
||
for u, v in specification.causal_dag.graph.edges: | ||
edge_metadata = specification.causal_dag.graph.adj[u][v] | ||
if "included" in edge_metadata: | ||
from_var = specification.scenario.variables.get(u) | ||
to_var = specification.scenario.variables.get(v) | ||
base_test_case = BaseTestCase(from_var, to_var) | ||
|
||
minimal_adjustment_set = specification.causal_dag.identification(base_test_case, specification.scenario) | ||
|
||
surrogate = CubicSplineRegressionEstimator( | ||
u, | ||
0, | ||
0, | ||
minimal_adjustment_set, | ||
v, | ||
4, | ||
df=data_collector.data, | ||
expected_relationship=edge_metadata["expected"], | ||
) | ||
surrogate_models.append(surrogate) | ||
|
||
return surrogate_models |
114 changes: 114 additions & 0 deletions
114
causal_testing/surrogate/surrogate_search_algorithms.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
"""Module containing implementation of search algorithm for surrogate search """ | ||
# Fitness functions are required to be iteratively defined, including all variables within. | ||
|
||
from operator import itemgetter | ||
from pygad import GA | ||
|
||
from causal_testing.specification.causal_specification import CausalSpecification | ||
from causal_testing.testing.estimators import CubicSplineRegressionEstimator | ||
from causal_testing.surrogate.causal_surrogate_assisted import SearchAlgorithm | ||
|
||
|
||
class GeneticSearchAlgorithm(SearchAlgorithm): | ||
"""Implementation of SearchAlgorithm class. Implements genetic search algorithm for surrogate models.""" | ||
|
||
def __init__(self, delta=0.05, config: dict = None) -> None: | ||
super().__init__() | ||
|
||
self.delta = delta | ||
self.config = config | ||
self.contradiction_functions = { | ||
"positive": lambda x: -1 * x, | ||
"negative": lambda x: x, | ||
"no_effect": abs, | ||
"some_effect": lambda x: abs(1 / x), | ||
} | ||
|
||
# pylint: disable=too-many-locals | ||
def search( | ||
self, surrogate_models: list[CubicSplineRegressionEstimator], specification: CausalSpecification | ||
) -> list: | ||
solutions = [] | ||
|
||
for surrogate in surrogate_models: | ||
contradiction_function = self.contradiction_functions[surrogate.expected_relationship] | ||
|
||
# The GA fitness function after including required variables into the function's scope | ||
# Unused arguments are required for pygad's fitness function signature | ||
#pylint: disable=cell-var-from-loop | ||
def fitness_function(ga, solution, idx): # pylint: disable=unused-argument | ||
surrogate.control_value = solution[0] - self.delta | ||
surrogate.treatment_value = solution[0] + self.delta | ||
|
||
adjustment_dict = {} | ||
for i, adjustment in enumerate(surrogate.adjustment_set): | ||
adjustment_dict[adjustment] = solution[i + 1] | ||
|
||
ate = surrogate.estimate_ate_calculated(adjustment_dict) | ||
|
||
return contradiction_function(ate) | ||
|
||
gene_types, gene_space = self.create_gene_types(surrogate, specification) | ||
|
||
ga = GA( | ||
num_generations=200, | ||
num_parents_mating=4, | ||
fitness_func=fitness_function, | ||
sol_per_pop=10, | ||
num_genes=1 + len(surrogate.adjustment_set), | ||
gene_space=gene_space, | ||
gene_type=gene_types, | ||
) | ||
|
||
if self.config is not None: | ||
for k, v in self.config.items(): | ||
if k == "gene_space": | ||
raise ValueError( | ||
"Gene space should not be set through config. This is generated from the causal " | ||
"specification" | ||
) | ||
setattr(ga, k, v) | ||
|
||
ga.run() | ||
solution, fitness, _ = ga.best_solution() | ||
|
||
solution_dict = {} | ||
solution_dict[surrogate.treatment] = solution[0] | ||
for idx, adj in enumerate(surrogate.adjustment_set): | ||
solution_dict[adj] = solution[idx + 1] | ||
solutions.append((solution_dict, fitness, surrogate)) | ||
|
||
return max(solutions, key=itemgetter(1)) # This can be done better with fitness normalisation between edges | ||
|
||
@staticmethod | ||
def create_gene_types( | ||
surrogate_model: CubicSplineRegressionEstimator, specification: CausalSpecification | ||
) -> tuple[list, list]: | ||
"""Generate the gene_types and gene_space for a given fitness function and specification | ||
:param surrogate_model: Instance of a CubicSplineRegressionEstimator | ||
:param specification: The Causal Specification (combination of Scenario and Causal Dag)""" | ||
|
||
var_space = {} | ||
var_space[surrogate_model.treatment] = {} | ||
for adj in surrogate_model.adjustment_set: | ||
var_space[adj] = {} | ||
|
||
for relationship in list(specification.scenario.constraints): | ||
rel_split = str(relationship).split(" ") | ||
|
||
if rel_split[0] in var_space: | ||
if rel_split[1] == ">=": | ||
var_space[rel_split[0]]["low"] = int(rel_split[2]) | ||
elif rel_split[1] == "<=": | ||
var_space[rel_split[0]]["high"] = int(rel_split[2]) | ||
|
||
gene_space = [] | ||
gene_space.append(var_space[surrogate_model.treatment]) | ||
for adj in surrogate_model.adjustment_set: | ||
gene_space.append(var_space[adj]) | ||
|
||
gene_types = [] | ||
gene_types.append(specification.scenario.variables.get(surrogate_model.treatment).datatype) | ||
for adj in surrogate_model.adjustment_set: | ||
gene_types.append(specification.scenario.variables.get(adj).datatype) | ||
return gene_types, gene_space |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.