From 1156368b75d92b9567fccce5e5c508892f50300c Mon Sep 17 00:00:00 2001 From: jrzkaminski <86363785+jrzkaminski@users.noreply.github.com> Date: Sat, 5 Aug 2023 17:41:43 +0300 Subject: [PATCH] added tests --- tests/test_builders.py | 59 --------------------------------- tests/test_networks.py | 74 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 59 deletions(-) diff --git a/tests/test_builders.py b/tests/test_builders.py index eb62061..afbe3ea 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -730,65 +730,6 @@ def test_build(self): msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}", ) - class TestCompositeBuilder(unittest.TestCase): - def setUp(self): - self.data = pd.read_csv(r"data/benchmark/healthcare.csv", index_col=0) - self.descriptor = { - "types": { - "A": "disc", - "C": "disc", - "D": "cont", - "H": "disc", - "I": "cont", - "O": "cont", - "T": "cont", - }, - "signs": {"D": "pos", "I": "neg", "O": "pos", "T": "pos"}, - } - self.comp_builder = CompositeStructureBuilder( - data=self.data, descriptor=self.descriptor, regressor=None - ) - self.reference_dag = [ - ("A", "C"), - ("A", "D"), - ("A", "H"), - ("A", "O"), - ("C", "I"), - ("D", "I"), - ("H", "D"), - ("I", "T"), - ("O", "T"), - ("A", "C"), - ("A", "D"), - ("A", "H"), - ("A", "O"), - ("C", "I"), - ("D", "I"), - ("H", "D"), - ("I", "T"), - ("O", "T"), - ] - - def test_build(self): - kwargs = {} - self.comp_builder.build( - data=self.data, classifier=None, regressor=None, verbose=False, **kwargs - ) - - obtained_dag = self.comp_builder.skeleton["E"] - obtained_dag = [tuple([str(item) for item in inner_list]) for inner_list in obtained_dag] - num_edges = len(obtained_dag) - self.assertGreaterEqual( - num_edges, 1, msg="Obtained graph should have at least one edge." - ) - - dist = precision_recall(obtained_dag, self.reference_dag)["SHD"] - self.assertLess( - dist, - 15, - msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}", - ) - if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/test_networks.py b/tests/test_networks.py index 76649c9..9189745 100644 --- a/tests/test_networks.py +++ b/tests/test_networks.py @@ -13,11 +13,14 @@ import pandas as pd from bamt.networks.hybrid_bn import BaseNetwork, HybridBN +from bamt.networks.composite_bn import CompositeBN +import bamt.preprocessors as bp from bamt.nodes.gaussian_node import GaussianNode from bamt.nodes.discrete_node import DiscreteNode from bamt.nodes.logit_node import LogitNode from bamt import preprocessors +from bamt.utils.MathUtils import precision_recall logging.getLogger("network").setLevel(logging.CRITICAL) @@ -1054,5 +1057,76 @@ class TestBigBraveBN(unittest.SkipTest): pass +class TestCompositeNetwork(unittest.TestCase): + def setUp(self): + self.data = pd.read_csv(r"data/benchmark/healthcare.csv", index_col=0) + self.descriptor = { + "types": { + "A": "disc", + "C": "disc", + "D": "cont", + "H": "disc", + "I": "cont", + "O": "cont", + "T": "cont", + }, + "signs": {"D": "pos", "I": "neg", "O": "pos", "T": "pos"}, + } + self.bn = CompositeBN() + self.reference_dag = [ + ("A", "C"), + ("A", "D"), + ("A", "H"), + ("A", "O"), + ("C", "I"), + ("D", "I"), + ("H", "D"), + ("I", "T"), + ("O", "T"), + ("A", "C"), + ("A", "D"), + ("A", "H"), + ("A", "O"), + ("C", "I"), + ("D", "I"), + ("H", "D"), + ("I", "T"), + ("O", "T"), + ] + + def test_learning(self): + encoder = pp.LabelEncoder() + p = bp.Preprocessor([("encoder", encoder)]) + + _, _ = p.apply(self.data) + + info = p.info + + self.bn.add_nodes(info) + + self.bn.add_edges(self.data) + + self.bn.fit_parameters(self.data) + + obtained_dag = self.bn.edges + num_edges = len(obtained_dag) + self.assertGreaterEqual( + num_edges, 1, msg="Obtained graph should have at least one edge." + ) + + dist = precision_recall(obtained_dag, self.reference_dag)["SHD"] + self.assertLess( + dist, + 20, + msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}", + ) + + for node in self.bn.nodes: + if type(node).__name__ == "CompositeContinuousNode": + self.assertIsNotNone(node.regressor, msg="CompositeContinuousNode does not have regressor") + if type(node).__name__ == "CompositeDiscreteNode": + self.assertIsNotNone(node.classifier, msg="CompositeDiscreteNode does not have classifier") + + if __name__ == "__main__": unittest.main(verbosity=3)