Skip to content

Commit

Permalink
when pvals don't set, return marginal dist
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman223 committed Aug 11, 2023
1 parent 3006616 commit a361562
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
5 changes: 2 additions & 3 deletions bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ def plot(self, output: str):

return network.show(f"visualization_result/" + output)

def get_dist(self, node_name: str, pvals: dict):
def get_dist(self, node_name: str, pvals: Optional[dict] = None):
"""
Get a distribution from node with known parent values (conditional distribution).
Expand All @@ -877,8 +877,7 @@ def get_dist(self, node_name: str, pvals: dict):

parents = node.cont_parents + node.disc_parents
if not parents:
logger_network.error("No parents")
return
return self.distributions[node_name]

pvals = [pvals[parent] for parent in parents]

Expand Down
23 changes: 13 additions & 10 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
logging.getLogger("nodes").setLevel(logging.CRITICAL)


class DistributionAssertion(unittest.TestCase):
class MyTest(unittest.TestCase):
unittest.skip("This is an assertion.")

def assertDist(self, dist, node_type):
Expand All @@ -22,7 +22,7 @@ def assertDistMixture(self, dist):
return self.assertEqual(len(dist), 3, msg=f"Error on {dist}")


class TestBaseNode(unittest.TestCase):
class TestBaseNode(MyTest):
def setUp(self):
np.random.seed(510)

Expand Down Expand Up @@ -137,19 +137,26 @@ def test_get_dist_mixture(self):

for i in range(-2, 0, 2):
dist = hybrid_bn.get_dist(mixture_gauss.name, pvals={"Node0": i})
DistributionAssertion().assertDistMixture(dist)
self.assertDistMixture(dist)

for i in range(-2, 0, 2):
for j in self.data[cond_mixture_gauss.disc_parents[0]].unique().tolist():
dist = hybrid_bn.get_dist(
cond_mixture_gauss.name, pvals={"Node0": float(i), "Node4": j}
)
DistributionAssertion().assertDistMixture(dist)
self.assertDistMixture(dist)

def test_get_dist(self):
for node in self.bn.nodes:
if not node.cont_parents + node.disc_parents:
dist = self.bn.get_dist(node.name)
if "mean" in dist.keys():
self.assertTrue(isinstance(dist["mean"], float))
self.assertTrue(isinstance(dist["variance"], float))
else:
self.assertAlmostEqual(sum(dist["cprob"]), 1)
continue

if len(node.cont_parents + node.disc_parents) == 1:
if node.disc_parents:
pvals = self.data[node.disc_parents[0]].unique().tolist()
Expand All @@ -160,19 +167,15 @@ def test_get_dist(self):

for pval in pvals:
dist = self.bn.get_dist(node.name, {parent: pval})
DistributionAssertion().assertDist(
dist, self.info["types"][node.name]
)
self.assertDist(dist, self.info["types"][node.name])
else:
for i in self.data[node.disc_parents[0]].unique().tolist():
for j in range(-5, 5, 1):
dist = self.bn.get_dist(
node.name,
{node.cont_parents[0]: float(j), node.disc_parents[0]: i},
)
DistributionAssertion().assertDist(
dist, self.info["types"][node.name]
)
self.assertDist(dist, self.info["types"][node.name])

# ???
def test_choose_serialization(self):
Expand Down

0 comments on commit a361562

Please sign in to comment.