Skip to content

Commit

Permalink
Merge pull request #294 from CITCOM-project/ignore-cycles
Browse files Browse the repository at this point in the history
Ignore cycles
  • Loading branch information
jmafoster1 authored Dec 3, 2024
2 parents de0a676 + 499fc54 commit 4fdf12b
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-drafts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
Expand Down
4 changes: 2 additions & 2 deletions causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
data_paths = []
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)

def setup(self, scenario: Scenario, data=None):
def setup(self, scenario: Scenario, data=None, ignore_cycles=False):
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
self.scenario = scenario
self._get_scenario_variables()
self.scenario.setup_treatment_variables()
self.causal_specification = CausalSpecification(
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path, ignore_cycles=ignore_cycles)
)
# Parse the JSON test plan
with open(self.input_paths.json_path, encoding="utf-8") as f:
Expand Down
51 changes: 35 additions & 16 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class CausalDAG(nx.DiGraph):
ensures it is acyclic. A CausalDAG must be specified as a dot file.
"""

def __init__(self, dot_path: str = None, **attr):
def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
super().__init__(**attr)
if dot_path:
with open(dot_path, "r", encoding="utf-8") as file:
Expand All @@ -144,7 +144,12 @@ def __init__(self, dot_path: str = None, **attr):
self.graph = nx.DiGraph()

if not self.is_acyclic():
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
if ignore_cycles:
logger.warning(
"Cycles found. Ignoring them can invalidate causal estimates. Proceed with extreme caution."
)
else:
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")

def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
"""
Expand All @@ -164,16 +169,17 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
)

# (iii) Instrument and outcome do not share causes
if any(
(
cause
for cause in self.graph.nodes
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
)
):
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")

for cause in self.graph.nodes:
# Exclude self-cycles due to breaking changes in NetworkX > 3.2
outcome_paths = (
list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else []
)
instrument_paths = (
list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else []
)
if len(instrument_paths) > 0 and len(outcome_paths) > 0:
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
return True

def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
Expand All @@ -188,12 +194,18 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
if not self.is_acyclic():
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")

def cycle_nodes(self) -> list:
"""Get the nodes involved in any cycles.
:return: A list containing all nodes involved in a cycle.
"""
return [node for cycle in nx.simple_cycles(self.graph) for node in cycle]

def is_acyclic(self) -> bool:
"""Checks if the graph is acyclic.
:return: True if acyclic, False otherwise.
"""
return not list(nx.simple_cycles(self.graph))
return not self.cycle_nodes()

def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
"""Convert the causal DAG to a proper back-door graph.
Expand Down Expand Up @@ -267,7 +279,9 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
gback.graph.remove_edge(v1, v2)
return gback

def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
def direct_effect_adjustment_sets(
self, treatments: list[str], outcomes: list[str], nodes_to_ignore: list[str] = None
) -> list[set[str]]:
"""
Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
and outcomes for DIRECT causal effect.
Expand All @@ -278,12 +292,17 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
2019. These works use the algorithm presented by Takata et al. in their work entitled: Space-optimal,
backtracking algorithms to list the minimal vertex separators of a graph, 2013.
:param list[str] treatments: List of treatment names.
:param list[str] outcomes: List of outcome names.
:param treatments: List of treatment names.
:param outcomes: List of outcome names.
:param nodes_to_ignore: List of nodes to exclude from tests if they appear as treatments, outcomes, or in the
adjustment set.
:return: A list of possible adjustment sets.
:rtype: list[set[str]]
"""

if nodes_to_ignore is None:
nodes_to_ignore = []

indirect_graph = self.get_indirect_graph(treatments, outcomes)
ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes)
gam = nx.moral_graph(ancestor_graph.graph)
Expand All @@ -295,7 +314,7 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
if set(outcomes) in min_seps:
min_seps.remove(set(outcomes))
return min_seps
return sorted(list(filter(lambda sep: not sep.intersection(nodes_to_ignore), min_seps)))

def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
"""Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
Expand Down
140 changes: 107 additions & 33 deletions causal_testing/specification/metamorphic_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import argparse
import logging
import json
from multiprocessing import Pool

import networkx as nx
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -214,46 +216,96 @@ def __str__(self):
)


def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
def generate_metamorphic_relation(
node_pair: tuple[str, str], dag: CausalDAG, nodes_to_ignore: set = None
) -> MetamorphicRelation:
"""Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can
be constructed (e.g. because every valid adjustment set contains a node to ignore).
:param node_pair: The pair of nodes to consider.
:param dag: Causal DAG from which the metamorphic relations will be generated.
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
"""

if nodes_to_ignore is None:
nodes_to_ignore = set()

(u, v) = node_pair
metamorphic_relations = []

# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(v, u, list(adj_sets[0]), dag))

# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
# Only make one MR since V _||_ U == U _||_ V
else:
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))

# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldCause(u, v, list(adj_sets[0]), dag))
else:
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldCause(v, u, list(adj_sets[0]), dag))
return metamorphic_relations


def generate_metamorphic_relations(
dag: CausalDAG, nodes_to_ignore: set = None, threads: int = 0, nodes_to_test: set = None
) -> list[MetamorphicRelation]:
"""Construct a list of metamorphic relations implied by the Causal DAG.
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
:param dag: Causal DAG from which the metamorphic relations will be generated.
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
:param threads: Number of threads to use (if generating in parallel).
:param nodes_to_test: Set of nodes to test the relationships between (defaults to all nodes).
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
"""
metamorphic_relations = []
for node_pair in combinations(dag.graph.nodes, 2):
(u, v) = node_pair

# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))

# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
# Only make one MR since V _||_ U == U _||_ V
else:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
else:
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
if nodes_to_ignore is None:
nodes_to_ignore = {}

return metamorphic_relations
if nodes_to_test is None:
nodes_to_test = dag.graph.nodes

if not threads:
metamorphic_relations = [
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2)
]
else:
with Pool(threads) as pool:
metamorphic_relations = pool.starmap(
generate_metamorphic_relation,
map(
lambda node_pair: (node_pair, dag, nodes_to_ignore),
combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2),
),
)

return [item for items in metamorphic_relations for item in items]


if __name__ == "__main__": # pragma: no cover
Expand All @@ -273,10 +325,32 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
help="Specify path where tests should be saved, normally a .json file.",
required=True,
)
parser.add_argument(
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
)
parser.add_argument("-i", "--ignore-cycles", action="store_true")
args = parser.parse_args()

causal_dag = CausalDAG(args.dag_path)
relations = generate_metamorphic_relations(causal_dag)
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)

dag_nodes_to_test = set(
k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True"
)

if not causal_dag.is_acyclic() and args.ignore_cycles:
logger.warning(
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
"Your causal test suite WILL NOT BE COMPLETE!"
)
relations = generate_metamorphic_relations(
causal_dag,
nodes_to_test=dag_nodes_to_test,
nodes_to_ignore=set(causal_dag.cycle_nodes()),
threads=args.threads,
)
else:
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads)

tests = [
relation.to_json_stub(skip=False)
for relation in relations
Expand Down
Loading

0 comments on commit 4fdf12b

Please sign in to comment.