From 4de42e5db4cab343fb38d57790da58b60baf240e Mon Sep 17 00:00:00 2001 From: Farhad Allian Date: Fri, 29 Nov 2024 09:59:29 +0000 Subject: [PATCH] fix: from comments by @jmafoster1 --- causal_testing/specification/causal_dag.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index ed1a23dc..f00d4ad8 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -171,12 +171,15 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: # (iii) Instrument and outcome do not share causes for cause in self.graph.nodes: - if cause not in (instrument, outcome): # exclude self-cycles due to breaking changes in NetworkX > 3.2 - instrument_paths = list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) - outcome_paths = list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) - if len(instrument_paths) > 0 and len(outcome_paths) > 0: - print(cause, instrument, instrument_paths, outcome, outcome_paths) - raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") + # Exclude self-cycles due to breaking changes in NetworkX > 3.2 + outcome_paths = ( + list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) if cause != outcome else [] + ) + instrument_paths = ( + list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) if cause != instrument else [] + ) + if len(instrument_paths) > 0 and len(outcome_paths) > 0: + raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes") return True def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):