diff --git a/causal_testing/data_collection/data_collector.py b/causal_testing/data_collection/data_collector.py index 9a0a152f..718ed222 100644 --- a/causal_testing/data_collection/data_collector.py +++ b/causal_testing/data_collection/data_collector.py @@ -48,6 +48,10 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da f"Missing columns: missing data for variables {missing_variables}. Should they be marked as hidden?" ) + # Quick out if we don't have any constraints + if len(self.scenario.constraints) == 0: + return data + # For each row, does it satisfy the constraints? solver = z3.Solver() for c in self.scenario.constraints: @@ -56,6 +60,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da unsat_core = None for _, row in data.iterrows(): solver.push() + # Check that the row does not violate any scenario constraints # Need to explicitly cast variables to their specified type. Z3 will not take e.g. np.int64 to be an int. model = [ self.scenario.variables[var].z3