Skip to content

Commit

Permalink
middle state commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Aug 3, 2023
1 parent 7eea3da commit 2cfd46d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
2 changes: 1 addition & 1 deletion bamt/networks/composite_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def wrapper():
if self.type == "Discrete":
pvals = [str(output[t]) for t in parents]
elif type(node).__name__ in ("CompositeDiscreteNode", "CompositeContinuousNode"):
pvals = [str(output[t]) for t in parents]
pvals = output
else:
pvals = [output[t] for t in parents]

Expand Down
18 changes: 15 additions & 3 deletions bamt/nodes/composite_continuous_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import random
import math

from numpy import array
from .base import BaseNode
from .gaussian_node import GaussianNode
from .schema import GaussianParams, HybcprobParams

from sklearn import linear_model
from pandas import DataFrame
from typing import Optional, List, Union
from typing import Optional, List, Union, Dict

from ..log import logger_nodes
from sklearn.metrics import mean_squared_error as mse
Expand Down Expand Up @@ -69,14 +70,25 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams:
msg="Composite Continuous Node should always have a parent"
)

@staticmethod
def choose(node_info: GaussianParams, pvals: List[float]) -> float:
def choose(self, node_info: GaussianParams, pvals: Dict) -> float:
"""
Return value from Logit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

print("ENCODERS OF", self.name, "\n", self.encoders)

for parent_key in pvals:
if not isinstance(pvals[parent_key], (float, int)):
parent_value_array = np.array(pvals[parent_key])
pvals[parent_key] = self.encoders[parent_key].transform(parent_value_array.reshape(1, -1))

pvals = list(pvals.values())

print("TRANSFORMED PVALS \n", pvals)

if pvals:
for el in pvals:
if str(el) == "nan":
Expand Down
16 changes: 13 additions & 3 deletions bamt/nodes/composite_discrete_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from sklearn import linear_model
from pandas import DataFrame
from typing import Optional, List, Union
from typing import Optional, List, Union, Dict


class CompositeDiscreteNode(BaseNode):
Expand Down Expand Up @@ -58,8 +58,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams:
"serialization": serialization_name,
}

@staticmethod
def choose(node_info: LogitParams, pvals: List[Union[float]]) -> str:
def choose(self, node_info: LogitParams, pvals: Dict) -> str:
"""
Return value from Logit node
params:
Expand All @@ -69,6 +68,17 @@ def choose(node_info: LogitParams, pvals: List[Union[float]]) -> str:

rindex = 0

print("ENCODERS OF", self.name, "\n", self.encoders)

for parent_key in pvals:
if not isinstance(pvals[parent_key], (float, int)):
parent_value_array = np.array(pvals[parent_key])
pvals[parent_key] = self.encoders[parent_key].transform(parent_value_array.reshape(1, -1))

pvals = list(pvals.values())

print("TRANSFORMED PVALS \n", pvals)

if len(node_info["classes"]) > 1:
if node_info["serialization"] == "joblib":
model = joblib.load(node_info["classifier_obj"])
Expand Down

0 comments on commit 2cfd46d

Please sign in to comment.