diff --git a/bamt/networks/base.py b/bamt/networks/base.py index 3f390e4..d44d974 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -863,7 +863,7 @@ def plot(self, output: str): return network.show(f"visualization_result/" + output) - def get_dist(self, node_name: str, pvals: dict): + def get_dist(self, node_name: str, pvals: Optional[dict] = None): """ Get a distribution from node with known parent values (conditional distribution). @@ -877,8 +877,7 @@ def get_dist(self, node_name: str, pvals: dict): parents = node.cont_parents + node.disc_parents if not parents: - logger_network.error("No parents") - return + return self.distributions[node_name] pvals = [pvals[parent] for parent in parents] diff --git a/tests/test_nodes.py b/tests/test_nodes.py index c48e4a0..6a70264 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -9,7 +9,7 @@ logging.getLogger("nodes").setLevel(logging.CRITICAL) -class DistributionAssertion(unittest.TestCase): +class MyTest(unittest.TestCase): unittest.skip("This is an assertion.") def assertDist(self, dist, node_type): @@ -22,7 +22,7 @@ def assertDistMixture(self, dist): return self.assertEqual(len(dist), 3, msg=f"Error on {dist}") -class TestBaseNode(unittest.TestCase): +class TestBaseNode(MyTest): def setUp(self): np.random.seed(510) @@ -137,19 +137,26 @@ def test_get_dist_mixture(self): for i in range(-2, 0, 2): dist = hybrid_bn.get_dist(mixture_gauss.name, pvals={"Node0": i}) - DistributionAssertion().assertDistMixture(dist) + self.assertDistMixture(dist) for i in range(-2, 0, 2): for j in self.data[cond_mixture_gauss.disc_parents[0]].unique().tolist(): dist = hybrid_bn.get_dist( cond_mixture_gauss.name, pvals={"Node0": float(i), "Node4": j} ) - DistributionAssertion().assertDistMixture(dist) + self.assertDistMixture(dist) def test_get_dist(self): for node in self.bn.nodes: if not node.cont_parents + node.disc_parents: + dist = self.bn.get_dist(node.name) + if "mean" in dist.keys(): + self.assertTrue(isinstance(dist["mean"], float)) + self.assertTrue(isinstance(dist["variance"], float)) + else: + self.assertAlmostEqual(sum(dist["cprob"]), 1) continue + if len(node.cont_parents + node.disc_parents) == 1: if node.disc_parents: pvals = self.data[node.disc_parents[0]].unique().tolist() @@ -160,9 +167,7 @@ def test_get_dist(self): for pval in pvals: dist = self.bn.get_dist(node.name, {parent: pval}) - DistributionAssertion().assertDist( - dist, self.info["types"][node.name] - ) + self.assertDist(dist, self.info["types"][node.name]) else: for i in self.data[node.disc_parents[0]].unique().tolist(): for j in range(-5, 5, 1): @@ -170,9 +175,7 @@ def test_get_dist(self): node.name, {node.cont_parents[0]: float(j), node.disc_parents[0]: i}, ) - DistributionAssertion().assertDist( - dist, self.info["types"][node.name] - ) + self.assertDist(dist, self.info["types"][node.name]) # ??? def test_choose_serialization(self):