From 1315151932020128a17d89fdfff443e2df5bede0 Mon Sep 17 00:00:00 2001 From: Roman Netrogolov <68499591+Roman223@users.noreply.github.com> Date: Thu, 30 Nov 2023 16:58:54 +0300 Subject: [PATCH] Rollback cond gaus (#92) * serialization fix * min_shape changed --- bamt/nodes/conditional_gaussian_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bamt/nodes/conditional_gaussian_node.py b/bamt/nodes/conditional_gaussian_node.py index 67c4582..58801b6 100644 --- a/bamt/nodes/conditional_gaussian_node.py +++ b/bamt/nodes/conditional_gaussian_node.py @@ -44,7 +44,7 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, CondGaussParams mask = (mask) & (data[col] == val) new_data = data[mask] key_comb = [str(x) for x in comb] - if new_data.shape[0] > 1: + if new_data.shape[0] > 0: if self.cont_parents: model = clone(self.regressor) model.fit(