diff --git a/bamt/nodes/conditional_mixture_gaussian_node.py b/bamt/nodes/conditional_mixture_gaussian_node.py index ed2e32e..11d1df2 100644 --- a/bamt/nodes/conditional_mixture_gaussian_node.py +++ b/bamt/nodes/conditional_mixture_gaussian_node.py @@ -163,8 +163,9 @@ def choose( pvals: parent values """ mean, covariance, w = self.get_dist(node_info, pvals) - - if np.isnan(w): + + # check if w is nan or list of weights + if not isinstance(w, list): return np.nan n_comp = len(w)