Skip to content

Commit

Permalink
Hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman223 committed Aug 15, 2023
1 parent aba2926 commit b2cbc74
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
3 changes: 1 addition & 2 deletions bamt/nodes/gaussian_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams:
"serialization": None,
}

@staticmethod
def get_dist(node_info, pvals):
def get_dist(self, node_info, pvals):
var = node_info["variance"]
if pvals:
for el in pvals:
Expand Down
7 changes: 3 additions & 4 deletions bamt/nodes/logit_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,18 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams:
"serialization": serialization_name,
}

@staticmethod
def get_dist(node_info, pvals):
def get_dist(self, node_info, pvals):
if len(node_info["classes"]) > 1:
if node_info["serialization"] == "joblib":
model = joblib.load(node_info["classifier_obj"])
else:
# str_model = node_info["classifier_obj"].decode('latin1').replace('\'', '\"')
a = node_info["classifier_obj"].encode("latin1")
model = pickle.loads(a)

if type(self).__name__ == "CompositeDiscreteNode":
pvals = [int(item) if isinstance(item, str) else item for item in pvals]

return model.predict_proba(np.array(pvals).reshape(1, -1))[0]
else:
return np.array([1.0])
Expand Down

0 comments on commit b2cbc74

Please sign in to comment.