Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decomoposing and code improvement #50

Merged
merged 10 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions bamt/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from pgmpy.estimators import HillClimbSearch
from bamt.redef_HC import hc as hc_method

from bamt import nodes
from bamt.nodes import *
Roman223 marked this conversation as resolved.
Show resolved Hide resolved
from bamt.log import logger_builder
from pandas import DataFrame
from bamt.utils import GraphUtils as gru

from typing import Dict, List, Optional, Tuple, Callable, TypedDict, Union
from typing import Dict, List, Optional, Tuple, Callable, TypedDict, Union, Sequence


class ParamDict(TypedDict, total=False):
init_edges: Optional[Tuple[str, str]]
init_edges: Optional[Sequence[str]]
init_nodes: Optional[List[str]]
remove_init_edges: bool
white_list: Optional[Tuple[str, str]]
Expand All @@ -30,11 +30,10 @@ def __init__(self, descriptor: Dict[str, Dict[str, str]]):
"""
:param descriptor: a dict with types and signs of nodes
Attributes:
Skeleton: dict;
black_list: a list with restricted connections;
"""
self.skeleton = {'V': None,
'E': None}
self.skeleton = {'V': [],
'E': []}
self.descriptor = descriptor

self.has_logit = bool
Expand All @@ -54,7 +53,6 @@ def restrict(self, data: DataFrame,
datacol = data.columns.to_list()

if not self.has_logit:

# Has_logit flag allows BN building edges between cont and disc
RESTRICTIONS = [('cont', 'disc'), ('cont', 'disc_num')]
for x, y in itertools.product(datacol, repeat=2):
Expand Down Expand Up @@ -124,20 +122,20 @@ def __init__(self, descriptor: Dict[str, Dict[str, str]],
# Notice that vertices are used only by Builders
self.vertices = []

Node = None
node = None
# LEVEL 1: Define a general type of node: Discrete or Gaussian
for vertex, type in self.descriptor['types'].items():
if type in ['disc_num', 'disc']:
Node = nodes.DiscreteNode(name=vertex)
node = discrete_node.DiscreteNode(name=vertex)
elif type == 'cont':
Node = nodes.GaussianNode(name=vertex, regressor=regressor)
node = gaussian_node.GaussianNode(name=vertex, regressor=regressor)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Возможно, вот эту часть стоит переработать в виде Factory.

else:
msg = f"""First stage of automatic vertex detection failed on {vertex} due TypeError ({type}).
Set vertex manually (by calling set_nodes()) or investigate the error."""
logger_builder.error(msg)
continue

self.vertices.append(Node)
self.vertices.append(node)

def overwrite_vertex(
self,
Expand All @@ -158,27 +156,27 @@ def overwrite_vertex(
if 'Discrete' in node_instance.type:
if node_instance.cont_parents:
if not node_instance.disc_parents:
Node = nodes.LogitNode(
Node = logit_node.LogitNode(
name=node_instance.name, classifier=classifier)

elif node_instance.disc_parents:
Node = nodes.ConditionalLogitNode(
Node = conditional_logit_node.ConditionalLogitNode(
name=node_instance.name, classifier=classifier)

if use_mixture:
if 'Gaussian' in node_instance.type:
if not node_instance.disc_parents:
Node = nodes.MixtureGaussianNode(
Node = mixture_gaussian_node.MixtureGaussianNode(
name=node_instance.name)
elif node_instance.disc_parents:
Node = nodes.ConditionalMixtureGaussianNode(
Node = conditional_mixture_gaussian_node.ConditionalMixtureGaussianNode(
name=node_instance.name)
else:
continue
else:
if 'Gaussian' in node_instance.type:
if node_instance.disc_parents:
Node = nodes.ConditionalGaussianNode(
Node = conditional_gaussian_node.ConditionalGaussianNode(
name=node_instance.name, regressor=regressor)
else:
continue
Expand Down Expand Up @@ -222,20 +220,21 @@ def __init__(self, data: DataFrame, descriptor: Dict[str, Dict[str, str]],
def apply_K2(self,
data: DataFrame,
init_edges: Optional[List[Tuple[str,
str]]],
str]]],
progress_bar: bool,
remove_init_edges: bool,
white_list: Optional[List[Tuple[str,
str]]]):
str]]]):
"""
:param init_edges: list of tuples, a graph to start learning with
:param remove_init_edges: allows changes in model defined by user
:param remove_init_edges: allows changes in a model defined by user
:param data: user's data
:param progress_bar: verbose regime
:param white_list: list of allowed edges
"""
import bamt.utils.GraphUtils as gru
if not all([i in ['disc', 'disc_num']
for i in gru.nodes_types(data).values()]):
for i in gru.nodes_types(data).values()]):
logger_builder.error(
f"K2 deals only with discrete data. Continuous data: {[col for col, type in gru.nodes_types(data).items() if type not in ['disc', 'disc_num']]}")
return None
Expand Down Expand Up @@ -279,13 +278,13 @@ def apply_group1(self,
data: DataFrame,
progress_bar: bool,
init_edges: Optional[List[Tuple[str,
str]]],
str]]],
remove_init_edges: bool,
white_list: Optional[List[Tuple[str,
str]]]):

str]]]):
# (score == "MI") | (score == "LL") | (score == "BIC") | (score == "AIC")
Roman223 marked this conversation as resolved.
Show resolved Hide resolved
column_name_dict = dict([(n.name, i)
for i, n in enumerate(self.vertices)])
for i, n in enumerate(self.vertices)])
blacklist_new = []
for pair in self.black_list:
blacklist_new.append(
Expand Down
4 changes: 4 additions & 0 deletions bamt/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__all__ = ["base", "hybrid_bn",
"continuous_bn", "discrete_bn",
"big_brave_bn",
]
Loading