-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtest_cpd.py
50 lines (41 loc) · 2.06 KB
/
test_cpd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sys
import unittest
import numpy as np
import pytest
from pycid.core.cpd import ConstantCPD, StochasticFunctionCPD
from pycid.examples.simple_cids import get_minimal_cid
from pycid.examples.story_cids import get_introduced_bias
# TODO: add tests for StochasticFunctionCPD and DecisionDomain
class TestCPD(unittest.TestCase):
def test_initialize_uniform_random_cpd(self) -> None:
cid = get_minimal_cid()
cpd_a = ConstantCPD("A", {}, cid, [0, 2])
self.assertTrue((cpd_a.get_values() == np.array([[0.5], [0.5]])).all())
self.assertEqual(cpd_a.get_state_names("A", 1), 2)
cpd_b = ConstantCPD("B", {}, cid, [0, 1])
self.assertTrue((cpd_b.get_values() == np.array([[0.5, 0.5], [0.5, 0.5]])).all())
def test_initialize_function_cpd(self) -> None:
cid = get_minimal_cid()
cpd_a = StochasticFunctionCPD("A", lambda: 2, cid)
self.assertEqual(cpd_a.get_values(), np.array([[1]]))
self.assertEqual(cpd_a.get_cardinality(["A"])["A"], 1)
self.assertEqual(cpd_a.get_state_names("A", 0), 2)
cid.add_cpds(cpd_a)
cpd_b = StochasticFunctionCPD("B", lambda A: A, cid)
self.assertEqual(cpd_b.get_values(), np.array([[1]]))
self.assertEqual(cpd_b.get_cardinality(["B"])["B"], 1)
self.assertEqual(cpd_b.get_state_names("B", 0), 2)
def test_updated_decision_names(self) -> None:
cid = get_introduced_bias()
self.assertEqual(cid.get_cpds("D").state_names["D"], [0, 1])
cid.impute_conditional_expectation_decision("D", "Y")
self.assertNotEqual(cid.get_cpds("D").state_names["D"], [0, 1])
cid.impute_random_policy()
self.assertNotEqual(cid.get_cpds("D").state_names["D"], [0, 1])
# TODO: It doesn't always work to impute an optimal policy after imputing a
# conditional expectation one, possibly because the real-valued decision domain?
# cid.impute_optimal_policy()
# eu = cid.expected_utility({})
# self.assertGreater(eu, -0.2)
if __name__ == "__main__":
pytest.main(sys.argv)