diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index ed1a23dc..f00d4ad8 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -171,12 +171,15 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: # (iii) Instrument and outcome do not share causes for cause in self.graph.nodes: - if cause not in (instrument, outcome): # exclude self-cycles due to breaking changes in NetworkX > 3.2 - instrument_paths = list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) - outcome_paths = list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) - if len(instrument_paths) > 0 and len(outcome_paths) > 0: - print(cause, instrument, instrument_paths, outcome, outcome_paths) - raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") + # Exclude self-cycles due to breaking changes in NetworkX > 3.2 + outcome_paths = ( + list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else [] + ) + instrument_paths = ( + list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else [] + ) + if len(instrument_paths) > 0 and len(outcome_paths) > 0: + raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") return True def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):