Skip to content

Commit

Permalink
Feature: get distribution from node (#76)
Browse files Browse the repository at this point in the history
* feature: get_dist to get conditional distribution from node

* tests and minor improvements

* when pvals don't set, return marginal dist

* probas in fractions

* Hotfix

---------

Co-authored-by: Yury Kaminsky <[email protected]>
  • Loading branch information
Roman223 and jrzkaminski authored Aug 15, 2023
1 parent 2e1b1d8 commit fc3362c
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 90 deletions.
21 changes: 20 additions & 1 deletion bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,26 @@ def find_family(
plot_(
plot_to, [self[name] for name in structure["nodes"]], structure["edges"]
)
return structure

def get_dist(self, node_name: str, pvals: Optional[dict] = None):
"""
Get a distribution from node with known parent values (conditional distribution).
:param node_name: name of node
:param pvals: parent values
"""
if not self.distributions:
logger_network.error("Empty parameters. Call fit_params first.")
return
node = self[node_name]

parents = node.cont_parents + node.disc_parents
if not parents:
return self.distributions[node_name]

pvals = [pvals[parent] for parent in parents]

return node.get_dist(node_info=self.distributions[node_name], pvals=pvals)

def _encode_categorical_data(self, data):
for column in data.select_dtypes(include=["object", "string"]).columns:
Expand Down
4 changes: 4 additions & 0 deletions bamt/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,7 @@ def get_path_joblib(node_name: str, specific: str = "") -> str:
os.path.join(path_to_check, f"{specific}.joblib.compressed")
)
return path

@staticmethod
def get_dist(node_info, pvals):
pass
41 changes: 23 additions & 18 deletions bamt/nodes/conditional_gaussian_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,7 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, CondGaussParams
}
return {"hybcprob": hycprob}

def choose(
self,
node_info: Dict[str, Dict[str, CondGaussParams]],
pvals: List[Union[str, float]],
) -> float:
"""
Return value from ConditionalLogit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

def get_dist(self, node_info, pvals):
dispvals = []
lgpvals = []
for pval in pvals:
Expand All @@ -138,7 +127,7 @@ def choose(
flag = True
break
if flag:
return np.nan
return np.nan, np.nan
else:
if lgdistribution["regressor"]:
if lgdistribution["serialization"] == "joblib":
Expand All @@ -150,14 +139,30 @@ def choose(

cond_mean = model.predict(np.array(lgpvals).reshape(1, -1))[0]
variance = lgdistribution["variance"]
return random.gauss(cond_mean, variance)
return cond_mean, variance
else:
return np.nan
return np.nan, np.nan

else:
return random.gauss(
lgdistribution["mean"], math.sqrt(lgdistribution["variance"])
)
return lgdistribution["mean"], math.sqrt(lgdistribution["variance"])

def choose(
self,
node_info: Dict[str, Dict[str, CondGaussParams]],
pvals: List[Union[str, float]],
) -> float:
"""
Return value from ConditionalLogit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

cond_mean, variance = self.get_dist(node_info, pvals)
if not cond_mean or not variance:
return np.nan

return random.gauss(cond_mean, variance)

def predict(
self,
Expand Down
38 changes: 27 additions & 11 deletions bamt/nodes/conditional_logit_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,7 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, LogitParams]]:
return {"hybcprob": hycprob}

@staticmethod
def choose(
node_info: Dict[str, Dict[str, LogitParams]], pvals: List[Union[str, float]]
) -> str:
"""
Return value from ConditionalLogit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

def get_dist(node_info, pvals, **kwargs):
dispvals = []
lgpvals = []
for pval in pvals:
Expand All @@ -140,6 +131,32 @@ def choose(

distribution = model.predict_proba(np.array(lgpvals).reshape(1, -1))[0]

if not kwargs.get("inner", False):
return distribution
else:
return distribution, lgdistribution
else:
if not kwargs.get("inner", False):
return np.array([1.0])
else:
return np.array([1.0]), lgdistribution

def choose(
self,
node_info: Dict[str, Dict[str, LogitParams]],
pvals: List[Union[str, float]],
) -> str:
"""
Return value from ConditionalLogit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

distribution, lgdistribution = self.get_dist(node_info, pvals, inner=True)

