From 8d14caf0e6b6ed94d0b21b92b2a39a001e19db08 Mon Sep 17 00:00:00 2001 From: Roman223 Date: Fri, 11 Aug 2023 14:10:42 +0300 Subject: [PATCH] tests and display improvements --- bamt/display/__init__.py | 2 +- bamt/display/display.py | 5 +- bamt/display/test.py | 35 -------------- bamt/logging.conf | 8 +++- bamt/networks/base.py | 4 +- bamt/utils/GraphUtils.py | 11 ++--- tests/test_builders.py | 5 +- tests/test_graph_analyzer.py | 90 ++++++++++++++++++++++++++++++++++++ 8 files changed, 108 insertions(+), 52 deletions(-) delete mode 100644 bamt/display/test.py create mode 100644 tests/test_graph_analyzer.py diff --git a/bamt/display/__init__.py b/bamt/display/__init__.py index 90ae7d9..a22b2c3 100644 --- a/bamt/display/__init__.py +++ b/bamt/display/__init__.py @@ -1,4 +1,4 @@ -from display import Display +from .display import Display def plot_(output, *args): diff --git a/bamt/display/display.py b/bamt/display/display.py index b12a020..20636ae 100644 --- a/bamt/display/display.py +++ b/bamt/display/display.py @@ -140,10 +140,7 @@ def build(self, nodes, edges, **kwargs): network.hrepulsion(node_distance=300, central_gravity=0.5) - if not (os.path.exists("visualization_result")): - os.mkdir("visualization_result") - - return network.show(f"visualization_result/" + self.output) + return network.show(self.output) @staticmethod def get_info(bn, as_df): diff --git a/bamt/display/test.py b/bamt/display/test.py deleted file mode 100644 index 82779be..0000000 --- a/bamt/display/test.py +++ /dev/null @@ -1,35 +0,0 @@ -from bamt.networks.hybrid_bn import HybridBN -from bamt.preprocessors import Preprocessor -import pandas as pd -from sklearn import preprocessing as pp - -data = pd.read_csv("../../data/real data/vk_data.csv").iloc[:1000, :10] -print(data.shape) -# print(data.columns) -# set encoder and discretizer -encoder = pp.LabelEncoder() -discretizer = pp.KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='uniform') - -# create preprocessor object with encoder and discretizer -p = Preprocessor([('encoder', encoder), ('discretizer', discretizer)]) - -# discretize data for structure learning -discretized_data, est = p.apply(data) - -# get information about data -info = p.info - -# initialize network object -bn = HybridBN() - -# add nodes to network -bn.add_nodes(info) - -# using mutual information as scoring function for structure learning -bn.add_edges(discretized_data, scoring_function=('K2',)) - -bn.get_info(as_df=False) -bn.plot("entire.html") -plot_to = "family.html" - -bn.find_family("has_high_education", height=1, depth=1, plot_to=plot_to) diff --git a/bamt/logging.conf b/bamt/logging.conf index 2384be9..d75c39f 100644 --- a/bamt/logging.conf +++ b/bamt/logging.conf @@ -1,5 +1,5 @@ [loggers] -keys=root, preprocessor, builder, nodes, network, py.warnings +keys=root, preprocessor, builder, nodes, network, display, py.warnings [handlers] keys=consoleHandler, fileHandler @@ -35,6 +35,12 @@ qualname=nodes handlers=consoleHandler, fileHandler propagate=0 +[logger_display] +level=INFO +qualname=display +handlers=consoleHandler, fileHandler +propagate=0 + [logger_py.warnings] level=INFO diff --git a/bamt/networks/base.py b/bamt/networks/base.py index a66eb3d..5ee1b91 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -766,7 +766,9 @@ def find_family( with_nodes: Optional[List] = None, plot_to: Optional[str] = None, ): - structure = GraphUtils.GraphAnalyzer(self).find_family(node_name, height, depth, with_nodes) + structure = GraphUtils.GraphAnalyzer(self).find_family( + node_name, height, depth, with_nodes + ) if plot_to: plot_( diff --git a/bamt/utils/GraphUtils.py b/bamt/utils/GraphUtils.py index 577209e..89e97b4 100644 --- a/bamt/utils/GraphUtils.py +++ b/bamt/utils/GraphUtils.py @@ -104,7 +104,7 @@ def markov_blanket(self, node_name: str): nodes = parents + children + fremd_eltern + [node_name] edges = self._isolate_structure(nodes) - return {"nodes": nodes, "edges": edges} + return {"nodes": list(set(nodes)), "edges": edges} def _collect_height(self, node_name, height): nodes = [] @@ -143,12 +143,11 @@ def find_family(self, *args): else: with_nodes = list(with_nodes) nodes = ( - self._collect_depth(node_name, depth) - + self._collect_height(node_name, height) - + [node_name] + self._collect_depth(node_name, depth) + + self._collect_height(node_name, height) + + [node_name] ) nodes = list(set(nodes + with_nodes)) - return {"nodes": nodes, - "edges": self._isolate_structure(nodes + with_nodes)} + return {"nodes": nodes, "edges": self._isolate_structure(nodes + with_nodes)} diff --git a/tests/test_builders.py b/tests/test_builders.py index 491a9ca..92111fe 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -1,6 +1,3 @@ -from contextlib import redirect_stdout, redirect_stderr -import os - import itertools import unittest @@ -9,7 +6,7 @@ import pandas as pd from bamt.builders.builders_base import StructureBuilder, VerticesDefiner -from bamt.builders.hc_builder import HCStructureBuilder, HillClimbDefiner +from bamt.builders.hc_builder import HillClimbDefiner from bamt.builders.evo_builder import EvoStructureBuilder from bamt.nodes.gaussian_node import GaussianNode diff --git a/tests/test_graph_analyzer.py b/tests/test_graph_analyzer.py new file mode 100644 index 0000000..6fd1ba0 --- /dev/null +++ b/tests/test_graph_analyzer.py @@ -0,0 +1,90 @@ +import unittest +import logging + +from bamt.utils import GraphUtils +from bamt.networks.discrete_bn import DiscreteBN +from bamt.nodes.discrete_node import DiscreteNode +from bamt.builders.builders_base import VerticesDefiner +from bamt.display import plot_ + +logging.getLogger("builder").setLevel(logging.CRITICAL) + + +class TestGraphAnalyzer(unittest.TestCase): + def setUp(self): + self.bn = DiscreteBN() + definer = VerticesDefiner(descriptor={"types": { + "Node0": "disc", + "Node1": "disc", + "Node2": "disc", + "Node3": "disc_num", + "Node4": "disc", + "Node5": "disc", + "Node6": "disc", + "Node7": "disc", + "Node8": "disc", + "Node9": "disc", + }, + "signs": {}, + }, + regressor=None) + + definer.skeleton["V"] = [ + DiscreteNode(name=f"Node{i}") for i in range(10) + ] + definer.skeleton["E"] = [ + ("Node0", "Node1"), + ("Node0", "Node2"), + ("Node2", "Node3"), + ("Node4", "Node7"), + ("Node1", "Node5"), + ("Node5", "Node6"), + ("Node7", "Node0"), + ("Node8", "Node1"), + ("Node9", "Node2"), + ] + definer.get_family() + + self.bn.nodes = definer.skeleton["V"] + self.bn.edges = definer.skeleton["E"] + + self.analyzer = GraphUtils.GraphAnalyzer(self.bn) + + def test_markov_blanket(self): + result = self.analyzer.markov_blanket("Node0") + result["nodes"] = sorted(result["nodes"]) + self.assertEqual( + {'edges': [('Node0', 'Node1'), + ('Node0', 'Node2'), + ('Node7', 'Node0'), + ('Node8', 'Node1'), + ('Node9', 'Node2')], + 'nodes': sorted(['Node0', 'Node1', 'Node2', 'Node7', 'Node8', 'Node9'])} + , + result + ) + + def test_find_family(self): + without_parents = self.analyzer.find_family("Node0", 0, 2, None) + without_parents["nodes"] = sorted(without_parents["nodes"]) + self.assertEqual( + {'nodes': sorted(['Node3', 'Node2', 'Node1', 'Node0', 'Node5']), + 'edges': [('Node0', 'Node1'), ('Node0', 'Node2'), ('Node2', 'Node3'), ('Node1', 'Node5')]}, + without_parents + ) + + without_children = self.analyzer.find_family("Node0", 2, 0, None) + without_children["nodes"] = sorted(without_children["nodes"]) + self.assertEqual( + {'nodes': sorted(['Node4', 'Node7', 'Node0']), + 'edges': [('Node4', 'Node7'), ('Node7', 'Node0')]}, + without_children + ) + + plot_("here.html", + [self.bn[node] for node in without_children["nodes"]], + without_children["edges"]) + + +if __name__ == "__main__": + unittest.main(verbosity=2)