Skip to content

Commit

Permalink
minor documentation changes and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Aug 5, 2023
1 parent ade378b commit 4718e0e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 10 deletions.
11 changes: 2 additions & 9 deletions bamt/networks/composite_bn.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import os
import re
import random

import numpy as np
from sklearn.preprocessing import LabelEncoder

from tqdm import tqdm
from bamt.log import logger_network
from bamt.networks.base import BaseNetwork, STORAGE
from bamt.networks.base import BaseNetwork
import pandas as pd
from typing import Optional, Dict, Union, List
from typing import Optional, Dict
from bamt.builders.composite_builder import CompositeStructureBuilder, CompositeDefiner
from bamt.utils.composite_utils.MLUtils import MlModels

Expand Down
1 change: 0 additions & 1 deletion bamt/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(self, name: str):
self.disc_parents = []
self.cont_parents = []
self.children = []
self.encoders = {}

def __repr__(self):
return f"{self.name}"
Expand Down
60 changes: 60 additions & 0 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from bamt.builders.builders_base import StructureBuilder, VerticesDefiner
from bamt.builders.hc_builder import HCStructureBuilder, HillClimbDefiner
from bamt.builders.evo_builder import EvoStructureBuilder
from bamt.builders.composite_builder import CompositeStructureBuilder

from bamt.nodes.gaussian_node import GaussianNode
from bamt.nodes.discrete_node import DiscreteNode
Expand Down Expand Up @@ -729,6 +730,65 @@ def test_build(self):
msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}",
)

class TestCompositeBuilder(unittest.TestCase):
def setUp(self):
self.data = pd.read_csv(r"data/benchmark/healthcare.csv", index_col=0)
self.descriptor = {
"types": {
"A": "disc",
"C": "disc",
"D": "cont",
"H": "disc",
"I": "cont",
"O": "cont",
"T": "cont",
},
"signs": {"D": "pos", "I": "neg", "O": "pos", "T": "pos"},
}
self.comp_builder = CompositeStructureBuilder(
data=self.data, descriptor=self.descriptor, regressor=None
)
self.reference_dag = [
("A", "C"),
("A", "D"),
("A", "H"),
("A", "O"),
("C", "I"),
("D", "I"),
("H", "D"),
("I", "T"),
("O", "T"),
("A", "C"),
("A", "D"),
("A", "H"),
("A", "O"),
("C", "I"),
("D", "I"),
("H", "D"),
("I", "T"),
("O", "T"),
]

def test_build(self):
kwargs = {}
self.comp_builder.build(
data=self.data, classifier=None, regressor=None, verbose=False, **kwargs
)

obtained_dag = self.comp_builder.skeleton["E"]
obtained_dag = [tuple([str(item) for item in inner_list]) for inner_list in obtained_dag]
num_edges = len(obtained_dag)
self.assertGreaterEqual(
num_edges, 1, msg="Obtained graph should have at least one edge."
)

dist = precision_recall(obtained_dag, self.reference_dag)["SHD"]
self.assertLess(
dist,
15,
msg=f"Structural Hamming Distance should be less than 15, obtained SHD = {dist}",
)


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 4718e0e

Please sign in to comment.