Skip to content

Commit

Permalink
Code reformatting to black style
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Jul 11, 2023
1 parent 45c33e0 commit 33ad9be
Show file tree
Hide file tree
Showing 51 changed files with 3,279 additions and 1,884 deletions.
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
- | |license|
* - stats
- | |downloads_stats| |downloads_monthly| |downloads_weekly|
* - style
- | |Black|

Repository of a data modeling and analysis tool based on Bayesian networks

Expand Down Expand Up @@ -241,3 +243,6 @@ Citation

.. |coverage| image:: https://codecov.io/github/aimclub/BAMT/branch/master/graph/badge.svg?token=fA4qsxGqTC
:target: https://codecov.io/github/aimclub/BAMT

.. |Black| image:: https://img.shields.io/badge/code%20style-black-000000.svg
.. _Black: https://github.com/psf/black
5 changes: 1 addition & 4 deletions bamt/builders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
__all__ = ["builders_base",
"evo_builder",
"hc_builder"
]
__all__ = ["builders_base", "evo_builder", "hc_builder"]
113 changes: 61 additions & 52 deletions bamt/builders/builders_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,39 @@ def __init__(self, descriptor: Dict[str, Dict[str, str]]):
Attributes:
black_list: a list with restricted connections;
"""
self.skeleton = {'V': [],
'E': []}
self.skeleton = {"V": [], "E": []}
self.descriptor = descriptor

self.has_logit = bool

self.black_list = None

def restrict(self, data: DataFrame,
init_nodes: Optional[List[str]],
bl_add: Optional[List[str]]):
def restrict(
self,
data: DataFrame,
init_nodes: Optional[List[str]],
bl_add: Optional[List[str]],
):
"""
:param data: data to deal with
:param init_nodes: nodes to begin with (thus they have no parents)
:param bl_add: additional vertices
"""
node_type = self.descriptor['types']
node_type = self.descriptor["types"]
blacklist = []
datacol = data.columns.to_list()

if not self.has_logit:
# Has_logit flag allows BN building edges between cont and disc
RESTRICTIONS = [('cont', 'disc'), ('cont', 'disc_num')]
RESTRICTIONS = [("cont", "disc"), ("cont", "disc_num")]
for x, y in itertools.product(datacol, repeat=2):
if x != y:
if (node_type[x], node_type[y]) in RESTRICTIONS:
blacklist.append((x, y))
else:
self.black_list = []
if init_nodes:
blacklist += [(x, y)
for x in datacol for y in init_nodes if x != y]
blacklist += [(x, y) for x in datacol for y in init_nodes if x != y]
if bl_add:
blacklist = blacklist + bl_add
self.black_list = blacklist
Expand All @@ -75,17 +76,17 @@ def get_family(self):
"""
A function that updates a skeleton;
"""
if not self.skeleton['V']:
if not self.skeleton["V"]:
logger_builder.error("Vertex list is None")
return None
if not self.skeleton['E']:
if not self.skeleton["E"]:
logger_builder.error("Edges list is None")
return None
for node_instance in self.skeleton['V']:
for node_instance in self.skeleton["V"]:
node = node_instance.name
children = []
parents = []
for edge in self.skeleton['E']:
for edge in self.skeleton["E"]:
if node in edge:
if edge.index(node) == 0:
children.append(edge[1])
Expand All @@ -95,29 +96,30 @@ def get_family(self):
disc_parents = []
cont_parents = []
for parent in parents:
if self.descriptor['types'][parent] in ['disc', 'disc_num']:
if self.descriptor["types"][parent] in ["disc", "disc_num"]:
disc_parents.append(parent)
else:
cont_parents.append(parent)

id = self.skeleton['V'].index(node_instance)
self.skeleton['V'][id].disc_parents = disc_parents
self.skeleton['V'][id].cont_parents = cont_parents
self.skeleton['V'][id].children = children
id = self.skeleton["V"].index(node_instance)
self.skeleton["V"][id].disc_parents = disc_parents
self.skeleton["V"][id].cont_parents = cont_parents
self.skeleton["V"][id].children = children

ordered = gru.toporder(self.skeleton['V'], self.skeleton['E'])
not_ordered = [node.name for node in self.skeleton['V']]
ordered = gru.toporder(self.skeleton["V"], self.skeleton["E"])
not_ordered = [node.name for node in self.skeleton["V"]]
mask = [not_ordered.index(name) for name in ordered]
self.skeleton['V'] = [self.skeleton['V'][i] for i in mask]
self.skeleton["V"] = [self.skeleton["V"][i] for i in mask]


class VerticesDefiner(StructureBuilder):
"""
Main class for defining vertices
"""

def __init__(self, descriptor: Dict[str, Dict[str, str]],
regressor: Optional[object]):
def __init__(
self, descriptor: Dict[str, Dict[str, str]], regressor: Optional[object]
):
"""
Automatically creates a list of nodes
"""
Expand All @@ -127,10 +129,10 @@ def __init__(self, descriptor: Dict[str, Dict[str, str]],

node = None
# LEVEL 1: Define a general type of node: Discrete or Gaussian
for vertex, type in self.descriptor['types'].items():
if type in ['disc_num', 'disc']:
for vertex, type in self.descriptor["types"].items():
if type in ["disc_num", "disc"]:
node = DiscreteNode(name=vertex)
elif type == 'cont':
elif type == "cont":
node = GaussianNode(name=vertex, regressor=regressor)
else:
msg = f"""First stage of automatic vertex detection failed on {vertex} due TypeError ({type}).
Expand All @@ -141,11 +143,12 @@ def __init__(self, descriptor: Dict[str, Dict[str, str]],
self.vertices.append(node)

def overwrite_vertex(
self,
has_logit: bool,
use_mixture: bool,
classifier: Optional[Callable],
regressor: Optional[Callable]):
self,
has_logit: bool,
use_mixture: bool,
classifier: Optional[Callable],
regressor: Optional[Callable],
):
"""
Level 2: Redefined nodes according structure (parents)
:param classifier: an object to pass into logit, condLogit nodes
Expand All @@ -156,42 +159,43 @@ def overwrite_vertex(
for node_instance in self.vertices:
node = node_instance
if has_logit:
if 'Discrete' in node_instance.type:
if "Discrete" in node_instance.type:
if node_instance.cont_parents:
if not node_instance.disc_parents:
node = LogitNode(
name=node_instance.name, classifier=classifier)
name=node_instance.name, classifier=classifier
)

elif node_instance.disc_parents:
node = ConditionalLogitNode(
name=node_instance.name, classifier=classifier)
name=node_instance.name, classifier=classifier
)

if use_mixture:
if 'Gaussian' in node_instance.type:
if "Gaussian" in node_instance.type:
if not node_instance.disc_parents:
node = MixtureGaussianNode(
name=node_instance.name)
node = MixtureGaussianNode(name=node_instance.name)
elif node_instance.disc_parents:
node = ConditionalMixtureGaussianNode(
name=node_instance.name)
node = ConditionalMixtureGaussianNode(name=node_instance.name)
else:
continue
else:
if 'Gaussian' in node_instance.type:
if "Gaussian" in node_instance.type:
if node_instance.disc_parents:
node = ConditionalGaussianNode(
name=node_instance.name, regressor=regressor)
name=node_instance.name, regressor=regressor
)
else:
continue

if node_instance == node:
continue

id = self.skeleton['V'].index(node_instance)
id = self.skeleton["V"].index(node_instance)
node.disc_parents = node_instance.disc_parents
node.cont_parents = node_instance.cont_parents
node.children = node_instance.children
self.skeleton['V'][id] = node
self.skeleton["V"][id] = node


class EdgesDefiner(StructureBuilder):
Expand All @@ -200,15 +204,20 @@ def __init__(self, descriptor: Dict[str, Dict[str, str]]):


class BaseDefiner(VerticesDefiner, EdgesDefiner):
def __init__(self, data: DataFrame, descriptor: Dict[str, Dict[str, str]],
scoring_function: Union[Tuple[str, Callable], Tuple[str]],
regressor: Optional[object] = None):

def __init__(
self,
data: DataFrame,
descriptor: Dict[str, Dict[str, str]],
scoring_function: Union[Tuple[str, Callable], Tuple[str]],
regressor: Optional[object] = None,
):
self.scoring_function = scoring_function
self.params = {'init_edges': None,
'init_nodes': None,
'remove_init_edges': True,
'white_list': None,
'bl_add': None}
self.params = {
"init_edges": None,
"init_nodes": None,
"remove_init_edges": True,
"white_list": None,
"bl_add": None,
}
super().__init__(descriptor, regressor=regressor)
self.optimizer = None # will be defined in subclasses
Loading

0 comments on commit 33ad9be

Please sign in to comment.