diff --git a/bamt/networks/base.py b/bamt/networks/base.py index c7c4d13..1d8dadf 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -816,7 +816,26 @@ def find_family( plot_( plot_to, [self[name] for name in structure["nodes"]], structure["edges"] ) - return structure + + def get_dist(self, node_name: str, pvals: Optional[dict] = None): + """ + Get a distribution from node with known parent values (conditional distribution). + + :param node_name: name of node + :param pvals: parent values + """ + if not self.distributions: + logger_network.error("Empty parameters. Call fit_params first.") + return + node = self[node_name] + + parents = node.cont_parents + node.disc_parents + if not parents: + return self.distributions[node_name] + + pvals = [pvals[parent] for parent in parents] + + return node.get_dist(node_info=self.distributions[node_name], pvals=pvals) def _encode_categorical_data(self, data): for column in data.select_dtypes(include=["object", "string"]).columns: diff --git a/bamt/nodes/base.py b/bamt/nodes/base.py index c559d92..9824f09 100644 --- a/bamt/nodes/base.py +++ b/bamt/nodes/base.py @@ -85,3 +85,7 @@ def get_path_joblib(node_name: str, specific: str = "") -> str: os.path.join(path_to_check, f"{specific}.joblib.compressed") ) return path + + @staticmethod + def get_dist(node_info, pvals): + pass diff --git a/bamt/nodes/conditional_gaussian_node.py b/bamt/nodes/conditional_gaussian_node.py index f36bd5b..fe97927 100644 --- a/bamt/nodes/conditional_gaussian_node.py +++ b/bamt/nodes/conditional_gaussian_node.py @@ -107,18 +107,7 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, CondGaussParams } return {"hybcprob": hycprob} - def choose( - self, - node_info: Dict[str, Dict[str, CondGaussParams]], - pvals: List[Union[str, float]], - ) -> float: - """ - Return value from ConditionalLogit node - params: - node_info: nodes info from distributions - pvals: parent values - """ - + def get_dist(self, node_info, pvals): dispvals = [] lgpvals = [] for pval in pvals: @@ -138,7 +127,7 @@ def choose( flag = True break if flag: - return np.nan + return np.nan, np.nan else: if lgdistribution["regressor"]: if lgdistribution["serialization"] == "joblib": @@ -150,14 +139,30 @@ def choose( cond_mean = model.predict(np.array(lgpvals).reshape(1, -1))[0] variance = lgdistribution["variance"] - return random.gauss(cond_mean, variance) + return cond_mean, variance else: - return np.nan + return np.nan, np.nan else: - return random.gauss( - lgdistribution["mean"], math.sqrt(lgdistribution["variance"]) - ) + return lgdistribution["mean"], math.sqrt(lgdistribution["variance"]) + + def choose( + self, + node_info: Dict[str, Dict[str, CondGaussParams]], + pvals: List[Union[str, float]], + ) -> float: + """ + Return value from ConditionalLogit node + params: + node_info: nodes info from distributions + pvals: parent values + """ + + cond_mean, variance = self.get_dist(node_info, pvals) + if not cond_mean or not variance: + return np.nan + + return random.gauss(cond_mean, variance) def predict( self, diff --git a/bamt/nodes/conditional_logit_node.py b/bamt/nodes/conditional_logit_node.py index 8ce2e8b..6c988f4 100644 --- a/bamt/nodes/conditional_logit_node.py +++ b/bamt/nodes/conditional_logit_node.py @@ -106,16 +106,7 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, LogitParams]]: return {"hybcprob": hycprob} @staticmethod - def choose( - node_info: Dict[str, Dict[str, LogitParams]], pvals: List[Union[str, float]] - ) -> str: - """ - Return value from ConditionalLogit node - params: - node_info: nodes info from distributions - pvals: parent values - """ - + def get_dist(node_info, pvals, **kwargs): dispvals = [] lgpvals = [] for pval in pvals: @@ -140,6 +131,32 @@ def choose( distribution = model.predict_proba(np.array(lgpvals).reshape(1, -1))[0] + if not kwargs.get("inner", False): + return distribution + else: + return distribution, lgdistribution + else: + if not kwargs.get("inner", False): + return np.array([1.0]) + else: + return np.array([1.0]), lgdistribution + + def choose( + self, + node_info: Dict[str, Dict[str, LogitParams]], + pvals: List[Union[str, float]], + ) -> str: + """ + Return value from ConditionalLogit node + params: + node_info: nodes info from distributions + pvals: parent values + """ + + distribution, lgdistribution = self.get_dist(node_info, pvals, inner=True) + + # JOBLIB + if len(lgdistribution["classes"]) > 1: rand = random.random() rindex = 0 lbound = 0 @@ -152,7 +169,6 @@ def choose( else: lbound = ubound return str(lgdistribution["classes"][rindex]) - else: return str(lgdistribution["classes"][0]) diff --git a/bamt/nodes/conditional_mixture_gaussian_node.py b/bamt/nodes/conditional_mixture_gaussian_node.py index c04c823..26a2991 100644 --- a/bamt/nodes/conditional_mixture_gaussian_node.py +++ b/bamt/nodes/conditional_mixture_gaussian_node.py @@ -112,27 +112,21 @@ def fit_parameters( return {"hybcprob": hycprob} @staticmethod - def choose( - node_info: Dict[str, Dict[str, CondMixtureGaussParams]], - pvals: List[Union[str, float]], - ) -> Optional[float]: - """ - Function to get value from ConditionalMixtureGaussian node - params: - node_info: nodes info from distributions - pvals: parent values - """ - dispvals = [] + def get_dist(node_info, pvals): lgpvals = [] + dispvals = [] + for pval in pvals: if (isinstance(pval, str)) | (isinstance(pval, int)): dispvals.append(pval) else: lgpvals.append(pval) + lgdistribution = node_info["hybcprob"][str(dispvals)] mean = lgdistribution["mean"] covariance = lgdistribution["covars"] w = lgdistribution["coef"] + if len(w) != 0: if len(lgpvals) != 0: indexes = [i for i in range(1, (len(lgpvals) + 1), 1)] @@ -145,17 +139,40 @@ def choose( covariances=covariance, ) cond_gmm = gmm.condition(indexes, [lgpvals]) - sample = cond_gmm.sample(1)[0][0] + return cond_gmm.means, cond_gmm.covariances, cond_gmm.priors else: - sample = np.nan + return np.nan, np.nan, np.nan else: n_comp = len(w) gmm = GMM( n_components=n_comp, priors=w, means=mean, covariances=covariance ) - sample = gmm.sample(1)[0][0] + return gmm.means, gmm.covariances, gmm.priors else: - sample = np.nan + return np.nan, np.nan, np.nan + + def choose( + self, + node_info: Dict[str, Dict[str, CondMixtureGaussParams]], + pvals: List[Union[str, float]], + ) -> Optional[float]: + """ + Function to get value from ConditionalMixtureGaussian node + params: + node_info: nodes info from distributions + pvals: parent values + """ + mean, covariance, w = self.get_dist(node_info, pvals) + + n_comp = len(w) + + gmm = GMM( + n_components=n_comp, + priors=w, + means=mean, + covariances=covariance, + ) + sample = gmm.sample(1)[0][0] return sample @staticmethod diff --git a/bamt/nodes/discrete_node.py b/bamt/nodes/discrete_node.py index 9eb136c..976d60d 100644 --- a/bamt/nodes/discrete_node.py +++ b/bamt/nodes/discrete_node.py @@ -48,7 +48,10 @@ def worker(node: Type[BaseNode]) -> DiscreteParams: tight_form = conditional_dist.to_dict("tight") for comb, probs in zip(tight_form["index"], tight_form["data"]): - cprob[str([str(i) for i in comb])] = probs + if len(parents) > 1: + cprob[str([str(i) for i in comb])] = probs + else: + cprob[f"['{comb}']"] = probs return {"cprob": cprob, "vals": vals} pool = ThreadPoolExecutor(num_workers) @@ -56,7 +59,14 @@ def worker(node: Type[BaseNode]) -> DiscreteParams: return future.result() @staticmethod - def choose(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str: + def get_dist(node_info, pvals): + if not pvals: + return node_info["cprob"] + else: + # noinspection PyTypeChecker + return node_info["cprob"][str(pvals)] + + def choose(self, node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str: """ Return value from discrete node params: @@ -66,11 +76,9 @@ def choose(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str: rindex = 0 random.seed() vals = node_info["vals"] - if not pvals: - dist = node_info["cprob"] - else: - # noinspection PyTypeChecker - dist = node_info["cprob"][str(pvals)] + + dist = self.get_dist(node_info, pvals) + lbound = 0 ubound = 0 rand = random.random() diff --git a/bamt/nodes/gaussian_node.py b/bamt/nodes/gaussian_node.py index 483b3e2..38698a8 100644 --- a/bamt/nodes/gaussian_node.py +++ b/bamt/nodes/gaussian_node.py @@ -76,13 +76,8 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams: "serialization": None, } - def choose(self, node_info: GaussianParams, pvals: List[float]) -> float: - """ - Return value from Logit node - params: - node_info: nodes info from distributions - pvals: parent values - """ + def get_dist(self, node_info, pvals): + var = node_info["variance"] if pvals: for el in pvals: if str(el) == "nan": @@ -97,10 +92,20 @@ def choose(self, node_info: GaussianParams, pvals: List[float]) -> float: pvals = [int(item) if isinstance(item, str) else item for item in pvals] cond_mean = model.predict(np.array(pvals).reshape(1, -1))[0] - var = node_info["variance"] - return random.gauss(cond_mean, var) + return cond_mean, var else: - return random.gauss(node_info["mean"], math.sqrt(node_info["variance"])) + return node_info["mean"], math.sqrt(var) + + def choose(self, node_info: GaussianParams, pvals: List[float]) -> float: + """ + Return value from Logit node + params: + node_info: nodes info from distributions + pvals: parent values + """ + + cond_mean, var = self.get_dist(node_info, pvals) + return random.gauss(cond_mean, var) @staticmethod def predict(node_info: GaussianParams, pvals: List[float]) -> float: diff --git a/bamt/nodes/logit_node.py b/bamt/nodes/logit_node.py index b41f306..2c4bbcf 100644 --- a/bamt/nodes/logit_node.py +++ b/bamt/nodes/logit_node.py @@ -56,16 +56,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams: "serialization": serialization_name, } - def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: - """ - Return value from Logit node - params: - node_info: nodes info from distributions - pvals: parent values - """ - - rindex = 0 - + def get_dist(self, node_info, pvals): if len(node_info["classes"]) > 1: if node_info["serialization"] == "joblib": model = joblib.load(node_info["classifier_obj"]) @@ -76,9 +67,24 @@ def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: if type(self).__name__ == "CompositeDiscreteNode": pvals = [int(item) if isinstance(item, str) else item for item in pvals] - distribution = model.predict_proba(np.array(pvals).reshape(1, -1))[0] - # choose + return model.predict_proba(np.array(pvals).reshape(1, -1))[0] + else: + return np.array([1.0]) + + def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: + """ + Return value from Logit node + params: + node_info: nodes info from distributions + pvals: parent values + """ + + rindex = 0 + + distribution = self.get_dist(node_info, pvals) + + if len(node_info["classes"]) > 1: rand = random.random() lbound = 0 ubound = 0 @@ -91,7 +97,6 @@ def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: lbound = ubound return str(node_info["classes"][rindex]) - else: return str(node_info["classes"][0]) diff --git a/bamt/nodes/mixture_gaussian_node.py b/bamt/nodes/mixture_gaussian_node.py index 6ac87c8..22cd2b3 100644 --- a/bamt/nodes/mixture_gaussian_node.py +++ b/bamt/nodes/mixture_gaussian_node.py @@ -69,15 +69,7 @@ def fit_parameters(self, data: DataFrame) -> MixtureGaussianParams: return {"mean": means, "coef": w, "covars": cov} @staticmethod - def choose( - node_info: MixtureGaussianParams, pvals: List[Union[str, float]] - ) -> Optional[float]: - """ - Func to get value from current node - node_info: nodes info from distributions - pvals: parent values - Return value from MixtureGaussian node - """ + def get_dist(node_info, pvals): mean = node_info["mean"] covariance = node_info["covars"] w = node_info["coef"] @@ -93,17 +85,37 @@ def choose( covariances=covariance, ) cond_gmm = gmm.condition(indexes, [pvals]) - sample = cond_gmm.sample(1)[0][0] + return cond_gmm.means, cond_gmm.covariances, cond_gmm.priors else: - sample = np.nan + return np.nan, np.nan, np.nan else: gmm = GMM( n_components=n_comp, priors=w, means=mean, covariances=covariance ) - sample = gmm.sample(1)[0][0] + return gmm.means, gmm.covariances, gmm.priors else: - sample = np.nan - return sample + return np.nan, np.nan, np.nan + + def choose( + self, node_info: MixtureGaussianParams, pvals: List[Union[str, float]] + ) -> Optional[float]: + """ + Func to get value from current node + node_info: nodes info from distributions + pvals: parent values + Return value from MixtureGaussian node + """ + mean, covariance, w = self.get_dist(node_info, pvals) + + n_comp = len(w) + + gmm = GMM( + n_components=n_comp, + priors=w, + means=mean, + covariances=covariance, + ) + return gmm.sample(1)[0][0] @staticmethod def predict( diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 810a252..aabd2d7 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -5,11 +5,85 @@ import pandas as pd from bamt.nodes import * +from bamt.networks.hybrid_bn import HybridBN logging.getLogger("nodes").setLevel(logging.CRITICAL) -class TestBaseNode(unittest.TestCase): +class MyTest(unittest.TestCase): + unittest.skip("This is an assertion.") + + def assertDist(self, dist, node_type): + if node_type in ("disc", "disc_num"): + return self.assertAlmostEqual(sum(dist), 1) + else: + return self.assertEqual(len(dist), 2, msg=f"Error on {dist}") + + def assertDistMixture(self, dist): + return self.assertEqual(len(dist), 3, msg=f"Error on {dist}") + + +class TestBaseNode(MyTest): + def setUp(self): + np.random.seed(510) + + hybrid_bn = HybridBN(has_logit=True) + + info = { + "types": { + "Node0": "cont", + "Node1": "cont", + "Node2": "cont", + "Node3": "cont", + "Node4": "disc", + "Node5": "disc", + "Node6": "disc_num", + "Node7": "disc_num", + }, + "signs": {"Node0": "pos", "Node1": "neg", "Node2": "neg", "Node3": "neg"}, + } + + data = pd.DataFrame( + { + "Node0": np.random.normal(0, 4, 30), + "Node1": np.random.normal(0, 0.1, 30), + "Node2": np.random.normal(0, 0.3, 30), + "Node3": np.random.normal(0, 0.3, 30), + "Node4": np.random.choice(["cat1", "cat2", "cat3"], 30), + "Node5": np.random.choice(["cat4", "cat5", "cat6"], 30), + "Node6": np.random.choice(["cat7", "cat8", "cat9"], 30), + "Node7": np.random.choice(["cat7", "cat8", "cat9"], 30), + } + ) + + nodes = [ + gaussian_node.GaussianNode(name="Node0"), + gaussian_node.GaussianNode(name="Node1"), + gaussian_node.GaussianNode(name="Node2"), + gaussian_node.GaussianNode(name="Node3"), + discrete_node.DiscreteNode(name="Node4"), + discrete_node.DiscreteNode(name="Node5"), + discrete_node.DiscreteNode(name="Node6"), + discrete_node.DiscreteNode(name="Node7"), + ] + + edges = [ + ("Node0", "Node7"), + ("Node0", "Node1"), + ("Node0", "Node2"), + ("Node0", "Node5"), + ("Node4", "Node2"), + ("Node4", "Node5"), + ("Node4", "Node6"), + ("Node4", "Node3"), + ] + + hybrid_bn.set_structure(info, nodes=nodes, edges=edges) + hybrid_bn.fit_parameters(data) + self.bn = hybrid_bn + self.data = data + self.info = info + def test_equality(self): test = base.BaseNode(name="node0") first = base.BaseNode(name="node1") @@ -53,6 +127,57 @@ def test_equality(self): self.assertFalse(test == first) self.assertTrue(test == test_clone) + def test_get_dist_mixture(self): + hybrid_bn = HybridBN(use_mixture=True, has_logit=True) + + hybrid_bn.set_structure(self.info, self.bn.nodes, self.bn.edges) + hybrid_bn.fit_parameters(self.data) + + mixture_gauss = hybrid_bn["Node1"] + cond_mixture_gauss = hybrid_bn["Node2"] + + for i in range(-2, 0, 2): + dist = hybrid_bn.get_dist(mixture_gauss.name, pvals={"Node0": i}) + self.assertDistMixture(dist) + + for i in range(-2, 0, 2): + for j in self.data[cond_mixture_gauss.disc_parents[0]].unique().tolist(): + dist = hybrid_bn.get_dist( + cond_mixture_gauss.name, pvals={"Node0": float(i), "Node4": j} + ) + self.assertDistMixture(dist) + + def test_get_dist(self): + for node in self.bn.nodes: + if not node.cont_parents + node.disc_parents: + dist = self.bn.get_dist(node.name) + if "mean" in dist.keys(): + self.assertTrue(isinstance(dist["mean"], float)) + self.assertTrue(isinstance(dist["variance"], float)) + else: + self.assertAlmostEqual(sum(dist["cprob"]), 1) + continue + + if len(node.cont_parents + node.disc_parents) == 1: + if node.disc_parents: + pvals = self.data[node.disc_parents[0]].unique().tolist() + parent = node.disc_parents[0] + else: + pvals = range(-5, 5, 1) + parent = node.cont_parents[0] + + for pval in pvals: + dist = self.bn.get_dist(node.name, {parent: pval}) + self.assertDist(dist, self.info["types"][node.name]) + else: + for i in self.data[node.disc_parents[0]].unique().tolist(): + for j in range(-5, 5, 1): + dist = self.bn.get_dist( + node.name, + {node.cont_parents[0]: float(j), node.disc_parents[0]: i}, + ) + self.assertDist(dist, self.info["types"][node.name]) + # ??? def test_choose_serialization(self): pass