From 2cfd46dd9ad9ba40eabf73d16399fc5faba6ef14 Mon Sep 17 00:00:00 2001 From: jrzkaminski <86363785+jrzkaminski@users.noreply.github.com> Date: Thu, 3 Aug 2023 16:58:30 +0300 Subject: [PATCH] middle state commit --- bamt/networks/composite_bn.py | 2 +- bamt/nodes/composite_continuous_node.py | 18 +++++++++++++++--- bamt/nodes/composite_discrete_node.py | 16 +++++++++++++--- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/bamt/networks/composite_bn.py b/bamt/networks/composite_bn.py index 4a31b45..6c52201 100644 --- a/bamt/networks/composite_bn.py +++ b/bamt/networks/composite_bn.py @@ -159,7 +159,7 @@ def wrapper(): if self.type == "Discrete": pvals = [str(output[t]) for t in parents] elif type(node).__name__ in ("CompositeDiscreteNode", "CompositeContinuousNode"): - pvals = [str(output[t]) for t in parents] + pvals = output else: pvals = [output[t] for t in parents] diff --git a/bamt/nodes/composite_continuous_node.py b/bamt/nodes/composite_continuous_node.py index 95b62b0..944be2e 100644 --- a/bamt/nodes/composite_continuous_node.py +++ b/bamt/nodes/composite_continuous_node.py @@ -4,13 +4,14 @@ import random import math +from numpy import array from .base import BaseNode from .gaussian_node import GaussianNode from .schema import GaussianParams, HybcprobParams from sklearn import linear_model from pandas import DataFrame -from typing import Optional, List, Union +from typing import Optional, List, Union, Dict from ..log import logger_nodes from sklearn.metrics import mean_squared_error as mse @@ -69,14 +70,25 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams: msg="Composite Continuous Node should always have a parent" ) - @staticmethod - def choose(node_info: GaussianParams, pvals: List[float]) -> float: + def choose(self, node_info: GaussianParams, pvals: Dict) -> float: """ Return value from Logit node params: node_info: nodes info from distributions pvals: parent values """ + + print("ENCODERS OF", self.name, "\n", self.encoders) + + for parent_key in pvals: + if not isinstance(pvals[parent_key], (float, int)): + parent_value_array = np.array(pvals[parent_key]) + pvals[parent_key] = self.encoders[parent_key].transform(parent_value_array.reshape(1, -1)) + + pvals = list(pvals.values()) + + print("TRANSFORMED PVALS \n", pvals) + if pvals: for el in pvals: if str(el) == "nan": diff --git a/bamt/nodes/composite_discrete_node.py b/bamt/nodes/composite_discrete_node.py index d28152b..3a30ccd 100644 --- a/bamt/nodes/composite_discrete_node.py +++ b/bamt/nodes/composite_discrete_node.py @@ -10,7 +10,7 @@ from sklearn import linear_model from pandas import DataFrame -from typing import Optional, List, Union +from typing import Optional, List, Union, Dict class CompositeDiscreteNode(BaseNode): @@ -58,8 +58,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams: "serialization": serialization_name, } - @staticmethod - def choose(node_info: LogitParams, pvals: List[Union[float]]) -> str: + def choose(self, node_info: LogitParams, pvals: Dict) -> str: """ Return value from Logit node params: @@ -69,6 +68,17 @@ def choose(node_info: LogitParams, pvals: List[Union[float]]) -> str: rindex = 0 + print("ENCODERS OF", self.name, "\n", self.encoders) + + for parent_key in pvals: + if not isinstance(pvals[parent_key], (float, int)): + parent_value_array = np.array(pvals[parent_key]) + pvals[parent_key] = self.encoders[parent_key].transform(parent_value_array.reshape(1, -1)) + + pvals = list(pvals.values()) + + print("TRANSFORMED PVALS \n", pvals) + if len(node_info["classes"]) > 1: if node_info["serialization"] == "joblib": model = joblib.load(node_info["classifier_obj"])