Skip to content

Commit

Permalink
encoding moved to network level
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Aug 4, 2023
1 parent 2cfd46d commit 52003d0
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 61 deletions.
69 changes: 57 additions & 12 deletions bamt/networks/composite_bn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
import re
import random

import numpy as np
from sklearn.preprocessing import LabelEncoder

from tqdm import tqdm
from bamt.log import logger_network
from bamt.networks.base import BaseNetwork
from bamt.networks.base import BaseNetwork, STORAGE
import pandas as pd
from typing import Optional, Dict, Union, List
from bamt.builders.composite_builder import CompositeStructureBuilder, CompositeDefiner
Expand All @@ -22,6 +24,7 @@ def __init__(self):
self._allowed_dtypes = ["cont", "disc", "disc_num"]
self.type = "Composite"
self.parent_models = {}
self.encoders = {}

def add_nodes(self, descriptor: Dict[str, Dict[str, str]]):
"""
Expand Down Expand Up @@ -158,17 +161,14 @@ def wrapper():
else:
if self.type == "Discrete":
pvals = [str(output[t]) for t in parents]
elif type(node).__name__ in ("CompositeDiscreteNode", "CompositeContinuousNode"):
pvals = output
else:
pvals = [output[t] for t in parents]

# If any nan from parents, sampling from node blocked.
if any(pd.isnull(pvalue) for pvalue in pvals):
output[node.name] = np.nan
continue
print("NODE \n:", node.name)
print("PVALS \n:", pvals)

node_data = self.distributions[node.name]
if models_dir and ("hybcprob" in node_data.keys()):
for obj, obj_data in node_data["hybcprob"].items():
Expand Down Expand Up @@ -219,14 +219,59 @@ def wrapper():
sample_output = pd.DataFrame.from_dict(seq, orient="columns")

if as_df:
if self.has_logit or self.type == "Composite":
for node in self.nodes:
for feature_key, encoder in node.encoders:
sample_output[feature_key] = encoder[
feature_key
].inverse_transform(sample_output[feature_key])
pass
sample_output = self.decode_categorical_data(sample_output)

return sample_output
else:
return seq

def fit_parameters(self, data: pd.DataFrame, dropna: bool = True, n_jobs: int = -1):
"""
Base function for parameter learning
"""
if dropna:
data = data.dropna()
data.reset_index(inplace=True, drop=True)

if not os.path.isdir(STORAGE):
os.makedirs(STORAGE)

# init folder
if not os.listdir(STORAGE):
os.makedirs(os.path.join(STORAGE, "0"))

index = sorted([int(id) for id in os.listdir(STORAGE)])[-1] + 1
os.makedirs(os.path.join(STORAGE, str(index)))

data = self.encode_categorical_data(data)

# Turn all discrete values to str for learning algorithm
if "disc_num" in self.descriptor["types"].values():
columns_names = [
name
for name, t in self.descriptor["types"].items()
if t in ["disc_num"]
]
data[columns_names] = data.loc[:, columns_names].astype("str")

def worker(node):
return node.fit_parameters(data)

# results = Parallel(n_jobs=n_jobs)(delayed(worker)(node) for node in self.nodes)

results = [worker(node) for node in self.nodes]

for result, node in zip(results, self.nodes):
self.distributions[node.name] = result

def encode_categorical_data(self, data):
for column in data.select_dtypes(include=['object', 'string']).columns:
encoder = LabelEncoder()
data[column] = encoder.fit_transform(data[column])
self.encoders[column] = encoder
return data

def decode_categorical_data(self, data):
for column, encoder in self.encoders.items():
data[column] = encoder.inverse_transform(data[column])
return data
34 changes: 15 additions & 19 deletions bamt/nodes/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from bamt.config import config
import numpy as np

from typing import Union
from sklearn.preprocessing import LabelEncoder
Expand Down Expand Up @@ -91,22 +92,17 @@ def get_path_joblib(node_name: str, specific: str = "") -> str:
)
return path

@staticmethod
def encode_categorical_data_if_any(func):
@wraps(func)
def wrapper(self, data, *args, **kwargs):
for column in self.disc_parents + [self.name]:
if data[column].dtype in ("object", "str"):
encoder = LabelEncoder()
data[column] = encoder.fit_transform(data[column])
self.encoders[column] = encoder
elif data[column].dtype in ("float64", "int64"):
continue
else:
logger_nodes.warning(
msg="Wrong datatype passed to categorical data encoder"
)
result = func(self, data, *args, **kwargs)
return result

return wrapper
def encode_categorical_data_if_any(self, data):
for column in self.disc_parents + [self.name]:
if data[column].dtype in ("object", "str", "string"):
encoder = LabelEncoder()
data[column] = encoder.fit_transform(data[column])
self.encoders[column] = encoder
elif np.issubdtype(data[column].dtype, np.number):
continue
else:
logger_nodes.warning(
msg="Wrong datatype passed to categorical data encoder"
)
return data

28 changes: 11 additions & 17 deletions bamt/nodes/composite_continuous_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, name, regressor: Optional[object] = None):
self.regressor = regressor
self.type = "CompositeContinuous" + f" ({type(self.regressor).__name__})"

@BaseNode.encode_categorical_data_if_any
def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams:
parents = self.cont_parents + self.disc_parents
if parents:
Expand Down Expand Up @@ -66,29 +65,24 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams:
"serialization": "joblib",
}
else:
logger_nodes.warning(
msg="Composite Continuous Node should always have a parent"
)
mean_base = np.mean(data[self.name].values)
variance = np.var(data[self.name].values)
return {
"mean": mean_base,
"regressor_obj": None,
"regressor": None,
"variance": variance,
"serialization": None,
}

def choose(self, node_info: GaussianParams, pvals: Dict) -> float:
@staticmethod
def choose(node_info: GaussianParams, pvals: List[float]) -> float:
"""
Return value from Logit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

