From f1a8756006c9bebf3929ad38707124c348c7d49c Mon Sep 17 00:00:00 2001 From: Roman223 Date: Thu, 10 Aug 2023 14:51:47 +0300 Subject: [PATCH 1/5] feature: get_dist to get conditional distribution from node --- bamt/networks/base.py | 8 ++++ bamt/nodes/base.py | 4 ++ bamt/nodes/conditional_gaussian_node.py | 41 +++++++++------- bamt/nodes/conditional_logit_node.py | 38 ++++++++++----- .../conditional_mixture_gaussian_node.py | 47 +++++++++++++------ bamt/nodes/discrete_node.py | 22 ++++++--- bamt/nodes/gaussian_node.py | 25 ++++++---- bamt/nodes/logit_node.py | 27 ++++++----- bamt/nodes/mixture_gaussian_node.py | 40 ++++++++++------ 9 files changed, 166 insertions(+), 86 deletions(-) diff --git a/bamt/networks/base.py b/bamt/networks/base.py index da4e73f..8459fe9 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -862,3 +862,11 @@ def plot(self, output: str): os.mkdir("visualization_result") return network.show(f"visualization_result/" + output) + + def get_dist(self, node_name: str, pvals: dict): + node = self[node_name] + + parents = node.cont_parents + node.disc_parents + pvals = [pvals[parent] for parent in parents] + + return node.get_dist(node_info=self.distributions[node_name], pvals=pvals) diff --git a/bamt/nodes/base.py b/bamt/nodes/base.py index 1c5c2c8..d00b253 100644 --- a/bamt/nodes/base.py +++ b/bamt/nodes/base.py @@ -86,3 +86,7 @@ def get_path_joblib(node_name: str, specific: str = "") -> str: os.path.join(path_to_check, f"{specific}.joblib.compressed") ) return path + + @staticmethod + def get_dist(node_info, pvals): + pass diff --git a/bamt/nodes/conditional_gaussian_node.py b/bamt/nodes/conditional_gaussian_node.py index dcdd166..322b8f5 100644 --- a/bamt/nodes/conditional_gaussian_node.py +++ b/bamt/nodes/conditional_gaussian_node.py @@ -109,18 +109,7 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, CondGaussParams } return {"hybcprob": hycprob} - def choose( - self, - node_info: Dict[str, Dict[str, CondGaussParams]], - pvals: List[Union[str, float]], - ) -> float: - """ - Return value from ConditionalLogit node - params: - node_info: nodes info from distributions - pvals: parent values - """ - + def get_dist(self, node_info, pvals): dispvals = [] lgpvals = [] for pval in pvals: @@ -140,7 +129,7 @@ def choose( flag = True break if flag: - return np.nan + return np.nan, np.nan else: if lgdistribution["regressor"]: if lgdistribution["serialization"] == "joblib": @@ -152,14 +141,30 @@ def choose( cond_mean = model.predict(np.array(lgpvals).reshape(1, -1))[0] variance = lgdistribution["variance"] - return random.gauss(cond_mean, variance) + return cond_mean, variance else: - return np.nan + return np.nan, np.nan else: - return random.gauss( - lgdistribution["mean"], math.sqrt(lgdistribution["variance"]) - ) + return lgdistribution["mean"], math.sqrt(lgdistribution["variance"]) + + def choose( + self, + node_info: Dict[str, Dict[str, CondGaussParams]], + pvals: List[Union[str, float]], + ) -> float: + """ + Return value from ConditionalLogit node + params: + node_info: nodes info from distributions + pvals: parent values + """ + + cond_mean, variance = self.get_dist(node_info, pvals) + if not cond_mean or not variance: + return np.nan + + return random.gauss(cond_mean, variance) def predict( self, diff --git a/bamt/nodes/conditional_logit_node.py b/bamt/nodes/conditional_logit_node.py index 3a1b1a0..70f97f1 100644 --- a/bamt/nodes/conditional_logit_node.py +++ b/bamt/nodes/conditional_logit_node.py @@ -106,16 +106,7 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, LogitParams]]: return {"hybcprob": hycprob} @staticmethod - def choose( - node_info: Dict[str, Dict[str, LogitParams]], pvals: List[Union[str, float]] - ) -> str: - """ - Return value from ConditionalLogit node - params: - node_info: nodes info from distributions - pvals: parent values - """ - + def get_dist(node_info, pvals, **kwargs): dispvals = [] lgpvals = [] for pval in pvals: @@ -140,6 +131,32 @@ def choose( distribution = model.predict_proba(np.array(lgpvals).reshape(1, -1))[0] + if not kwargs.get("inner", False): + return distribution + else: + return distribution, lgdistribution + else: + if not kwargs.get("inner", False): + return np.array([100.0]) + else: + return np.array([100.0]), lgdistribution + + def choose( + self, + node_info: Dict[str, Dict[str, LogitParams]], + pvals: List[Union[str, float]], + ) -> str: + """ + Return value from ConditionalLogit node + params: + node_info: nodes info from distributions + pvals: parent values + """ + + distribution, lgdistribution = self.get_dist(node_info, pvals, inner=True) + + # JOBLIB + if len(lgdistribution["classes"]) > 1: rand = random.random() rindex = 0 lbound = 0 @@ -152,7 +169,6 @@ def choose( else: lbound = ubound return str(lgdistribution["classes"][rindex]) - else: return str(lgdistribution["classes"][0]) diff --git a/bamt/nodes/conditional_mixture_gaussian_node.py b/bamt/nodes/conditional_mixture_gaussian_node.py index 280aada..30ee52b 100644 --- a/bamt/nodes/conditional_mixture_gaussian_node.py +++ b/bamt/nodes/conditional_mixture_gaussian_node.py @@ -113,27 +113,21 @@ def fit_parameters( return {"hybcprob": hycprob} @staticmethod - def choose( - node_info: Dict[str, Dict[str, CondMixtureGaussParams]], - pvals: List[Union[str, float]], - ) -> Optional[float]: - """ - Function to get value from ConditionalMixtureGaussian node - params: - node_info: nodes info from distributions - pvals: parent values - """ - dispvals = [] + def get_dist(node_info, pvals): lgpvals = [] + dispvals = [] + for pval in pvals: if (isinstance(pval, str)) | (isinstance(pval, int)): dispvals.append(pval) else: lgpvals.append(pval) + lgdistribution = node_info["hybcprob"][str(dispvals)] mean = lgdistribution["mean"] covariance = lgdistribution["covars"] w = lgdistribution["coef"] + if len(w) != 0: if len(lgpvals) != 0: indexes = [i for i in range(1, (len(lgpvals) + 1), 1)] @@ -146,17 +140,40 @@ def choose( covariances=covariance, ) cond_gmm = gmm.condition(indexes, [lgpvals]) - sample = cond_gmm.sample(1)[0][0] + return cond_gmm.means, cond_gmm.covariances, cond_gmm.priors else: - sample = np.nan + return np.nan, np.nan, np.nan else: n_comp = len(w) gmm = GMM( n_components=n_comp, priors=w, means=mean, covariances=covariance ) - sample = gmm.sample(1)[0][0] + return gmm.means, gmm.covariances, gmm.priors else: - sample = np.nan + return np.nan, np.nan, np.nan + + def choose( + self, + node_info: Dict[str, Dict[str, CondMixtureGaussParams]], + pvals: List[Union[str, float]], + ) -> Optional[float]: + """ + Function to get value from ConditionalMixtureGaussian node + params: + node_info: nodes info from distributions + pvals: parent values + """ + mean, covariance, w = self.get_dist(node_info, pvals) + + n_comp = len(w) + + gmm = GMM( + n_components=n_comp, + priors=w, + means=mean, + covariances=covariance, + ) + sample = gmm.sample(1)[0][0] return sample @staticmethod diff --git a/bamt/nodes/discrete_node.py b/bamt/nodes/discrete_node.py index f2df41c..5e7455f 100644 --- a/bamt/nodes/discrete_node.py +++ b/bamt/nodes/discrete_node.py @@ -49,7 +49,10 @@ def worker(node: Type[BaseNode]) -> DiscreteParams: tight_form = conditional_dist.to_dict("tight") for comb, probs in zip(tight_form["index"], tight_form["data"]): - cprob[str([str(i) for i in comb])] = probs + if len(parents) > 1: + cprob[str([str(i) for i in comb])] = probs + else: + cprob[f"['{comb}']"] = probs return {"cprob": cprob, "vals": vals} pool = ThreadPoolExecutor(num_workers) @@ -57,7 +60,14 @@ def worker(node: Type[BaseNode]) -> DiscreteParams: return future.result() @staticmethod - def choose(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str: + def get_dist(node_info, pvals): + if not pvals: + return node_info["cprob"] + else: + # noinspection PyTypeChecker + return node_info["cprob"][str(pvals)] + + def choose(self, node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str: """ Return value from discrete node params: @@ -67,11 +77,9 @@ def choose(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str: rindex = 0 random.seed() vals = node_info["vals"] - if not pvals: - dist = node_info["cprob"] - else: - # noinspection PyTypeChecker - dist = node_info["cprob"][str(pvals)] + + dist = self.get_dist(node_info, pvals) + lbound = 0 ubound = 0 rand = random.random() diff --git a/bamt/nodes/gaussian_node.py b/bamt/nodes/gaussian_node.py index f72dcf3..92f7680 100644 --- a/bamt/nodes/gaussian_node.py +++ b/bamt/nodes/gaussian_node.py @@ -78,13 +78,8 @@ def fit_parameters(self, data: DataFrame) -> GaussianParams: } @staticmethod - def choose(node_info: GaussianParams, pvals: List[float]) -> float: - """ - Return value from Logit node - params: - node_info: nodes info from distributions - pvals: parent values - """ + def get_dist(node_info, pvals): + var = node_info["variance"] if pvals: for el in pvals: if str(el) == "nan": @@ -96,10 +91,20 @@ def choose(node_info: GaussianParams, pvals: List[float]) -> float: model = pickle.loads(a) cond_mean = model.predict(np.array(pvals).reshape(1, -1))[0] - var = node_info["variance"] - return random.gauss(cond_mean, var) + return cond_mean, var else: - return random.gauss(node_info["mean"], math.sqrt(node_info["variance"])) + return node_info["mean"], math.sqrt(var) + + def choose(self, node_info: GaussianParams, pvals: List[float]) -> float: + """ + Return value from Logit node + params: + node_info: nodes info from distributions + pvals: parent values + """ + + cond_mean, var = self.get_dist(node_info, pvals) + return random.gauss(cond_mean, var) @staticmethod def predict(node_info: GaussianParams, pvals: List[float]) -> float: diff --git a/bamt/nodes/logit_node.py b/bamt/nodes/logit_node.py index 5f770eb..a06ab87 100644 --- a/bamt/nodes/logit_node.py +++ b/bamt/nodes/logit_node.py @@ -57,6 +57,19 @@ def fit_parameters(self, data: DataFrame) -> LogitParams: "serialization": serialization_name, } + @staticmethod + def get_dist(node_info, pvals): + if len(node_info["classes"]) > 1: + if node_info["serialization"] == "joblib": + model = joblib.load(node_info["classifier_obj"]) + else: + # str_model = node_info["classifier_obj"].decode('latin1').replace('\'', '\"') + a = node_info["classifier_obj"].encode("latin1") + model = pickle.loads(a) + return model.predict_proba(np.array(pvals).reshape(1, -1))[0] + else: + return np.array([100.]) + def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: """ Return value from Logit node @@ -67,29 +80,21 @@ def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: rindex = 0 - if len(node_info["classes"]) > 1: - if node_info["serialization"] == "joblib": - model = joblib.load(node_info["classifier_obj"]) - else: - # str_model = node_info["classifier_obj"].decode('latin1').replace('\'', '\"') - a = node_info["classifier_obj"].encode("latin1") - model = pickle.loads(a) - distribution = model.predict_proba(np.array(pvals).reshape(1, -1))[0] + distribution = self.get_dist(node_info, pvals) - # choose + if len(node_info["classes"]) > 1: rand = random.random() lbound = 0 ubound = 0 for interval in range(len(node_info["classes"])): ubound += distribution[interval] - if lbound <= rand and rand < ubound: + if lbound <= rand < ubound: rindex = interval break else: lbound = ubound return str(node_info["classes"][rindex]) - else: return str(node_info["classes"][0]) diff --git a/bamt/nodes/mixture_gaussian_node.py b/bamt/nodes/mixture_gaussian_node.py index 93dc724..7653b11 100644 --- a/bamt/nodes/mixture_gaussian_node.py +++ b/bamt/nodes/mixture_gaussian_node.py @@ -70,15 +70,7 @@ def fit_parameters(self, data: DataFrame) -> MixtureGaussianParams: return {"mean": means, "coef": w, "covars": cov} @staticmethod - def choose( - node_info: MixtureGaussianParams, pvals: List[Union[str, float]] - ) -> Optional[float]: - """ - Func to get value from current node - node_info: nodes info from distributions - pvals: parent values - Return value from MixtureGaussian node - """ + def get_dist(node_info, pvals): mean = node_info["mean"] covariance = node_info["covars"] w = node_info["coef"] @@ -94,17 +86,37 @@ def choose( covariances=covariance, ) cond_gmm = gmm.condition(indexes, [pvals]) - sample = cond_gmm.sample(1)[0][0] + return cond_gmm.means, cond_gmm.covariances, cond_gmm.priors else: - sample = np.nan + return np.nan, np.nan, np.nan else: gmm = GMM( n_components=n_comp, priors=w, means=mean, covariances=covariance ) - sample = gmm.sample(1)[0][0] + return gmm.means, gmm.covariances, gmm.priors else: - sample = np.nan - return sample + return np.nan, np.nan, np.nan + + def choose( + self, node_info: MixtureGaussianParams, pvals: List[Union[str, float]] + ) -> Optional[float]: + """ + Func to get value from current node + node_info: nodes info from distributions + pvals: parent values + Return value from MixtureGaussian node + """ + mean, covariance, w = self.get_dist(node_info, pvals) + + n_comp = len(w) + + gmm = GMM( + n_components=n_comp, + priors=w, + means=mean, + covariances=covariance, + ) + return gmm.sample(1)[0][0] @staticmethod def predict( From 3006616b1ab409d0a465e0f0b08209892e46a3fe Mon Sep 17 00:00:00 2001 From: Roman223 Date: Fri, 11 Aug 2023 16:58:31 +0300 Subject: [PATCH 2/5] tests and minor improvements --- bamt/networks/base.py | 13 +++++ bamt/nodes/logit_node.py | 2 +- tests/test_nodes.py | 122 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/bamt/networks/base.py b/bamt/networks/base.py index 8459fe9..3f390e4 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -864,9 +864,22 @@ def plot(self, output: str): return network.show(f"visualization_result/" + output) def get_dist(self, node_name: str, pvals: dict): + """ + Get a distribution from node with known parent values (conditional distribution). + + :param node_name: name of node + :param pvals: parent values + """ + if not self.distributions: + logger_network.error("Empty parameters. Call fit_params first.") + return node = self[node_name] parents = node.cont_parents + node.disc_parents + if not parents: + logger_network.error("No parents") + return + pvals = [pvals[parent] for parent in parents] return node.get_dist(node_info=self.distributions[node_name], pvals=pvals) diff --git a/bamt/nodes/logit_node.py b/bamt/nodes/logit_node.py index a06ab87..2eeb622 100644 --- a/bamt/nodes/logit_node.py +++ b/bamt/nodes/logit_node.py @@ -68,7 +68,7 @@ def get_dist(node_info, pvals): model = pickle.loads(a) return model.predict_proba(np.array(pvals).reshape(1, -1))[0] else: - return np.array([100.]) + return np.array([100.0]) def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: """ diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 87d2695..c48e4a0 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -4,11 +4,85 @@ import numpy as np from bamt.nodes import * +from bamt.networks.hybrid_bn import HybridBN logging.getLogger("nodes").setLevel(logging.CRITICAL) +class DistributionAssertion(unittest.TestCase): + unittest.skip("This is an assertion.") + + def assertDist(self, dist, node_type): + if node_type in ("disc", "disc_num"): + return self.assertAlmostEqual(sum(dist), 1) + else: + return self.assertEqual(len(dist), 2, msg=f"Error on {dist}") + + def assertDistMixture(self, dist): + return self.assertEqual(len(dist), 3, msg=f"Error on {dist}") + + class TestBaseNode(unittest.TestCase): + def setUp(self): + np.random.seed(510) + + hybrid_bn = HybridBN(has_logit=True) + + info = { + "types": { + "Node0": "cont", + "Node1": "cont", + "Node2": "cont", + "Node3": "cont", + "Node4": "disc", + "Node5": "disc", + "Node6": "disc_num", + "Node7": "disc_num", + }, + "signs": {"Node0": "pos", "Node1": "neg", "Node2": "neg", "Node3": "neg"}, + } + + data = pd.DataFrame( + { + "Node0": np.random.normal(0, 4, 30), + "Node1": np.random.normal(0, 0.1, 30), + "Node2": np.random.normal(0, 0.3, 30), + "Node3": np.random.normal(0, 0.3, 30), + "Node4": np.random.choice(["cat1", "cat2", "cat3"], 30), + "Node5": np.random.choice(["cat4", "cat5", "cat6"], 30), + "Node6": np.random.choice(["cat7", "cat8", "cat9"], 30), + "Node7": np.random.choice(["cat7", "cat8", "cat9"], 30), + } + ) + + nodes = [ + gaussian_node.GaussianNode(name="Node0"), + gaussian_node.GaussianNode(name="Node1"), + gaussian_node.GaussianNode(name="Node2"), + gaussian_node.GaussianNode(name="Node3"), + discrete_node.DiscreteNode(name="Node4"), + discrete_node.DiscreteNode(name="Node5"), + discrete_node.DiscreteNode(name="Node6"), + discrete_node.DiscreteNode(name="Node7"), + ] + + edges = [ + ("Node0", "Node7"), + ("Node0", "Node1"), + ("Node0", "Node2"), + ("Node0", "Node5"), + ("Node4", "Node2"), + ("Node4", "Node5"), + ("Node4", "Node6"), + ("Node4", "Node3"), + ] + + hybrid_bn.set_structure(info, nodes=nodes, edges=edges) + hybrid_bn.fit_parameters(data) + self.bn = hybrid_bn + self.data = data + self.info = info + def test_equality(self): test = base.BaseNode(name="node0") first = base.BaseNode(name="node1") @@ -52,6 +126,54 @@ def test_equality(self): self.assertFalse(test == first) self.assertTrue(test == test_clone) + def test_get_dist_mixture(self): + hybrid_bn = HybridBN(use_mixture=True, has_logit=True) + + hybrid_bn.set_structure(self.info, self.bn.nodes, self.bn.edges) + hybrid_bn.fit_parameters(self.data) + + mixture_gauss = hybrid_bn["Node1"] + cond_mixture_gauss = hybrid_bn["Node2"] + + for i in range(-2, 0, 2): + dist = hybrid_bn.get_dist(mixture_gauss.name, pvals={"Node0": i}) + DistributionAssertion().assertDistMixture(dist) + + for i in range(-2, 0, 2): + for j in self.data[cond_mixture_gauss.disc_parents[0]].unique().tolist(): + dist = hybrid_bn.get_dist( + cond_mixture_gauss.name, pvals={"Node0": float(i), "Node4": j} + ) + DistributionAssertion().assertDistMixture(dist) + + def test_get_dist(self): + for node in self.bn.nodes: + if not node.cont_parents + node.disc_parents: + continue + if len(node.cont_parents + node.disc_parents) == 1: + if node.disc_parents: + pvals = self.data[node.disc_parents[0]].unique().tolist() + parent = node.disc_parents[0] + else: + pvals = range(-5, 5, 1) + parent = node.cont_parents[0] + + for pval in pvals: + dist = self.bn.get_dist(node.name, {parent: pval}) + DistributionAssertion().assertDist( + dist, self.info["types"][node.name] + ) + else: + for i in self.data[node.disc_parents[0]].unique().tolist(): + for j in range(-5, 5, 1): + dist = self.bn.get_dist( + node.name, + {node.cont_parents[0]: float(j), node.disc_parents[0]: i}, + ) + DistributionAssertion().assertDist( + dist, self.info["types"][node.name] + ) + # ??? def test_choose_serialization(self): pass From a36156265420bdc6e5da9b7295e890a799c67118 Mon Sep 17 00:00:00 2001 From: Roman223 Date: Fri, 11 Aug 2023 17:34:49 +0300 Subject: [PATCH 3/5] when pvals don't set, return marginal dist --- bamt/networks/base.py | 5 ++--- tests/test_nodes.py | 23 +++++++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/bamt/networks/base.py b/bamt/networks/base.py index 3f390e4..d44d974 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -863,7 +863,7 @@ def plot(self, output: str): return network.show(f"visualization_result/" + output) - def get_dist(self, node_name: str, pvals: dict): + def get_dist(self, node_name: str, pvals: Optional[dict] = None): """ Get a distribution from node with known parent values (conditional distribution). @@ -877,8 +877,7 @@ def get_dist(self, node_name: str, pvals: dict): parents = node.cont_parents + node.disc_parents if not parents: - logger_network.error("No parents") - return + return self.distributions[node_name] pvals = [pvals[parent] for parent in parents] diff --git a/tests/test_nodes.py b/tests/test_nodes.py index c48e4a0..6a70264 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -9,7 +9,7 @@ logging.getLogger("nodes").setLevel(logging.CRITICAL) -class DistributionAssertion(unittest.TestCase): +class MyTest(unittest.TestCase): unittest.skip("This is an assertion.") def assertDist(self, dist, node_type): @@ -22,7 +22,7 @@ def assertDistMixture(self, dist): return self.assertEqual(len(dist), 3, msg=f"Error on {dist}") -class TestBaseNode(unittest.TestCase): +class TestBaseNode(MyTest): def setUp(self): np.random.seed(510) @@ -137,19 +137,26 @@ def test_get_dist_mixture(self): for i in range(-2, 0, 2): dist = hybrid_bn.get_dist(mixture_gauss.name, pvals={"Node0": i}) - DistributionAssertion().assertDistMixture(dist) + self.assertDistMixture(dist) for i in range(-2, 0, 2): for j in self.data[cond_mixture_gauss.disc_parents[0]].unique().tolist(): dist = hybrid_bn.get_dist( cond_mixture_gauss.name, pvals={"Node0": float(i), "Node4": j} ) - DistributionAssertion().assertDistMixture(dist) + self.assertDistMixture(dist) def test_get_dist(self): for node in self.bn.nodes: if not node.cont_parents + node.disc_parents: + dist = self.bn.get_dist(node.name) + if "mean" in dist.keys(): + self.assertTrue(isinstance(dist["mean"], float)) + self.assertTrue(isinstance(dist["variance"], float)) + else: + self.assertAlmostEqual(sum(dist["cprob"]), 1) continue + if len(node.cont_parents + node.disc_parents) == 1: if node.disc_parents: pvals = self.data[node.disc_parents[0]].unique().tolist() @@ -160,9 +167,7 @@ def test_get_dist(self): for pval in pvals: dist = self.bn.get_dist(node.name, {parent: pval}) - DistributionAssertion().assertDist( - dist, self.info["types"][node.name] - ) + self.assertDist(dist, self.info["types"][node.name]) else: for i in self.data[node.disc_parents[0]].unique().tolist(): for j in range(-5, 5, 1): @@ -170,9 +175,7 @@ def test_get_dist(self): node.name, {node.cont_parents[0]: float(j), node.disc_parents[0]: i}, ) - DistributionAssertion().assertDist( - dist, self.info["types"][node.name] - ) + self.assertDist(dist, self.info["types"][node.name]) # ??? def test_choose_serialization(self): From a25b4ef47e1aa05b047948ebbafaf81847b3e210 Mon Sep 17 00:00:00 2001 From: Roman223 Date: Mon, 14 Aug 2023 12:39:12 +0300 Subject: [PATCH 4/5] probas in fractions --- bamt/nodes/conditional_logit_node.py | 4 ++-- bamt/nodes/logit_node.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bamt/nodes/conditional_logit_node.py b/bamt/nodes/conditional_logit_node.py index 70f97f1..7888fed 100644 --- a/bamt/nodes/conditional_logit_node.py +++ b/bamt/nodes/conditional_logit_node.py @@ -137,9 +137,9 @@ def get_dist(node_info, pvals, **kwargs): return distribution, lgdistribution else: if not kwargs.get("inner", False): - return np.array([100.0]) + return np.array([1.0]) else: - return np.array([100.0]), lgdistribution + return np.array([1.0]), lgdistribution def choose( self, diff --git a/bamt/nodes/logit_node.py b/bamt/nodes/logit_node.py index 2eeb622..af5dc0f 100644 --- a/bamt/nodes/logit_node.py +++ b/bamt/nodes/logit_node.py @@ -68,7 +68,7 @@ def get_dist(node_info, pvals): model = pickle.loads(a) return model.predict_proba(np.array(pvals).reshape(1, -1))[0] else: - return np.array([100.0]) + return np.array([1.0]) def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: """ From b2cbc744c842c95ac224f8a52c97099fff47ffeb Mon Sep 17 00:00:00 2001 From: Roman223 Date: Tue, 15 Aug 2023 18:03:28 +0300 Subject: [PATCH 5/5] Hotfix --- bamt/nodes/gaussian_node.py | 3 +-- bamt/nodes/logit_node.py | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/bamt/nodes/gaussian_node.py b/bamt/nodes/gaussian_node.py index 8ab4f29..38698a8 100644 --- a/bamt/nodes/gaussian_node.py +++ b/bamt/nodes/gaussian_node.py @@ -76,8 +76,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams: "serialization": None, } - @staticmethod - def get_dist(node_info, pvals): + def get_dist(self, node_info, pvals): var = node_info["variance"] if pvals: for el in pvals: diff --git a/bamt/nodes/logit_node.py b/bamt/nodes/logit_node.py index 3071466..2c4bbcf 100644 --- a/bamt/nodes/logit_node.py +++ b/bamt/nodes/logit_node.py @@ -56,8 +56,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams: "serialization": serialization_name, } - @staticmethod - def get_dist(node_info, pvals): + def get_dist(self, node_info, pvals): if len(node_info["classes"]) > 1: if node_info["serialization"] == "joblib": model = joblib.load(node_info["classifier_obj"]) @@ -65,10 +64,10 @@ def get_dist(node_info, pvals): # str_model = node_info["classifier_obj"].decode('latin1').replace('\'', '\"') a = node_info["classifier_obj"].encode("latin1") model = pickle.loads(a) - + if type(self).__name__ == "CompositeDiscreteNode": pvals = [int(item) if isinstance(item, str) else item for item in pvals] - + return model.predict_proba(np.array(pvals).reshape(1, -1))[0] else: return np.array([1.0])