From 4718e0e6a736f216aa9112c1b4e8441c926423e9 Mon Sep 17 00:00:00 2001 From: jrzkaminski <86363785+jrzkaminski@users.noreply.github.com> Date: Sat, 5 Aug 2023 13:52:49 +0300 Subject: [PATCH] minor documentation changes and fixes --- bamt/networks/composite_bn.py | 11 ++----- bamt/nodes/base.py | 1 - tests/test_builders.py | 60 +++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/bamt/networks/composite_bn.py b/bamt/networks/composite_bn.py index 9345680..bbb63d2 100644 --- a/bamt/networks/composite_bn.py +++ b/bamt/networks/composite_bn.py @@ -1,15 +1,8 @@ -import os import re -import random -import numpy as np -from sklearn.preprocessing import LabelEncoder - -from tqdm import tqdm -from bamt.log import logger_network -from bamt.networks.base import BaseNetwork, STORAGE +from bamt.networks.base import BaseNetwork import pandas as pd -from typing import Optional, Dict, Union, List +from typing import Optional, Dict from bamt.builders.composite_builder import CompositeStructureBuilder, CompositeDefiner from bamt.utils.composite_utils.MLUtils import MlModels diff --git a/bamt/nodes/base.py b/bamt/nodes/base.py index 74029e1..1c5c2c8 100644 --- a/bamt/nodes/base.py +++ b/bamt/nodes/base.py @@ -29,7 +29,6 @@ def __init__(self, name: str): self.disc_parents = [] self.cont_parents = [] self.children = [] - self.encoders = {} def __repr__(self): return f"{self.name}" diff --git a/tests/test_builders.py b/tests/test_builders.py index 972241b..eb62061 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -11,6 +11,7 @@ from bamt.builders.builders_base import StructureBuilder, VerticesDefiner from bamt.builders.hc_builder import HCStructureBuilder, HillClimbDefiner from bamt.builders.evo_builder import EvoStructureBuilder +from bamt.builders.composite_builder import CompositeStructureBuilder from bamt.nodes.gaussian_node import GaussianNode from bamt.nodes.discrete_node import DiscreteNode @@ -729,6 +730,65 @@ def test_build(self): msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}", ) + class TestCompositeBuilder(unittest.TestCase): + def setUp(self): + self.data = pd.read_csv(r"data/benchmark/healthcare.csv", index_col=0) + self.descriptor = { + "types": { + "A": "disc", + "C": "disc", + "D": "cont", + "H": "disc", + "I": "cont", + "O": "cont", + "T": "cont", + }, + "signs": {"D": "pos", "I": "neg", "O": "pos", "T": "pos"}, + } + self.comp_builder = CompositeStructureBuilder( + data=self.data, descriptor=self.descriptor, regressor=None + ) + self.reference_dag = [ + ("A", "C"), + ("A", "D"), + ("A", "H"), + ("A", "O"), + ("C", "I"), + ("D", "I"), + ("H", "D"), + ("I", "T"), + ("O", "T"), + ("A", "C"), + ("A", "D"), + ("A", "H"), + ("A", "O"), + ("C", "I"), + ("D", "I"), + ("H", "D"), + ("I", "T"), + ("O", "T"), + ] + + def test_build(self): + kwargs = {} + self.comp_builder.build( + data=self.data, classifier=None, regressor=None, verbose=False, **kwargs + ) + + obtained_dag = self.comp_builder.skeleton["E"] + obtained_dag = [tuple([str(item) for item in inner_list]) for inner_list in obtained_dag] + num_edges = len(obtained_dag) + self.assertGreaterEqual( + num_edges, 1, msg="Obtained graph should have at least one edge." + ) + + dist = precision_recall(obtained_dag, self.reference_dag)["SHD"] + self.assertLess( + dist, + 15, + msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}", + ) + if __name__ == "__main__": unittest.main(verbosity=2)