Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore cycles #294

Merged
merged 14 commits into from
Dec 3, 2024
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-drafts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "windows-latest", "macos-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
Expand Down
18 changes: 18 additions & 0 deletions causal_testing/estimation/linear_regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,24 @@
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"])
return pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]

def estimate_prediction(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value. Here, we actually
calculate the expected outcomes under control and treatment and divide one by the other. This
allows for custom terms to be put in such as squares, inverses, products, etc.

:param: adjustment_config: The configuration of the adjustment set as a dict mapping variable names to
their values. N.B. Every variable in the adjustment set MUST have a value in
order to estimate the outcome under control and treatment.

:return: The average treatment effect and the 95% Wald confidence intervals.
"""
prediction = self._predict(adjustment_config=adjustment_config)
outcome = prediction.iloc[1]
ci_low = pd.Series(outcome["mean_ci_upper"])
ci_high = pd.Series(outcome["mean_ci_lower"])
return pd.Series(outcome["mean"]), [ci_low, ci_high]

Check warning on line 196 in causal_testing/estimation/linear_regression_estimator.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/estimation/linear_regression_estimator.py#L192-L196

Added lines #L192 - L196 were not covered by tests

def _get_confidence_intervals(self, model, treatment):
confidence_intervals = model.conf_int(alpha=self.alpha, cols=None)
ci_low, ci_high = (
Expand Down
4 changes: 2 additions & 2 deletions causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
data_paths = []
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)

def setup(self, scenario: Scenario, data=None):
def setup(self, scenario: Scenario, data=None, ignore_cycles=False):
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
self.scenario = scenario
self._get_scenario_variables()
self.scenario.setup_treatment_variables()
self.causal_specification = CausalSpecification(
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path, ignore_cycles=ignore_cycles)
)
# Parse the JSON test plan
with open(self.input_paths.json_path, encoding="utf-8") as f:
Expand Down
48 changes: 32 additions & 16 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
ensures it is acyclic. A CausalDAG must be specified as a dot file.
"""

def __init__(self, dot_path: str = None, **attr):
def __init__(self, dot_path: str = None, ignore_cycles: bool = False, **attr):
super().__init__(**attr)
if dot_path:
with open(dot_path, "r", encoding="utf-8") as file:
Expand All @@ -144,7 +144,12 @@
self.graph = nx.DiGraph()

if not self.is_acyclic():
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
if ignore_cycles:
logger.warning(

Check warning on line 148 in causal_testing/specification/causal_dag.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/specification/causal_dag.py#L148

Added line #L148 was not covered by tests
"Cycles found. Ignoring them can invalidate causal estimates. Proceed with extreme caution."
)
else:
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")

def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
"""
Expand All @@ -164,16 +169,14 @@
)

# (iii) Instrument and outcome do not share causes
if any(
(
cause
for cause in self.graph.nodes
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
)
):
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")

for cause in self.graph.nodes:
if cause not in (instrument, outcome): # exclude self-cycles due to breaking changes in NetworkX > 3.2
instrument_paths = list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
outcome_paths = list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
if len(instrument_paths) > 0 and len(outcome_paths) > 0:
print(cause, instrument, instrument_paths, outcome, outcome_paths)
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
return True

def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
Expand All @@ -188,12 +191,18 @@
if not self.is_acyclic():
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")

def cycle_nodes(self) -> list:
"""Get the nodes involved in any cycles.
:return: A list containing all nodes involved in a cycle.
"""
return [node for cycle in nx.simple_cycles(self.graph) for node in cycle]

def is_acyclic(self) -> bool:
"""Checks if the graph is acyclic.

:return: True if acyclic, False otherwise.
"""
return not list(nx.simple_cycles(self.graph))
return not self.cycle_nodes()

def get_proper_backdoor_graph(self, treatments: list[str], outcomes: list[str]) -> CausalDAG:
"""Convert the causal DAG to a proper back-door graph.
Expand Down Expand Up @@ -267,7 +276,9 @@
gback.graph.remove_edge(v1, v2)
return gback

def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
def direct_effect_adjustment_sets(
self, treatments: list[str], outcomes: list[str], nodes_to_ignore: list[str] = None
) -> list[set[str]]:
"""
Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
and outcomes for DIRECT causal effect.
Expand All @@ -278,12 +289,17 @@
2019. These works use the algorithm presented by Takata et al. in their work entitled: Space-optimal,
backtracking algorithms to list the minimal vertex separators of a graph, 2013.

:param list[str] treatments: List of treatment names.
:param list[str] outcomes: List of outcome names.
:param treatments: List of treatment names.
:param outcomes: List of outcome names.
:param nodes_to_ignore: List of nodes to exclude from tests if they appear as treatments, outcomes, or in the
adjustment set.
:return: A list of possible adjustment sets.
:rtype: list[set[str]]
"""

if nodes_to_ignore is None:
nodes_to_ignore = []

indirect_graph = self.get_indirect_graph(treatments, outcomes)
ancestor_graph = indirect_graph.get_ancestor_graph(treatments, outcomes)
gam = nx.moral_graph(ancestor_graph.graph)
Expand All @@ -295,7 +311,7 @@
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
if set(outcomes) in min_seps:
min_seps.remove(set(outcomes))
return min_seps
return sorted(list(filter(lambda sep: not sep.intersection(nodes_to_ignore), min_seps)))

def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
"""Get the smallest possible set of variables that blocks all back-door paths between all pairs of treatments
Expand Down
140 changes: 107 additions & 33 deletions causal_testing/specification/metamorphic_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import argparse
import logging
import json
from multiprocessing import Pool

import networkx as nx
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -214,46 +216,96 @@
)


def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
def generate_metamorphic_relation(
node_pair: tuple[str, str], dag: CausalDAG, nodes_to_ignore: set = None
) -> MetamorphicRelation:
"""Construct a metamorphic relation for a given node pair implied by the Causal DAG, or None if no such relation can
be constructed (e.g. because every valid adjustment set contains a node to ignore).

:param node_pair: The pair of nodes to consider.
:param dag: Causal DAG from which the metamorphic relations will be generated.
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.

:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
"""

if nodes_to_ignore is None:
nodes_to_ignore = set()

Check warning on line 233 in causal_testing/specification/metamorphic_relation.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/specification/metamorphic_relation.py#L233

Added line #L233 was not covered by tests

(u, v) = node_pair
metamorphic_relations = []

# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(v, u, list(adj_sets[0]), dag))

# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
# Only make one MR since V _||_ U == U _||_ V
else:
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))

# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldCause(u, v, list(adj_sets[0]), dag))
else:
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
if adj_sets:
metamorphic_relations.append(ShouldCause(v, u, list(adj_sets[0]), dag))
return metamorphic_relations


def generate_metamorphic_relations(
dag: CausalDAG, nodes_to_ignore: set = None, threads: int = 0, nodes_to_test: set = None
) -> list[MetamorphicRelation]:
"""Construct a list of metamorphic relations implied by the Causal DAG.

This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
relation for every (minimal) conditional independence relation implied by the structure of the DAG.

:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
:param dag: Causal DAG from which the metamorphic relations will be generated.
:param nodes_to_ignore: Set of nodes which will be excluded from causal tests.
:param threads: Number of threads to use (if generating in parallel).
:param nodes_to_test: Set of nodes to test the relationships between (defaults to all nodes).

:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
"""
metamorphic_relations = []
for node_pair in combinations(dag.graph.nodes, 2):
(u, v) = node_pair

# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
# Case 1: U --> ... --> V
if u in nx.ancestors(dag.graph, v):
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Case 2: V --> ... --> U
elif v in nx.ancestors(dag.graph, u):
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))

# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
# Only make one MR since V _||_ U == U _||_ V
else:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))

# Create a ShouldCause relation for each edge (u, v) or (v, u)
elif (u, v) in dag.graph.edges:
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
else:
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
if nodes_to_ignore is None:
nodes_to_ignore = {}

return metamorphic_relations
if nodes_to_test is None:
nodes_to_test = dag.graph.nodes

if not threads:
metamorphic_relations = [
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2)
]
else:
with Pool(threads) as pool:
metamorphic_relations = pool.starmap(

Check warning on line 300 in causal_testing/specification/metamorphic_relation.py

View check run for this annotation

Codecov / codecov/patch

causal_testing/specification/metamorphic_relation.py#L299-L300

Added lines #L299 - L300 were not covered by tests
generate_metamorphic_relation,
map(
lambda node_pair: (node_pair, dag, nodes_to_ignore),
combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2),
),
)

return [item for items in metamorphic_relations for item in items]


if __name__ == "__main__": # pragma: no cover
Expand All @@ -273,10 +325,32 @@
help="Specify path where tests should be saved, normally a .json file.",
required=True,
)
parser.add_argument(
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
)
parser.add_argument("-i", "--ignore-cycles", action="store_true")
args = parser.parse_args()

causal_dag = CausalDAG(args.dag_path)
relations = generate_metamorphic_relations(causal_dag)
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)

dag_nodes_to_test = set(
k for k, v in nx.get_node_attributes(causal_dag.graph, "test", default=True).items() if v == "True"
)

if not causal_dag.is_acyclic() and args.ignore_cycles:
logger.warning(
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
"Your causal test suite WILL NOT BE COMPLETE!"
)
relations = generate_metamorphic_relations(
causal_dag,
nodes_to_test=dag_nodes_to_test,
nodes_to_ignore=set(causal_dag.cycle_nodes()),
threads=args.threads,
)
else:
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads)

tests = [
relation.to_json_stub(skip=False)
for relation in relations
Expand Down
Loading