diff --git a/.github/workflows/ci-tests-drafts.yaml b/.github/workflows/ci-tests-drafts.yaml index 8575f760..cc504b58 100644 --- a/.github/workflows/ci-tests-drafts.yaml +++ b/.github/workflows/ci-tests-drafts.yaml @@ -13,7 +13,7 @@ jobs: strategy: matrix: os: ["ubuntu-latest", "windows-latest", "macos-latest"] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python diff --git a/.github/workflows/ci-tests.yaml b/.github/workflows/ci-tests.yaml index a218ab19..49d1fab3 100644 --- a/.github/workflows/ci-tests.yaml +++ b/.github/workflows/ci-tests.yaml @@ -18,7 +18,7 @@ jobs: strategy: matrix: os: ["ubuntu-latest", "windows-latest", "macos-latest"] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python diff --git a/causal_testing/json_front/json_class.py b/causal_testing/json_front/json_class.py index 6be7fa68..fd617228 100644 --- a/causal_testing/json_front/json_class.py +++ b/causal_testing/json_front/json_class.py @@ -70,13 +70,13 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None) data_paths = [] self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths) - def setup(self, scenario: Scenario, data=None): + def setup(self, scenario: Scenario, data=None, ignore_cycles=False): """Function to populate all the necessary parts of the json_class needed to execute tests""" self.scenario = scenario self._get_scenario_variables() self.scenario.setup_treatment_variables() self.causal_specification = CausalSpecification( - scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path) + scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path, ignore_cycles=ignore_cycles) ) # Parse the JSON test plan with open(self.input_paths.json_path, encoding="utf-8") as f: diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index 08f8e91a..f00d4ad8 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -130,7 +130,7 @@ class CausalDAG(nx.DiGraph): ensures it is acyclic. A CausalDAG must be specified as a dot file. """ - def __init__(self, dot_path: str = None, **attr): + def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr): super().__init__(**attr) if dot_path: with open(dot_path, "r", encoding="utf-8") as file: @@ -144,7 +144,12 @@ def __init__(self, dot_path: str = None, **attr): self.graph = nx.DiGraph() if not self.is_acyclic(): - raise nx.HasACycle("Invalid Causal DAG: contains a cycle.") + if ignore_cycles: + logger.warning( + "Cycles found. Ignoring them can invalidate causal estimates. Proceed with extreme caution." + ) + else: + raise nx.HasACycle("Invalid Causal DAG: contains a cycle.") def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: """ @@ -164,16 +169,17 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: ) # (iii) Instrument and outcome do not share causes - if any( - ( - cause - for cause in self.graph.nodes - if list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) - and list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) - ) - ): - raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") + for cause in self.graph.nodes: + # Exclude self-cycles due to breaking changes in NetworkX > 3.2 + outcome_paths = ( + list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else [] + ) + instrument_paths = ( + list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else [] + ) + if len(instrument_paths) > 0 and len(outcome_paths) > 0: + raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") return True def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr): @@ -188,12 +194,18 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr): if not self.is_acyclic(): raise nx.HasACycle("Invalid Causal DAG: contains a cycle.") + def cycle_nodes(self) -> list: + """Get the nodes involved in any cycles. + :return: A list containing all nodes involved in a cycle. + """ + return [node for cycle in nx.simple_cycles(self.graph) for node in cycle] + def is_acyclic(self) -> bool: """Checks if the graph is acyclic. :return: True if acyclic, False otherwise. """ - return not list(nx.simple_cycles(self.graph)) + return not self.cycle_nodes() def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG: """Convert the causal DAG to a proper back-door graph. @@ -267,7 +279,9 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus gback.graph.remove_edge(v1, v2) return gback - def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]: + def direct_effect_adjustment_sets( + self, treatments: list[str], outcomes: list[str], nodes_to_ignore: list[str] = None + ) -> list[set[str]]: """ Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments and outcomes for DIRECT causal effect. @@ -278,12 +292,17 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st 2019. These works use the algorithm presented by Takata et al. in their work entitled: Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, 2013. - :param list[str] treatments: List of treatment names. - :param list[str] outcomes: List of outcome names. + :param treatments: List of treatment names. + :param outcomes: List of outcome names. + :param nodes_to_ignore: List of nodes to exclude from tests if they appear as treatments, outcomes, or in the + adjustment set. :return: A list of possible adjustment sets. :rtype: list[set[str]] """ + if nodes_to_ignore is None: + nodes_to_ignore = [] + indirect_graph = self.get_indirect_graph(treatments, outcomes) ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes) gam = nx.moral_graph(ancestor_graph.graph) @@ -295,7 +314,7 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes))) if set(outcomes) in min_seps: min_seps.remove(set(outcomes)) - return min_seps + return sorted(list(filter(lambda sep: not sep.intersection(nodes_to_ignore), min_seps))) def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]: """Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments diff --git a/causal_testing/specification/metamorphic_relation.py b/causal_testing/specification/metamorphic_relation.py index 9d8c8afb..4a7c70c9 100644 --- a/causal_testing/specification/metamorphic_relation.py +++ b/causal_testing/specification/metamorphic_relation.py @@ -10,6 +10,8 @@ import argparse import logging import json +from multiprocessing import Pool + import networkx as nx import pandas as pd import numpy as np @@ -214,46 +216,96 @@ def __str__(self): ) -def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]: +def generate_metamorphic_relation( + node_pair: tuple[str, str], dag: CausalDAG, nodes_to_ignore: set = None +) -> MetamorphicRelation: + """Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can + be constructed (e.g. because every valid adjustment set contains a node to ignore). + + :param node_pair: The pair of nodes to consider. + :param dag: Causal DAG from which the metamorphic relations will be generated. + :param nodes_to_ignore: Set of nodes which will be excluded from causal tests. + + :return: A list containing ShouldCause and ShouldNotCause metamorphic relations. + """ + + if nodes_to_ignore is None: + nodes_to_ignore = set() + + (u, v) = node_pair + metamorphic_relations = [] + + # Create a ShouldNotCause relation for each pair of nodes that are not directly connected + if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges): + # Case 1: U --> ... --> V + if u in nx.ancestors(dag.graph, v): + adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag)) + + # Case 2: V --> ... --> U + elif v in nx.ancestors(dag.graph, u): + adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldNotCause(v, u, list(adj_sets[0]), dag)) + + # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V). + # Only make one MR since V _||_ U == U _||_ V + else: + adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag)) + + # Create a ShouldCause relation for each edge (u, v) or (v, u) + elif (u, v) in dag.graph.edges: + adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldCause(u, v, list(adj_sets[0]), dag)) + else: + adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldCause(v, u, list(adj_sets[0]), dag)) + return metamorphic_relations + + +def generate_metamorphic_relations( + dag: CausalDAG, nodes_to_ignore: set = None, threads: int = 0, nodes_to_test: set = None +) -> list[MetamorphicRelation]: """Construct a list of metamorphic relations implied by the Causal DAG. This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause relation for every (minimal) conditional independence relation implied by the structure of the DAG. - :param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated. + :param dag: Causal DAG from which the metamorphic relations will be generated. + :param nodes_to_ignore: Set of nodes which will be excluded from causal tests. + :param threads: Number of threads to use (if generating in parallel). + :param nodes_to_test: Set of nodes to test the relationships between (defaults to all nodes). + :return: A list containing ShouldCause and ShouldNotCause metamorphic relations. """ - metamorphic_relations = [] - for node_pair in combinations(dag.graph.nodes, 2): - (u, v) = node_pair - - # Create a ShouldNotCause relation for each pair of nodes that are not directly connected - if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges): - # Case 1: U --> ... --> V - if u in nx.ancestors(dag.graph, v): - adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) - metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag)) - - # Case 2: V --> ... --> U - elif v in nx.ancestors(dag.graph, u): - adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0]) - metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag)) - - # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V). - # Only make one MR since V _||_ U == U _||_ V - else: - adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) - metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag)) - # Create a ShouldCause relation for each edge (u, v) or (v, u) - elif (u, v) in dag.graph.edges: - adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) - metamorphic_relations.append(ShouldCause(u, v, adj_set, dag)) - else: - adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0]) - metamorphic_relations.append(ShouldCause(v, u, adj_set, dag)) + if nodes_to_ignore is None: + nodes_to_ignore = {} - return metamorphic_relations + if nodes_to_test is None: + nodes_to_test = dag.graph.nodes + + if not threads: + metamorphic_relations = [ + generate_metamorphic_relation(node_pair, dag, nodes_to_ignore) + for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2) + ] + else: + with Pool(threads) as pool: + metamorphic_relations = pool.starmap( + generate_metamorphic_relation, + map( + lambda node_pair: (node_pair, dag, nodes_to_ignore), + combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2), + ), + ) + + return [item for items in metamorphic_relations for item in items] if __name__ == "__main__": # pragma: no cover @@ -273,10 +325,32 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]: help="Specify path where tests should be saved, normally a .json file.", required=True, ) + parser.add_argument( + "--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0 + ) + parser.add_argument("-i", "--ignore-cycles", action="store_true") args = parser.parse_args() - causal_dag = CausalDAG(args.dag_path) - relations = generate_metamorphic_relations(causal_dag) + causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles) + + dag_nodes_to_test = set( + k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True" + ) + + if not causal_dag.is_acyclic() and args.ignore_cycles: + logger.warning( + "Ignoring cycles by removing causal tests that reference any node within a cycle. " + "Your causal test suite WILL NOT BE COMPLETE!" + ) + relations = generate_metamorphic_relations( + causal_dag, + nodes_to_test=dag_nodes_to_test, + nodes_to_ignore=set(causal_dag.cycle_nodes()), + threads=args.threads, + ) + else: + relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads) + tests = [ relation.to_json_stub(skip=False) for relation in relations diff --git a/dafni/main_dafni.py b/dafni/main_dafni.py index 5fa66b0a..4cf75cb6 100644 --- a/dafni/main_dafni.py +++ b/dafni/main_dafni.py @@ -36,6 +36,10 @@ def get_args(test_args=None) -> argparse.Namespace: "--tests_path", required=True, help="Input configuration file path " "containing the causal tests (.json)" ) + parser.add_argument( + "-i", "--ignore_cycles", action="store_true", help="Whether to ignore cycles in the DAG.", default=False + ) + parser.add_argument( "--variables_path", required=True, @@ -72,17 +76,14 @@ def get_args(test_args=None) -> argparse.Namespace: args.tests_path = Path(args.tests_path) if args.dag_path is not None: - args.dag_path = Path(args.dag_path) if args.output_path is None: - args.output_path = "./data/outputs/causal_tests_results.json" Path(args.output_path).parent.mkdir(exist_ok=True) else: - args.output_path = Path(args.output_path) args.output_path.parent.mkdir(exist_ok=True) @@ -98,13 +99,11 @@ def read_variables(variables_path: Path) -> FileNotFoundError | dict: - dict - A valid dictionary consisting of the causal tests """ if not variables_path.exists() or variables_path.is_dir(): - print(f"JSON file not found at the specified location: {variables_path}") raise FileNotFoundError with variables_path.open("r") as file: - inputs = json.load(file) return inputs @@ -118,7 +117,6 @@ def validate_variables(data_dict: dict) -> tuple: - Tuple containing the inputs, outputs and constraints to pass into the modelling scenario """ if data_dict["variables"]: - variables = data_dict["variables"] inputs = [ @@ -136,12 +134,9 @@ def validate_variables(data_dict: dict) -> tuple: constraints = set() for variable, input_var in zip(variables, inputs): - if "constraint" in variable: - constraints.add(input_var.z3 == variable["constraint"]) else: - raise ValidationError("Cannot find the variables defined by the causal tests.") return inputs, outputs, constraints @@ -154,7 +149,6 @@ def main(): args = get_args() try: - # Step 0: Read in the runtime dataset(s) data_frame = pd.concat([pd.read_csv(d) for d in args.data_path]) @@ -190,7 +184,7 @@ def main(): json_utility.set_paths(args.tests_path, args.dag_path, args.data_path) # Step 6: Sets up all the necessary parts of the json_class needed to execute tests - json_utility.setup(scenario=modelling_scenario, data=data_frame) + json_utility.setup(scenario=modelling_scenario, data=data_frame, ignore_cycles=args.ignore_cycles) # Step 7: Run the causal tests test_outcomes = json_utility.run_json_tests( @@ -200,7 +194,6 @@ def main(): # Step 8: Update, print and save the final outputs for test in test_outcomes: - test.pop("estimator") test["result"] = test["result"].to_dict(json=True) @@ -210,17 +203,14 @@ def main(): test["result"].pop("control_value") with open(args.output_path, "w", encoding="utf-8") as f: - print(json.dumps(test_outcomes, indent=2), file=f) print(json.dumps(test_outcomes, indent=2)) except ValidationError as ve: - print(f"Cannot validate the specified input configurations: {ve}") else: - print(f"Execution successful. " f"Output file saved at {Path(args.output_path).parent.resolve()}") diff --git a/pyproject.toml b/pyproject.toml index 738c2d06..1b35e510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "fitter~=1.7", "lifelines~=0.29.0", "lhsmdu~=1.1", - "networkx~=2.6", + "networkx~=3.4", "numpy~=1.26", "pandas>=2.1", "scikit_learn~=1.4", diff --git a/tests/specification_tests/test_metamorphic_relations.py b/tests/specification_tests/test_metamorphic_relations.py index dc35e071..7998b66c 100644 --- a/tests/specification_tests/test_metamorphic_relations.py +++ b/tests/specification_tests/test_metamorphic_relations.py @@ -10,6 +10,7 @@ ShouldCause, ShouldNotCause, generate_metamorphic_relations, + generate_metamorphic_relation, ) from causal_testing.data_collection.data_collector import ExperimentalDataCollector from causal_testing.specification.variable import Input, Output @@ -73,6 +74,10 @@ def setUp(self) -> None: dag_dot = """digraph DAG { rankdir=LR; X1 -> Z; Z -> M; M -> Y; X2 -> Z; X3 -> M;}""" with open(self.dag_dot_path, "w") as f: f.write(dag_dot) + self.dcg_dot_path = os.path.join(self.temp_dir_path, "dcg.dot") + dcg_dot = """digraph dct { a -> b -> c -> d; d -> c; }""" + with open(self.dcg_dot_path, "w") as f: + f.write(dcg_dot) X1 = Input("X1", float) X2 = Input("X2", float) @@ -248,6 +253,79 @@ def test_all_metamorphic_relations_implied_by_dag(self): self.assertEqual(extra_snc_relations, []) self.assertEqual(missing_snc_relations, []) + def test_all_metamorphic_relations_implied_by_dag_parallel(self): + dag = CausalDAG(self.dag_dot_path) + dag.add_edge("Z", "Y") # Add a direct path from Z to Y so M becomes a mediator + metamorphic_relations = generate_metamorphic_relations(dag, threads=2) + should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)] + should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)] + + # Check all ShouldCause relations are present and no extra + expected_should_cause_relations = [ + ShouldCause("X1", "Z", [], dag), + ShouldCause("Z", "M", [], dag), + ShouldCause("M", "Y", ["Z"], dag), + ShouldCause("Z", "Y", ["M"], dag), + ShouldCause("X2", "Z", [], dag), + ShouldCause("X3", "M", [], dag), + ] + + extra_sc_relations = [scr for scr in should_cause_relations if scr not in expected_should_cause_relations] + missing_sc_relations = [escr for escr in expected_should_cause_relations if escr not in should_cause_relations] + + self.assertEqual(extra_sc_relations, []) + self.assertEqual(missing_sc_relations, []) + + # Check all ShouldNotCause relations are present and no extra + expected_should_not_cause_relations = [ + ShouldNotCause("X1", "X2", [], dag), + ShouldNotCause("X1", "X3", [], dag), + ShouldNotCause("X1", "M", ["Z"], dag), + ShouldNotCause("X1", "Y", ["Z"], dag), + ShouldNotCause("X2", "X3", [], dag), + ShouldNotCause("X2", "M", ["Z"], dag), + ShouldNotCause("X2", "Y", ["Z"], dag), + ShouldNotCause("X3", "Y", ["M", "Z"], dag), + ShouldNotCause("Z", "X3", [], dag), + ] + + extra_snc_relations = [ + sncr for sncr in should_not_cause_relations if sncr not in expected_should_not_cause_relations + ] + missing_snc_relations = [ + esncr for esncr in expected_should_not_cause_relations if esncr not in should_not_cause_relations + ] + + self.assertEqual(extra_snc_relations, []) + self.assertEqual(missing_snc_relations, []) + + def test_all_metamorphic_relations_implied_by_dag_ignore_cycles(self): + dag = CausalDAG(self.dcg_dot_path, ignore_cycles=True) + metamorphic_relations = generate_metamorphic_relations(dag, threads=2, nodes_to_ignore=set(dag.cycle_nodes())) + should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)] + should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)] + + # Check all ShouldCause relations are present and no extra + + self.assertEqual( + should_cause_relations, + [ + ShouldCause("a", "b", [], dag), + ], + ) + self.assertEqual( + should_not_cause_relations, + [], + ) + + def test_generate_metamorphic_relation_(self): + dag = CausalDAG(self.dag_dot_path) + [metamorphic_relation] = generate_metamorphic_relation(("X1", "Z"), dag) + self.assertEqual( + metamorphic_relation, + ShouldCause("X1", "Z", [], dag), + ) + def test_equivalent_metamorphic_relations(self): dag = CausalDAG(self.dag_dot_path) sc_mr_a = ShouldCause("X", "Y", ["A", "B", "C"], dag)