Skip to content

Commit

Permalink
added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Aug 5, 2023
1 parent 4718e0e commit 1156368
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 59 deletions.
59 changes: 0 additions & 59 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,65 +730,6 @@ def test_build(self):
msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}",
)

class TestCompositeBuilder(unittest.TestCase):
def setUp(self):
self.data = pd.read_csv(r"data/benchmark/healthcare.csv", index_col=0)
self.descriptor = {
"types": {
"A": "disc",
"C": "disc",
"D": "cont",
"H": "disc",
"I": "cont",
"O": "cont",
"T": "cont",
},
"signs": {"D": "pos", "I": "neg", "O": "pos", "T": "pos"},
}
self.comp_builder = CompositeStructureBuilder(
data=self.data, descriptor=self.descriptor, regressor=None
)
self.reference_dag = [
("A", "C"),
("A", "D"),
("A", "H"),
("A", "O"),
("C", "I"),
("D", "I"),
("H", "D"),
("I", "T"),
("O", "T"),
("A", "C"),
("A", "D"),
("A", "H"),
("A", "O"),
("C", "I"),
("D", "I"),
("H", "D"),
("I", "T"),
("O", "T"),
]

def test_build(self):
kwargs = {}
self.comp_builder.build(
data=self.data, classifier=None, regressor=None, verbose=False, **kwargs
)

obtained_dag = self.comp_builder.skeleton["E"]
obtained_dag = [tuple([str(item) for item in inner_list]) for inner_list in obtained_dag]
num_edges = len(obtained_dag)
self.assertGreaterEqual(
num_edges, 1, msg="Obtained graph should have at least one edge."
)

dist = precision_recall(obtained_dag, self.reference_dag)["SHD"]
self.assertLess(
dist,
15,
msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}",
)


if __name__ == "__main__":
unittest.main(verbosity=2)
74 changes: 74 additions & 0 deletions tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
import pandas as pd

from bamt.networks.hybrid_bn import BaseNetwork, HybridBN
from bamt.networks.composite_bn import CompositeBN
import bamt.preprocessors as bp

from bamt.nodes.gaussian_node import GaussianNode
from bamt.nodes.discrete_node import DiscreteNode
from bamt.nodes.logit_node import LogitNode
from bamt import preprocessors
from bamt.utils.MathUtils import precision_recall

logging.getLogger("network").setLevel(logging.CRITICAL)

Expand Down Expand Up @@ -1054,5 +1057,76 @@ class TestBigBraveBN(unittest.SkipTest):
pass


class TestCompositeNetwork(unittest.TestCase):
def setUp(self):
self.data = pd.read_csv(r"data/benchmark/healthcare.csv", index_col=0)
self.descriptor = {
"types": {
"A": "disc",
"C": "disc",
"D": "cont",
"H": "disc",
"I": "cont",
"O": "cont",
"T": "cont",
},
"signs": {"D": "pos", "I": "neg", "O": "pos", "T": "pos"},
}
self.bn = CompositeBN()
self.reference_dag = [
("A", "C"),
("A", "D"),
("A", "H"),
("A", "O"),
("C", "I"),
("D", "I"),
("H", "D"),
("I", "T"),
("O", "T"),
("A", "C"),
("A", "D"),
("A", "H"),
("A", "O"),
("C", "I"),
("D", "I"),
("H", "D"),
("I", "T"),
("O", "T"),
]

def test_learning(self):
encoder = pp.LabelEncoder()
p = bp.Preprocessor([("encoder", encoder)])

_, _ = p.apply(self.data)

info = p.info

self.bn.add_nodes(info)

self.bn.add_edges(self.data)

self.bn.fit_parameters(self.data)

obtained_dag = self.bn.edges
num_edges = len(obtained_dag)
self.assertGreaterEqual(
num_edges, 1, msg="Obtained graph should have at least one edge."
)

dist = precision_recall(obtained_dag, self.reference_dag)["SHD"]
self.assertLess(
dist,
20,
msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}",
)

for node in self.bn.nodes:
if type(node).__name__ == "CompositeContinuousNode":
self.assertIsNotNone(node.regressor, msg="CompositeContinuousNode does not have regressor")
if type(node).__name__ == "CompositeDiscreteNode":
self.assertIsNotNone(node.classifier, msg="CompositeDiscreteNode does not have classifier")


if __name__ == "__main__":
unittest.main(verbosity=3)

0 comments on commit 1156368

Please sign in to comment.