diff --git a/causal_testing/specification/causal_dag.py b/causal_testing/specification/causal_dag.py index 867d13a3..7fb1a3f7 100644 --- a/causal_testing/specification/causal_dag.py +++ b/causal_testing/specification/causal_dag.py @@ -173,8 +173,8 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool: ( cause for cause in self.graph.nodes - if list(nx.all_simple_paths(self.graph, source=cause, target=instrument)) - and list(nx.all_simple_paths(self.graph, source=cause, target=outcome)) + if len(list(nx.all_simple_paths(self.graph, source=cause, target=instrument))) > 0 + and len(list(nx.all_simple_paths(self.graph, source=cause, target=outcome))) > 0 ) ): raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")