# JOBLIB
if len(lgdistribution["classes"]) > 1:
rand = random.random()
rindex = 0
lbound = 0
Expand All @@ -152,7 +169,6 @@ def choose(
else:
lbound = ubound
return str(lgdistribution["classes"][rindex])

else:
return str(lgdistribution["classes"][0])

Expand Down
47 changes: 32 additions & 15 deletions bamt/nodes/conditional_mixture_gaussian_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,27 +112,21 @@ def fit_parameters(
return {"hybcprob": hycprob}

@staticmethod
def choose(
node_info: Dict[str, Dict[str, CondMixtureGaussParams]],
pvals: List[Union[str, float]],
) -> Optional[float]:
"""
Function to get value from ConditionalMixtureGaussian node
params:
node_info: nodes info from distributions
pvals: parent values
"""
dispvals = []
def get_dist(node_info, pvals):
lgpvals = []
dispvals = []

for pval in pvals:
if (isinstance(pval, str)) | (isinstance(pval, int)):
dispvals.append(pval)
else:
lgpvals.append(pval)

lgdistribution = node_info["hybcprob"][str(dispvals)]
mean = lgdistribution["mean"]
covariance = lgdistribution["covars"]
w = lgdistribution["coef"]

if len(w) != 0:
if len(lgpvals) != 0:
indexes = [i for i in range(1, (len(lgpvals) + 1), 1)]
Expand All @@ -145,17 +139,40 @@ def choose(
covariances=covariance,
)
cond_gmm = gmm.condition(indexes, [lgpvals])
sample = cond_gmm.sample(1)[0][0]
return cond_gmm.means, cond_gmm.covariances, cond_gmm.priors
else:
sample = np.nan
return np.nan, np.nan, np.nan
else:
n_comp = len(w)
gmm = GMM(
n_components=n_comp, priors=w, means=mean, covariances=covariance
)
sample = gmm.sample(1)[0][0]
return gmm.means, gmm.covariances, gmm.priors
else:
sample = np.nan
return np.nan, np.nan, np.nan

def choose(
self,
node_info: Dict[str, Dict[str, CondMixtureGaussParams]],
pvals: List[Union[str, float]],
) -> Optional[float]:
"""
Function to get value from ConditionalMixtureGaussian node
params:
node_info: nodes info from distributions
pvals: parent values
"""
mean, covariance, w = self.get_dist(node_info, pvals)

n_comp = len(w)

gmm = GMM(
n_components=n_comp,
priors=w,
means=mean,
covariances=covariance,
)
sample = gmm.sample(1)[0][0]
return sample

@staticmethod
Expand Down
22 changes: 15 additions & 7 deletions bamt/nodes/discrete_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,25 @@ def worker(node: Type[BaseNode]) -> DiscreteParams:
tight_form = conditional_dist.to_dict("tight")

for comb, probs in zip(tight_form["index"], tight_form["data"]):
cprob[str([str(i) for i in comb])] = probs
if len(parents) > 1:
cprob[str([str(i) for i in comb])] = probs
else:
cprob[f"['{comb}']"] = probs
return {"cprob": cprob, "vals": vals}

pool = ThreadPoolExecutor(num_workers)
future = pool.submit(worker, self)
return future.result()

@staticmethod
def choose(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str:
def get_dist(node_info, pvals):
if not pvals:
return node_info["cprob"]
else:
# noinspection PyTypeChecker
return node_info["cprob"][str(pvals)]

def choose(self, node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str:
"""
Return value from discrete node
params:
Expand All @@ -66,11 +76,9 @@ def choose(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str:
rindex = 0
random.seed()
vals = node_info["vals"]
if not pvals:
dist = node_info["cprob"]
else:
# noinspection PyTypeChecker
dist = node_info["cprob"][str(pvals)]

dist = self.get_dist(node_info, pvals)

lbound = 0
ubound = 0
rand = random.random()
Expand Down
25 changes: 15 additions & 10 deletions bamt/nodes/gaussian_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,8 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams:
"serialization": None,
}

def choose(self, node_info: GaussianParams, pvals: List[float]) -> float:
"""
Return value from Logit node
params:
node_info: nodes info from distributions
pvals: parent values
"""
def get_dist(self, node_info, pvals):
var = node_info["variance"]
if pvals:
for el in pvals:
if str(el) == "nan":
Expand All @@ -97,10 +92,20 @@ def choose(self, node_info: GaussianParams, pvals: List[float]) -> float:
pvals = [int(item) if isinstance(item, str) else item for item in pvals]

cond_mean = model.predict(np.array(pvals).reshape(1, -1))[0]
var = node_info["variance"]
return random.gauss(cond_mean, var)
return cond_mean, var
else:
return random.gauss(node_info["mean"], math.sqrt(node_info["variance"]))
return node_info["mean"], math.sqrt(var)

def choose(self, node_info: GaussianParams, pvals: List[float]) -> float:
"""
Return value from Logit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

cond_mean, var = self.get_dist(node_info, pvals)
return random.gauss(cond_mean, var)

@staticmethod
def predict(node_info: GaussianParams, pvals: List[float]) -> float:
Expand Down
31 changes: 18 additions & 13 deletions bamt/nodes/logit_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams:
"serialization": serialization_name,
}

def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str:
"""
Return value from Logit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

rindex = 0

def get_dist(self, node_info, pvals):
if len(node_info["classes"]) > 1:
if node_info["serialization"] == "joblib":
model = joblib.load(node_info["classifier_obj"])
Expand All @@ -76,9 +67,24 @@ def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str:

if type(self).__name__ == "CompositeDiscreteNode":
pvals = [int(item) if isinstance(item, str) else item for item in pvals]
distribution = model.predict_proba(np.array(pvals).reshape(1, -1))[0]

# choose
return model.predict_proba(np.array(pvals).reshape(1, -1))[0]
else:
return np.array([1.0])

def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str:
"""
Return value from Logit node
params:
node_info: nodes info from distributions
pvals: parent values
"""

rindex = 0

distribution = self.get_dist(node_info, pvals)

if len(node_info["classes"]) > 1:
rand = random.random()
lbound = 0
ubound = 0
Expand All @@ -91,7 +97,6 @@ def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str:
lbound = ubound

return str(node_info["classes"][rindex])

else:
return str(node_info["classes"][0])

Expand Down
Loading

0 comments on commit fc3362c

Please sign in to comment.