From cfdd51605be502e76499decee73f8b1e65702272 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 1 Nov 2024 16:54:19 +0000 Subject: [PATCH 01/14] Option to ignore cycles when generating metamorphic relations --- causal_testing/specification/causal_dag.py | 26 +++- .../specification/metamorphic_relation.py | 118 +++++++++++++----- 2 files changed, 106 insertions(+), 38 deletions(-) diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index 08f8e91a..45fd5327 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -130,7 +130,7 @@ class CausalDAG(nx.DiGraph): ensures it is acyclic. A CausalDAG must be specified as a dot file. """ - def __init__(self, dot_path: str = None, **attr): + def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr): super().__init__(**attr) if dot_path: with open(dot_path, "r", encoding="utf-8") as file: @@ -144,7 +144,12 @@ def __init__(self, dot_path: str = None, **attr): self.graph = nx.DiGraph() if not self.is_acyclic(): - raise nx.HasACycle("Invalid Causal DAG: contains a cycle.") + if ignore_cycles: + logger.warning( + "Cycles found. Ignoring them can invalidate causal estimates. Proceed with extreme caution." + ) + else: + raise nx.HasACycle("Invalid Causal DAG: contains a cycle.") def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: """ @@ -188,12 +193,18 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr): if not self.is_acyclic(): raise nx.HasACycle("Invalid Causal DAG: contains a cycle.") + def cycle_nodes(self) -> list: + """Get the nodes involved in any cycles. + :return: A list containing all nodes involved in a cycle. + """ + return [node for cycle in nx.simple_cycles(self.graph) for node in cycle] + def is_acyclic(self) -> bool: """Checks if the graph is acyclic. :return: True if acyclic, False otherwise. """ - return not list(nx.simple_cycles(self.graph)) + return not self.cycle_nodes() def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG: """Convert the causal DAG to a proper back-door graph. @@ -267,7 +278,9 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus gback.graph.remove_edge(v1, v2) return gback - def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]: + def direct_effect_adjustment_sets( + self, treatments: list[str], outcomes: list[str], nodes_to_ignore: list[str] = None + ) -> list[set[str]]: """ Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments and outcomes for DIRECT causal effect. @@ -284,6 +297,9 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st :rtype: list[set[str]] """ + if nodes_to_ignore is None: + nodes_to_ignore = [] + indirect_graph = self.get_indirect_graph(treatments, outcomes) ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes) gam = nx.moral_graph(ancestor_graph.graph) @@ -295,7 +311,7 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes))) if set(outcomes) in min_seps: min_seps.remove(set(outcomes)) - return min_seps + return sorted(list(filter(lambda sep: not sep.intersection(nodes_to_ignore), min_seps))) def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]: """Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments diff --git a/causal_testing/specification/metamorphic_relation.py b/causal_testing/specification/metamorphic_relation.py index 9d8c8afb..f3f611cc 100644 --- a/causal_testing/specification/metamorphic_relation.py +++ b/causal_testing/specification/metamorphic_relation.py @@ -13,6 +13,7 @@ import networkx as nx import pandas as pd import numpy as np +from multiprocessing import Pool from causal_testing.specification.causal_specification import CausalDAG, Node from causal_testing.data_collection.data_collector import ExperimentalDataCollector @@ -214,46 +215,89 @@ def __str__(self): ) -def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]: +def generate_metamorphic_relation( + node_pair: tuple[str, str], dag: CausalDAG, nodes_to_ignore: set = None +) -> MetamorphicRelation: + """Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can + be constructed (e.g. because every valid adjustment set contains a node to ignore). + + :param node_pair: The pair of nodes to consider. + :param dag: Causal DAG from which the metamorphic relations will be generated. + :param nodes_to_ignore: Set of nodes which will be excluded from causal tests. + + :return: A list containing ShouldCause and ShouldNotCause metamorphic relations. + """ + + if nodes_to_ignore is None: + nodes_to_ignore = set() + + (u, v) = node_pair + metamorphic_relations = [] + + # Create a ShouldNotCause relation for each pair of nodes that are not directly connected + if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges): + # Case 1: U --> ... --> V + if u in nx.ancestors(dag.graph, v): + adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag)) + + # Case 2: V --> ... --> U + elif v in nx.ancestors(dag.graph, u): + adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldNotCause(v, u, list(adj_sets[0]), dag)) + + # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V). + # Only make one MR since V _||_ U == U _||_ V + else: + adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag)) + + # Create a ShouldCause relation for each edge (u, v) or (v, u) + elif (u, v) in dag.graph.edges: + adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldCause(u, v, list(adj_sets[0]), dag)) + else: + adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore) + if adj_sets: + metamorphic_relations.append(ShouldCause(v, u, list(adj_sets[0]), dag)) + return metamorphic_relations + + +def generate_metamorphic_relations( + dag: CausalDAG, nodes_to_ignore: set = {}, threads: int = 0 +) -> list[MetamorphicRelation]: """Construct a list of metamorphic relations implied by the Causal DAG. This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause relation for every (minimal) conditional independence relation implied by the structure of the DAG. - :param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated. + :param dag: Causal DAG from which the metamorphic relations will be generated. + :param nodes_to_ignore: Set of nodes which will be excluded from causal tests. + :param threads: Number of threads to use (if generating in parallel). + :return: A list containing ShouldCause and ShouldNotCause metamorphic relations. """ - metamorphic_relations = [] - for node_pair in combinations(dag.graph.nodes, 2): - (u, v) = node_pair - - # Create a ShouldNotCause relation for each pair of nodes that are not directly connected - if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges): - # Case 1: U --> ... --> V - if u in nx.ancestors(dag.graph, v): - adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) - metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag)) - - # Case 2: V --> ... --> U - elif v in nx.ancestors(dag.graph, u): - adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0]) - metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag)) - - # Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V). - # Only make one MR since V _||_ U == U _||_ V - else: - adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) - metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag)) - # Create a ShouldCause relation for each edge (u, v) or (v, u) - elif (u, v) in dag.graph.edges: - adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0]) - metamorphic_relations.append(ShouldCause(u, v, adj_set, dag)) - else: - adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0]) - metamorphic_relations.append(ShouldCause(v, u, adj_set, dag)) + if not threads: + metamorphic_relations = [ + generate_metamorphic_relation(node_pair, dag, nodes_to_ignore) + for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, dag.graph.nodes), 2) + ] + else: + with Pool(threads) as pool: + pool.starmap( + generate_metamorphic_relation, + map( + lambda node_pair: (node_pair, dag, nodes_to_ignore), + combinations(filter(lambda node: node not in nodes_to_ignore, dag.graph.nodes), 2), + ), + ) - return metamorphic_relations + return [item for items in metamorphic_relations for item in items] if __name__ == "__main__": # pragma: no cover @@ -273,10 +317,18 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]: help="Specify path where tests should be saved, normally a .json file.", required=True, ) + parser.add_argument("-i", "--ignore-cycles", action="store_true") args = parser.parse_args() - causal_dag = CausalDAG(args.dag_path) - relations = generate_metamorphic_relations(causal_dag) + causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles) + + if not causal_dag.is_acyclic() and args.ignore_cycles: + logger.warning( + "Ignoring cycles by removing causal tests that reference any node within a cycle. " + "Your causal test suite WILL NOT BE COMPLETE!" + ) + relations = generate_metamorphic_relations(causal_dag, nodes_to_ignore=set(causal_dag.cycle_nodes()), threads=20) + tests = [ relation.to_json_stub(skip=False) for relation in relations From da3fb4d1bc706c0271a3dfe6f19e44416240a813 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Mon, 4 Nov 2024 15:44:51 +0000 Subject: [PATCH 02/14] Capability to specify nodes to test in causal dag --- .../specification/metamorphic_relation.py | 27 ++++++++++++++----- pyproject.toml | 2 +- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/causal_testing/specification/metamorphic_relation.py b/causal_testing/specification/metamorphic_relation.py index f3f611cc..87633df5 100644 --- a/causal_testing/specification/metamorphic_relation.py +++ b/causal_testing/specification/metamorphic_relation.py @@ -10,10 +10,11 @@ import argparse import logging import json +from multiprocessing import Pool + import networkx as nx import pandas as pd import numpy as np -from multiprocessing import Pool from causal_testing.specification.causal_specification import CausalDAG, Node from causal_testing.data_collection.data_collector import ExperimentalDataCollector @@ -268,7 +269,7 @@ def generate_metamorphic_relation( def generate_metamorphic_relations( - dag: CausalDAG, nodes_to_ignore: set = {}, threads: int = 0 + dag: CausalDAG, nodes_to_ignore: set = {}, threads: int = 0, nodes_to_test: set = None ) -> list[MetamorphicRelation]: """Construct a list of metamorphic relations implied by the Causal DAG. @@ -282,18 +283,21 @@ def generate_metamorphic_relations( :return: A list containing ShouldCause and ShouldNotCause metamorphic relations. """ + if nodes_to_test is None: + nodes_to_test = dag.graph.nodes + if not threads: metamorphic_relations = [ generate_metamorphic_relation(node_pair, dag, nodes_to_ignore) - for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, dag.graph.nodes), 2) + for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2) ] else: with Pool(threads) as pool: - pool.starmap( + metamorphic_relations = pool.starmap( generate_metamorphic_relation, map( lambda node_pair: (node_pair, dag, nodes_to_ignore), - combinations(filter(lambda node: node not in nodes_to_ignore, dag.graph.nodes), 2), + combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2), ), ) @@ -317,17 +321,28 @@ def generate_metamorphic_relations( help="Specify path where tests should be saved, normally a .json file.", required=True, ) + parser.add_argument( + "--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0 + ) parser.add_argument("-i", "--ignore-cycles", action="store_true") args = parser.parse_args() causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles) + nodes_to_test = set( + k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True" + ) + if not causal_dag.is_acyclic() and args.ignore_cycles: logger.warning( "Ignoring cycles by removing causal tests that reference any node within a cycle. " "Your causal test suite WILL NOT BE COMPLETE!" ) - relations = generate_metamorphic_relations(causal_dag, nodes_to_ignore=set(causal_dag.cycle_nodes()), threads=20) + relations = generate_metamorphic_relations( + causal_dag, nodes_to_test=nodes_to_test, nodes_to_ignore=set(causal_dag.cycle_nodes()), threads=args.threads + ) + else: + relations = generate_metamorphic_relations(causal_dag, nodes_to_test=nodes_to_test, threads=args.threads) tests = [ relation.to_json_stub(skip=False) diff --git a/pyproject.toml b/pyproject.toml index 738c2d06..1b35e510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "fitter~=1.7", "lifelines~=0.29.0", "lhsmdu~=1.1", - "networkx~=2.6", + "networkx~=3.4", "numpy~=1.26", "pandas>=2.1", "scikit_learn~=1.4", From 9263747dd30f5fb2ff4525727544ad66b6a128a8 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 15 Nov 2024 16:29:14 +0000 Subject: [PATCH 03/14] main dafni ignore cycle option --- causal_testing/json_front/json_class.py | 4 ++-- dafni/main_dafni.py | 20 +++++--------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/causal_testing/json_front/json_class.py b/causal_testing/json_front/json_class.py index 6be7fa68..fd617228 100644 --- a/causal_testing/json_front/json_class.py +++ b/causal_testing/json_front/json_class.py @@ -70,13 +70,13 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None) data_paths = [] self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths) - def setup(self, scenario: Scenario, data=None): + def setup(self, scenario: Scenario, data=None, ignore_cycles=False): """Function to populate all the necessary parts of the json_class needed to execute tests""" self.scenario = scenario self._get_scenario_variables() self.scenario.setup_treatment_variables() self.causal_specification = CausalSpecification( - scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path) + scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path, ignore_cycles=ignore_cycles) ) # Parse the JSON test plan with open(self.input_paths.json_path, encoding="utf-8") as f: diff --git a/dafni/main_dafni.py b/dafni/main_dafni.py index 5fa66b0a..4cf75cb6 100644 --- a/dafni/main_dafni.py +++ b/dafni/main_dafni.py @@ -36,6 +36,10 @@ def get_args(test_args=None) -> argparse.Namespace: "--tests_path", required=True, help="Input configuration file path " "containing the causal tests (.json)" ) + parser.add_argument( + "-i", "--ignore_cycles", action="store_true", help="Whether to ignore cycles in the DAG.", default=False + ) + parser.add_argument( "--variables_path", required=True, @@ -72,17 +76,14 @@ def get_args(test_args=None) -> argparse.Namespace: args.tests_path = Path(args.tests_path) if args.dag_path is not None: - args.dag_path = Path(args.dag_path) if args.output_path is None: - args.output_path = "./data/outputs/causal_tests_results.json" Path(args.output_path).parent.mkdir(exist_ok=True) else: - args.output_path = Path(args.output_path) args.output_path.parent.mkdir(exist_ok=True) @@ -98,13 +99,11 @@ def read_variables(variables_path: Path) -> FileNotFoundError | dict: - dict - A valid dictionary consisting of the causal tests """ if not variables_path.exists() or variables_path.is_dir(): - print(f"JSON file not found at the specified location: {variables_path}") raise FileNotFoundError with variables_path.open("r") as file: - inputs = json.load(file) return inputs @@ -118,7 +117,6 @@ def validate_variables(data_dict: dict) -> tuple: - Tuple containing the inputs, outputs and constraints to pass into the modelling scenario """ if data_dict["variables"]: - variables = data_dict["variables"] inputs = [ @@ -136,12 +134,9 @@ def validate_variables(data_dict: dict) -> tuple: constraints = set() for variable, input_var in zip(variables, inputs): - if "constraint" in variable: - constraints.add(input_var.z3 == variable["constraint"]) else: - raise ValidationError("Cannot find the variables defined by the causal tests.") return inputs, outputs, constraints @@ -154,7 +149,6 @@ def main(): args = get_args() try: - # Step 0: Read in the runtime dataset(s) data_frame = pd.concat([pd.read_csv(d) for d in args.data_path]) @@ -190,7 +184,7 @@ def main(): json_utility.set_paths(args.tests_path, args.dag_path, args.data_path) # Step 6: Sets up all the necessary parts of the json_class needed to execute tests - json_utility.setup(scenario=modelling_scenario, data=data_frame) + json_utility.setup(scenario=modelling_scenario, data=data_frame, ignore_cycles=args.ignore_cycles) # Step 7: Run the causal tests test_outcomes = json_utility.run_json_tests( @@ -200,7 +194,6 @@ def main(): # Step 8: Update, print and save the final outputs for test in test_outcomes: - test.pop("estimator") test["result"] = test["result"].to_dict(json=True) @@ -210,17 +203,14 @@ def main(): test["result"].pop("control_value") with open(args.output_path, "w", encoding="utf-8") as f: - print(json.dumps(test_outcomes, indent=2), file=f) print(json.dumps(test_outcomes, indent=2)) except ValidationError as ve: - print(f"Cannot validate the specified input configurations: {ve}") else: - print(f"Execution successful. " f"Output file saved at {Path(args.output_path).parent.resolve()}") From 61ab78db64ee6a374e7c4776745202f03f01a650 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 15 Nov 2024 16:33:10 +0000 Subject: [PATCH 04/14] Removed testing for python 3.9 since networkx 3.4 requires >=3.10 --- .github/workflows/ci-tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-tests.yaml b/.github/workflows/ci-tests.yaml index a218ab19..49d1fab3 100644 --- a/.github/workflows/ci-tests.yaml +++ b/.github/workflows/ci-tests.yaml @@ -18,7 +18,7 @@ jobs: strategy: matrix: os: ["ubuntu-latest", "windows-latest", "macos-latest"] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python From 550ad1cf2771fbd1eeb7daace30971b2fdde7be8 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 15 Nov 2024 16:37:44 +0000 Subject: [PATCH 05/14] pylint --- causal_testing/specification/causal_dag.py | 6 ++++-- .../specification/metamorphic_relation.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index 45fd5327..867d13a3 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -291,8 +291,10 @@ def direct_effect_adjustment_sets( 2019. These works use the algorithm presented by Takata et al. in their work entitled: Space-optimal, backtracking algorithms to list the minimal vertex separators of a graph, 2013. - :param list[str] treatments: List of treatment names. - :param list[str] outcomes: List of outcome names. + :param treatments: List of treatment names. + :param outcomes: List of outcome names. + :param nodes_to_ignore: List of nodes to exclude from tests if they appear as treatments, outcomes, or in the + adjustment set. :return: A list of possible adjustment sets. :rtype: list[set[str]] """ diff --git a/causal_testing/specification/metamorphic_relation.py b/causal_testing/specification/metamorphic_relation.py index 87633df5..6c9c0b9f 100644 --- a/causal_testing/specification/metamorphic_relation.py +++ b/causal_testing/specification/metamorphic_relation.py @@ -269,7 +269,7 @@ def generate_metamorphic_relation( def generate_metamorphic_relations( - dag: CausalDAG, nodes_to_ignore: set = {}, threads: int = 0, nodes_to_test: set = None + dag: CausalDAG, nodes_to_ignore: set = None, threads: int = 0, nodes_to_test: set = None ) -> list[MetamorphicRelation]: """Construct a list of metamorphic relations implied by the Causal DAG. @@ -279,10 +279,14 @@ def generate_metamorphic_relations( :param dag: Causal DAG from which the metamorphic relations will be generated. :param nodes_to_ignore: Set of nodes which will be excluded from causal tests. :param threads: Number of threads to use (if generating in parallel). + :param nodes_to_ignore: Set of nodes to test the relationships between (defaults to all nodes). :return: A list containing ShouldCause and ShouldNotCause metamorphic relations. """ + if nodes_to_ignore is None: + nodes_to_ignore = {} + if nodes_to_test is None: nodes_to_test = dag.graph.nodes @@ -329,7 +333,7 @@ def generate_metamorphic_relations( causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles) - nodes_to_test = set( + dag_nodes_to_test = set( k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True" ) @@ -339,10 +343,13 @@ def generate_metamorphic_relations( "Your causal test suite WILL NOT BE COMPLETE!" ) relations = generate_metamorphic_relations( - causal_dag, nodes_to_test=nodes_to_test, nodes_to_ignore=set(causal_dag.cycle_nodes()), threads=args.threads + causal_dag, + nodes_to_test=dag_nodes_to_test, + nodes_to_ignore=set(causal_dag.cycle_nodes()), + threads=args.threads, ) else: - relations = generate_metamorphic_relations(causal_dag, nodes_to_test=nodes_to_test, threads=args.threads) + relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads) tests = [ relation.to_json_stub(skip=False) From d0c9ee273830da6e1c7475c911ebfacd991c5314 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 15 Nov 2024 16:52:36 +0000 Subject: [PATCH 06/14] pylint --- causal_testing/specification/metamorphic_relation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causal_testing/specification/metamorphic_relation.py b/causal_testing/specification/metamorphic_relation.py index 6c9c0b9f..4a7c70c9 100644 --- a/causal_testing/specification/metamorphic_relation.py +++ b/causal_testing/specification/metamorphic_relation.py @@ -279,7 +279,7 @@ def generate_metamorphic_relations( :param dag: Causal DAG from which the metamorphic relations will be generated. :param nodes_to_ignore: Set of nodes which will be excluded from causal tests. :param threads: Number of threads to use (if generating in parallel). - :param nodes_to_ignore: Set of nodes to test the relationships between (defaults to all nodes). + :param nodes_to_test: Set of nodes to test the relationships between (defaults to all nodes). :return: A list containing ShouldCause and ShouldNotCause metamorphic relations. """ From ef4df173f0e6f56743f3a80abf6b9c4b7a59323d Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Tue, 26 Nov 2024 13:06:11 +0000 Subject: [PATCH 07/14] LR Estimate prediction --- .../estimation/linear_regression_estimator.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/causal_testing/estimation/linear_regression_estimator.py b/causal_testing/estimation/linear_regression_estimator.py index 85a4b178..9c7fdb02 100644 --- a/causal_testing/estimation/linear_regression_estimator.py +++ b/causal_testing/estimation/linear_regression_estimator.py @@ -177,6 +177,24 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Se ci_high = pd.Series(treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]) return pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high] + def estimate_prediction(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]: + """Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused + by changing the treatment variable from the control value to the treatment value. Here, we actually + calculate the expected outcomes under control and treatment and divide one by the other. This + allows for custom terms to be put in such as squares, inverses, products, etc. + + :param: adjustment_config: The configuration of the adjustment set as a dict mapping variable names to + their values. N.B. Every variable in the adjustment set MUST have a value in + order to estimate the outcome under control and treatment. + + :return: The average treatment effect and the 95% Wald confidence intervals. + """ + prediction = self._predict(adjustment_config=adjustment_config) + outcome = prediction.iloc[1] + ci_low = pd.Series(outcome["mean_ci_upper"]) + ci_high = pd.Series(outcome["mean_ci_lower"]) + return pd.Series(outcome["mean"]), [ci_low, ci_high] + def _get_confidence_intervals(self, model, treatment): confidence_intervals = model.conf_int(alpha=self.alpha, cols=None) ci_low, ci_high = ( From f5400ee2bc22a40d204fd223e76422c962304332 Mon Sep 17 00:00:00 2001 From: Farhad Allian <39086289+f-allian@users.noreply.github.com> Date: Wed, 27 Nov 2024 10:06:06 +0000 Subject: [PATCH 08/14] Update ci-tests-drafts.yaml for ignoring cycles we need latest version of networkx --- .github/workflows/ci-tests-drafts.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-tests-drafts.yaml b/.github/workflows/ci-tests-drafts.yaml index 8575f760..cc504b58 100644 --- a/.github/workflows/ci-tests-drafts.yaml +++ b/.github/workflows/ci-tests-drafts.yaml @@ -13,7 +13,7 @@ jobs: strategy: matrix: os: ["ubuntu-latest", "windows-latest", "macos-latest"] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python From 75efc80b99cc2515f6363cac1e36f512ef5a97eb Mon Sep 17 00:00:00 2001 From: Michael Foster <13611658+jmafoster1@users.noreply.github.com> Date: Wed, 27 Nov 2024 10:55:40 +0000 Subject: [PATCH 09/14] Updated IV check to explicitly use len --- causal_testing/specification/causal_dag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index 867d13a3..7fb1a3f7 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -173,8 +173,8 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: ( cause for cause in self.graph.nodes - if list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) - and list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) + if len(list(nx.all_simple_paths(self.graph, source=cause, target=instrument))) > 0 + and len(list(nx.all_simple_paths(self.graph, source=cause, target=outcome))) > 0 ) ): raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") From adbfb726a00e926933c239ab43bc7b80754d7a05 Mon Sep 17 00:00:00 2001 From: Farhad Allian Date: Thu, 28 Nov 2024 13:18:03 +0000 Subject: [PATCH 10/14] fix: exlude self-cycles in newer NetworkX versions --- causal_testing/specification/causal_dag.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index 7fb1a3f7..ed1a23dc 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -169,16 +169,14 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: ) # (iii) Instrument and outcome do not share causes - if any( - ( - cause - for cause in self.graph.nodes - if len(list(nx.all_simple_paths(self.graph, source=cause, target=instrument))) > 0 - and len(list(nx.all_simple_paths(self.graph, source=cause, target=outcome))) > 0 - ) - ): - raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") + for cause in self.graph.nodes: + if cause not in (instrument, outcome): # exclude self-cycles due to breaking changes in NetworkX > 3.2 + instrument_paths = list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) + outcome_paths = list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) + if len(instrument_paths) > 0 and len(outcome_paths) > 0: + print(cause, instrument, instrument_paths, outcome, outcome_paths) + raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") return True def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr): From 4de42e5db4cab343fb38d57790da58b60baf240e Mon Sep 17 00:00:00 2001 From: Farhad Allian Date: Fri, 29 Nov 2024 09:59:29 +0000 Subject: [PATCH 11/14] fix: from comments by @jmafoster1 --- causal_testing/specification/causal_dag.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index ed1a23dc..f00d4ad8 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -171,12 +171,15 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: # (iii) Instrument and outcome do not share causes for cause in self.graph.nodes: - if cause not in (instrument, outcome): # exclude self-cycles due to breaking changes in NetworkX > 3.2 - instrument_paths = list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) - outcome_paths = list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) - if len(instrument_paths) > 0 and len(outcome_paths) > 0: - print(cause, instrument, instrument_paths, outcome, outcome_paths) - raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") + # Exclude self-cycles due to breaking changes in NetworkX > 3.2 + outcome_paths = ( + list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else [] + ) + instrument_paths = ( + list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else [] + ) + if len(instrument_paths) > 0 and len(outcome_paths) > 0: + raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") return True def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr): From 7b4c4d80e4959372738462c09aa4786322bae1d3 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 29 Nov 2024 15:46:12 +0000 Subject: [PATCH 12/14] removed estimate_prediction --- .../estimation/linear_regression_estimator.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/causal_testing/estimation/linear_regression_estimator.py b/causal_testing/estimation/linear_regression_estimator.py index 9c7fdb02..85a4b178 100644 --- a/causal_testing/estimation/linear_regression_estimator.py +++ b/causal_testing/estimation/linear_regression_estimator.py @@ -177,24 +177,6 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Se ci_high = pd.Series(treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]) return pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high] - def estimate_prediction(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]: - """Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused - by changing the treatment variable from the control value to the treatment value. Here, we actually - calculate the expected outcomes under control and treatment and divide one by the other. This - allows for custom terms to be put in such as squares, inverses, products, etc. - - :param: adjustment_config: The configuration of the adjustment set as a dict mapping variable names to - their values. N.B. Every variable in the adjustment set MUST have a value in - order to estimate the outcome under control and treatment. - - :return: The average treatment effect and the 95% Wald confidence intervals. - """ - prediction = self._predict(adjustment_config=adjustment_config) - outcome = prediction.iloc[1] - ci_low = pd.Series(outcome["mean_ci_upper"]) - ci_high = pd.Series(outcome["mean_ci_lower"]) - return pd.Series(outcome["mean"]), [ci_low, ci_high] - def _get_confidence_intervals(self, model, treatment): confidence_intervals = model.conf_int(alpha=self.alpha, cols=None) ci_low, ci_high = ( From 0357c85d3f09b367cd027314e93d743e4d2cfcf0 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 29 Nov 2024 16:07:14 +0000 Subject: [PATCH 13/14] tests pass and coverage increased --- .../test_metamorphic_relations.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/specification_tests/test_metamorphic_relations.py b/tests/specification_tests/test_metamorphic_relations.py index dc35e071..cbae7fa2 100644 --- a/tests/specification_tests/test_metamorphic_relations.py +++ b/tests/specification_tests/test_metamorphic_relations.py @@ -73,6 +73,10 @@ def setUp(self) -> None: dag_dot = """digraph DAG { rankdir=LR; X1 -> Z; Z -> M; M -> Y; X2 -> Z; X3 -> M;}""" with open(self.dag_dot_path, "w") as f: f.write(dag_dot) + self.dcg_dot_path = os.path.join(self.temp_dir_path, "dcg.dot") + dcg_dot = """digraph dct { a -> b -> c -> d; d -> c; }""" + with open(self.dcg_dot_path, "w") as f: + f.write(dcg_dot) X1 = Input("X1", float) X2 = Input("X2", float) @@ -248,6 +252,71 @@ def test_all_metamorphic_relations_implied_by_dag(self): self.assertEqual(extra_snc_relations, []) self.assertEqual(missing_snc_relations, []) + def test_all_metamorphic_relations_implied_by_dag_parallel(self): + dag = CausalDAG(self.dag_dot_path) + dag.add_edge("Z", "Y") # Add a direct path from Z to Y so M becomes a mediator + metamorphic_relations = generate_metamorphic_relations(dag, threads=2) + should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)] + should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)] + + # Check all ShouldCause relations are present and no extra + expected_should_cause_relations = [ + ShouldCause("X1", "Z", [], dag), + ShouldCause("Z", "M", [], dag), + ShouldCause("M", "Y", ["Z"], dag), + ShouldCause("Z", "Y", ["M"], dag), + ShouldCause("X2", "Z", [], dag), + ShouldCause("X3", "M", [], dag), + ] + + extra_sc_relations = [scr for scr in should_cause_relations if scr not in expected_should_cause_relations] + missing_sc_relations = [escr for escr in expected_should_cause_relations if escr not in should_cause_relations] + + self.assertEqual(extra_sc_relations, []) + self.assertEqual(missing_sc_relations, []) + + # Check all ShouldNotCause relations are present and no extra + expected_should_not_cause_relations = [ + ShouldNotCause("X1", "X2", [], dag), + ShouldNotCause("X1", "X3", [], dag), + ShouldNotCause("X1", "M", ["Z"], dag), + ShouldNotCause("X1", "Y", ["Z"], dag), + ShouldNotCause("X2", "X3", [], dag), + ShouldNotCause("X2", "M", ["Z"], dag), + ShouldNotCause("X2", "Y", ["Z"], dag), + ShouldNotCause("X3", "Y", ["M", "Z"], dag), + ShouldNotCause("Z", "X3", [], dag), + ] + + extra_snc_relations = [ + sncr for sncr in should_not_cause_relations if sncr not in expected_should_not_cause_relations + ] + missing_snc_relations = [ + esncr for esncr in expected_should_not_cause_relations if esncr not in should_not_cause_relations + ] + + self.assertEqual(extra_snc_relations, []) + self.assertEqual(missing_snc_relations, []) + + def test_all_metamorphic_relations_implied_by_dag_ignore_cycles(self): + dag = CausalDAG(self.dcg_dot_path, ignore_cycles=True) + metamorphic_relations = generate_metamorphic_relations(dag, threads=2, nodes_to_ignore=set(dag.cycle_nodes())) + should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)] + should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)] + + # Check all ShouldCause relations are present and no extra + + self.assertEqual( + should_cause_relations, + [ + ShouldCause("a", "b", [], dag), + ], + ) + self.assertEqual( + should_not_cause_relations, + [], + ) + def test_equivalent_metamorphic_relations(self): dag = CausalDAG(self.dag_dot_path) sc_mr_a = ShouldCause("X", "Y", ["A", "B", "C"], dag) From 499fc5421bd46e813983f09d352f3b2bdb0c8c63 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 29 Nov 2024 16:16:05 +0000 Subject: [PATCH 14/14] 100% coverage --- tests/specification_tests/test_metamorphic_relations.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/specification_tests/test_metamorphic_relations.py b/tests/specification_tests/test_metamorphic_relations.py index cbae7fa2..7998b66c 100644 --- a/tests/specification_tests/test_metamorphic_relations.py +++ b/tests/specification_tests/test_metamorphic_relations.py @@ -10,6 +10,7 @@ ShouldCause, ShouldNotCause, generate_metamorphic_relations, + generate_metamorphic_relation, ) from causal_testing.data_collection.data_collector import ExperimentalDataCollector from causal_testing.specification.variable import Input, Output @@ -317,6 +318,14 @@ def test_all_metamorphic_relations_implied_by_dag_ignore_cycles(self): [], ) + def test_generate_metamorphic_relation_(self): + dag = CausalDAG(self.dag_dot_path) + [metamorphic_relation] = generate_metamorphic_relation(("X1", "Z"), dag) + self.assertEqual( + metamorphic_relation, + ShouldCause("X1", "Z", [], dag), + ) + def test_equivalent_metamorphic_relations(self): dag = CausalDAG(self.dag_dot_path) sc_mr_a = ShouldCause("X", "Y", ["A", "B", "C"], dag)