diff --git a/bamt/nodes/conditional_gaussian_node.py b/bamt/nodes/conditional_gaussian_node.py index 67c4582..58801b6 100644 --- a/bamt/nodes/conditional_gaussian_node.py +++ b/bamt/nodes/conditional_gaussian_node.py @@ -44,7 +44,7 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, CondGaussParams mask = (mask) & (data[col] == val) new_data = data[mask] key_comb = [str(x) for x in comb] - if new_data.shape[0] > 1: + if new_data.shape[0] > 0: if self.cont_parents: model = clone(self.regressor) model.fit(