Skip to content

Commit

Permalink
tests and display improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman223 committed Aug 11, 2023
1 parent 887aa99 commit 8d14caf
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 52 deletions.
2 changes: 1 addition & 1 deletion bamt/display/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from display import Display
from .display import Display


def plot_(output, *args):
Expand Down
5 changes: 1 addition & 4 deletions bamt/display/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@ def build(self, nodes, edges, **kwargs):

network.hrepulsion(node_distance=300, central_gravity=0.5)

if not (os.path.exists("visualization_result")):
os.mkdir("visualization_result")

return network.show(f"visualization_result/" + self.output)
return network.show(self.output)

@staticmethod
def get_info(bn, as_df):
Expand Down
35 changes: 0 additions & 35 deletions bamt/display/test.py

This file was deleted.

8 changes: 7 additions & 1 deletion bamt/logging.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[loggers]
keys=root, preprocessor, builder, nodes, network, py.warnings
keys=root, preprocessor, builder, nodes, network, display, py.warnings

[handlers]
keys=consoleHandler, fileHandler
Expand Down Expand Up @@ -35,6 +35,12 @@ qualname=nodes
handlers=consoleHandler, fileHandler
propagate=0

[logger_display]
level=INFO
qualname=display
handlers=consoleHandler, fileHandler
propagate=0


[logger_py.warnings]
level=INFO
Expand Down
4 changes: 3 additions & 1 deletion bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,9 @@ def find_family(
with_nodes: Optional[List] = None,
plot_to: Optional[str] = None,
):
structure = GraphUtils.GraphAnalyzer(self).find_family(node_name, height, depth, with_nodes)
structure = GraphUtils.GraphAnalyzer(self).find_family(
node_name, height, depth, with_nodes
)

if plot_to:
plot_(
Expand Down
11 changes: 5 additions & 6 deletions bamt/utils/GraphUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def markov_blanket(self, node_name: str):
nodes = parents + children + fremd_eltern + [node_name]

edges = self._isolate_structure(nodes)
return {"nodes": nodes, "edges": edges}
return {"nodes": list(set(nodes)), "edges": edges}

def _collect_height(self, node_name, height):
nodes = []
Expand Down Expand Up @@ -143,12 +143,11 @@ def find_family(self, *args):
else:
with_nodes = list(with_nodes)
nodes = (
self._collect_depth(node_name, depth)
+ self._collect_height(node_name, height)
+ [node_name]
self._collect_depth(node_name, depth)
+ self._collect_height(node_name, height)
+ [node_name]
)

nodes = list(set(nodes + with_nodes))

return {"nodes": nodes,
"edges": self._isolate_structure(nodes + with_nodes)}
return {"nodes": nodes, "edges": self._isolate_structure(nodes + with_nodes)}
5 changes: 1 addition & 4 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from contextlib import redirect_stdout, redirect_stderr
import os

import itertools
import unittest

Expand All @@ -9,7 +6,7 @@
import pandas as pd

from bamt.builders.builders_base import StructureBuilder, VerticesDefiner
from bamt.builders.hc_builder import HCStructureBuilder, HillClimbDefiner
from bamt.builders.hc_builder import HillClimbDefiner
from bamt.builders.evo_builder import EvoStructureBuilder

from bamt.nodes.gaussian_node import GaussianNode
Expand Down
90 changes: 90 additions & 0 deletions tests/test_graph_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import unittest
import logging

from bamt.utils import GraphUtils
from bamt.networks.discrete_bn import DiscreteBN
from bamt.nodes.discrete_node import DiscreteNode
from bamt.builders.builders_base import VerticesDefiner
from bamt.display import plot_

logging.getLogger("builder").setLevel(logging.CRITICAL)


class TestGraphAnalyzer(unittest.TestCase):
def setUp(self):
self.bn = DiscreteBN()
definer = VerticesDefiner(descriptor={"types": {
"Node0": "disc",
"Node1": "disc",
"Node2": "disc",
"Node3": "disc_num",
"Node4": "disc",
"Node5": "disc",
"Node6": "disc",
"Node7": "disc",
"Node8": "disc",
"Node9": "disc",
},
"signs": {},
},
regressor=None)

definer.skeleton["V"] = [
DiscreteNode(name=f"Node{i}") for i in range(10)
]
definer.skeleton["E"] = [
("Node0", "Node1"),
("Node0", "Node2"),
("Node2", "Node3"),
("Node4", "Node7"),
("Node1", "Node5"),
("Node5", "Node6"),
("Node7", "Node0"),
("Node8", "Node1"),
("Node9", "Node2"),
]
definer.get_family()

self.bn.nodes = definer.skeleton["V"]
self.bn.edges = definer.skeleton["E"]

self.analyzer = GraphUtils.GraphAnalyzer(self.bn)

def test_markov_blanket(self):
result = self.analyzer.markov_blanket("Node0")
result["nodes"] = sorted(result["nodes"])
self.assertEqual(
{'edges': [('Node0', 'Node1'),
('Node0', 'Node2'),
('Node7', 'Node0'),
('Node8', 'Node1'),
('Node9', 'Node2')],
'nodes': sorted(['Node0', 'Node1', 'Node2', 'Node7', 'Node8', 'Node9'])}
,
result
)

def test_find_family(self):
without_parents = self.analyzer.find_family("Node0", 0, 2, None)
without_parents["nodes"] = sorted(without_parents["nodes"])
self.assertEqual(
{'nodes': sorted(['Node3', 'Node2', 'Node1', 'Node0', 'Node5']),
'edges': [('Node0', 'Node1'), ('Node0', 'Node2'), ('Node2', 'Node3'), ('Node1', 'Node5')]},
without_parents
)

without_children = self.analyzer.find_family("Node0", 2, 0, None)
without_children["nodes"] = sorted(without_children["nodes"])
self.assertEqual(
{'nodes': sorted(['Node4', 'Node7', 'Node0']),
'edges': [('Node4', 'Node7'), ('Node7', 'Node0')]},
without_children
)

plot_("here.html",
[self.bn[node] for node in without_children["nodes"]],
without_children["edges"])


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 8d14caf

Please sign in to comment.