print("ENCODERS OF", self.name, "\n", self.encoders)

for parent_key in pvals:
if not isinstance(pvals[parent_key], (float, int)):
parent_value_array = np.array(pvals[parent_key])
pvals[parent_key] = self.encoders[parent_key].transform(parent_value_array.reshape(1, -1))

pvals = list(pvals.values())

print("TRANSFORMED PVALS \n", pvals)

if pvals:
for el in pvals:
if str(el) == "nan":
Expand Down
17 changes: 4 additions & 13 deletions bamt/nodes/composite_discrete_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, name, classifier: Optional[object] = None):
self.classifier = classifier
self.type = "CompositeDiscrete" + f" ({type(self.classifier).__name__})"

@BaseNode.encode_categorical_data_if_any
def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams:
model_ser = None
path = None
Expand Down Expand Up @@ -58,7 +57,8 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams:
"serialization": serialization_name,
}

def choose(self, node_info: LogitParams, pvals: Dict) -> str:
@staticmethod
def choose(node_info: LogitParams, pvals: List[Union[float]]) -> str:
"""
Return value from Logit node
params:
Expand All @@ -68,25 +68,16 @@ def choose(self, node_info: LogitParams, pvals: Dict) -> str:

rindex = 0

print("ENCODERS OF", self.name, "\n", self.encoders)

for parent_key in pvals:
if not isinstance(pvals[parent_key], (float, int)):
parent_value_array = np.array(pvals[parent_key])
pvals[parent_key] = self.encoders[parent_key].transform(parent_value_array.reshape(1, -1))

pvals = list(pvals.values())

print("TRANSFORMED PVALS \n", pvals)

if len(node_info["classes"]) > 1:
if node_info["serialization"] == "joblib":
model = joblib.load(node_info["classifier_obj"])
else:
# str_model = node_info["classifier_obj"].decode('latin1').replace('\'', '\"')
a = node_info["classifier_obj"].encode("latin1")
model = pickle.loads(a)
distribution = model.predict_proba(np.array(pvals).reshape(1, -1))[0]

# choose
rand = random.random()
lbound = 0
ubound = 0
Expand Down

0 comments on commit 52003d0

Please sign in to comment.