Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore cycles #294

Merged
merged 14 commits into from
Dec 3, 2024
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-drafts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
Expand Down
4 changes: 2 additions & 2 deletions causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
data_paths = []
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)

def setup(self, scenario: Scenario, data=None):
def setup(self, scenario: Scenario, data=None, ignore_cycles=False):
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
self.scenario = scenario
self._get_scenario_variables()
self.scenario.setup_treatment_variables()
self.causal_specification = CausalSpecification(
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path, ignore_cycles=ignore_cycles)
)
# Parse the JSON test plan
with open(self.input_paths.json_path, encoding="utf-8") as f:
Expand Down
51 changes: 35 additions & 16 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class CausalDAG(nx.DiGraph):
ensures it is acyclic. A CausalDAG must be specified as a dot file.
"""

def __init__(self, dot_path: str = None, **attr):
def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
super().__init__(**attr)
if dot_path:
with open(dot_path, "r", encoding="utf-8") as file:
Expand All @@ -144,7 +144,12 @@ def __init__(self, dot_path: str = None, **attr):
self.graph = nx.DiGraph()

if not self.is_acyclic():
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
if ignore_cycles:
logger.warning(
"Cycles found. Ignoring them can invalidate causal estimates. Proceed with extreme caution."
)
else:
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")

def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
"""
Expand All @@ -164,16 +169,17 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
)

# (iii) Instrument and outcome do not share causes
if any(
(
cause
for cause in self.graph.nodes
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
)
):
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")

for cause in self.graph.nodes:
# Exclude self-cycles due to breaking changes in NetworkX > 3.2
outcome_paths = (
list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else []
)
instrument_paths = (
list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else []
)
if len(instrument_paths) > 0 and len(outcome_paths) > 0:
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
return True

def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
Expand All @@ -188,12 +194,18 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
if not self.is_acyclic():
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")

def cycle_nodes(self) -> list:
"""Get the nodes involved in any cycles.
:return: A list containing all nodes involved in a cycle.
"""
return [node for cycle in nx.simple_cycles(self.graph) for node in cycle]

def is_acyclic(self) -> bool:
"""Checks if the graph is acyclic.

:return: True if acyclic, False otherwise.
"""
return not list(nx.simple_cycles(self.graph))
return not self.cycle_nodes()

def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
"""Convert the causal DAG to a proper back-door graph.
Expand Down Expand Up @@ -267,7 +279,9 @@ def get_indirect_graph(self, treatments: list[str], outcomes: list[str]) -> Caus
gback.graph.remove_edge(v1, v2)
return gback

def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
def direct_effect_adjustment_sets(
self, treatments: list[str], outcomes: list[str], nodes_to_ignore: list[str] = None
) -> list[set[str]]:
"""
Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
and outcomes for DIRECT causal effect.
Expand All @@ -278,12 +292,17 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
2019. These works use the algorithm presented by Takata et al. in their work entitled: Space-optimal,
backtracking algorithms to list the minimal vertex separators of a graph, 2013.

:param list[str] treatments: List of treatment names.
:param list[str] outcomes: List of outcome names.
:param treatments: List of treatment names.
:param outcomes: List of outcome names.
:param nodes_to_ignore: List of nodes to exclude from tests if they appear as treatments, outcomes, or in the
adjustment set.
:return: A list of possible adjustment sets.
:rtype: list[set[str]]
"""

if nodes_to_ignore is None:
nodes_to_ignore = []

indirect_graph = self.get_indirect_graph(treatments, outcomes)
ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes)
gam = nx.moral_graph(ancestor_graph.graph)
Expand All @@ -295,7 +314,7 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
if set(outcomes) in min_seps:
min_seps.remove(set(outcomes))
return min_seps
return sorted(list(filter(lambda sep: not sep.intersection(nodes_to_ignore), min_seps)))

def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
"""Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
Expand Down
140 changes: 107 additions & 33 deletions causal_testing/specification/metamorphic_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import argparse
import logging
import json
from multiprocessing import Pool

import networkx as nx
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -214,46 +216,96 @@ def __str__(self):
)


def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
def generate_metamorphic_relation(
node_pair: tuple[str, str], dag: CausalDAG, nodes_to_ignore: set = None
) -> MetamorphicRelation:
"""Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can
be constructed (e.g. because every valid adjustment set contains a node to ignore).

:param node_pair: The pair of nodes to consider.
:param dag: Causal DAG from which the metamorphic relations will be generated.
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.

:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
"""

if nodes_to_ignore is None:
nodes_to_ignore = set()

(u, v) = node_pair
metamorphic_relations = []

# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(v, u, list(adj_sets[0]), dag))

# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
# Only make one MR since V _||_ U == U _||_ V
else:
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))

# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldCause(u, v, list(adj_sets[0]), dag))
else:
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldCause(v, u, list(adj_sets[0]), dag))
return metamorphic_relations


def generate_metamorphic_relations(
dag: CausalDAG, nodes_to_ignore: set = None, threads: int = 0, nodes_to_test: set = None
) -> list[MetamorphicRelation]:
"""Construct a list of metamorphic relations implied by the Causal DAG.

This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
relation for every (minimal) conditional independence relation implied by the structure of the DAG.

:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
:param dag: Causal DAG from which the metamorphic relations will be generated.
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
:param threads: Number of threads to use (if generating in parallel).
:param nodes_to_test: Set of nodes to test the relationships between (defaults to all nodes).

:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
"""
metamorphic_relations = []
for node_pair in combinations(dag.graph.nodes, 2):
(u, v) = node_pair

# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))

# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
# Only make one MR since V _||_ U == U _||_ V
else:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
else:
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
if nodes_to_ignore is None:
nodes_to_ignore = {}

return metamorphic_relations
if nodes_to_test is None:
nodes_to_test = dag.graph.nodes

if not threads:
metamorphic_relations = [
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2)
]
else:
with Pool(threads) as pool:
metamorphic_relations = pool.starmap(
generate_metamorphic_relation,
map(
lambda node_pair: (node_pair, dag, nodes_to_ignore),
combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2),
),
)

return [item for items in metamorphic_relations for item in items]


if __name__ == "__main__": # pragma: no cover
Expand All @@ -273,10 +325,32 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
help="Specify path where tests should be saved, normally a .json file.",
required=True,
)
parser.add_argument(
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
)
parser.add_argument("-i", "--ignore-cycles", action="store_true")
args = parser.parse_args()

causal_dag = CausalDAG(args.dag_path)
relations = generate_metamorphic_relations(causal_dag)
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)

dag_nodes_to_test = set(
k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True"
)

if not causal_dag.is_acyclic() and args.ignore_cycles:
logger.warning(
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
"Your causal test suite WILL NOT BE COMPLETE!"
)
relations = generate_metamorphic_relations(
causal_dag,
nodes_to_test=dag_nodes_to_test,
nodes_to_ignore=set(causal_dag.cycle_nodes()),
threads=args.threads,
)
else:
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads)

tests = [
relation.to_json_stub(skip=False)
for relation in relations
Expand Down
Loading