Skip to content

Commit e55d4ad

Browse files
committed
test: Basic test of string keys in Crosscat model
1 parent 689fe74 commit e55d4ad

File tree

2 files changed

+100
-28
lines changed

2 files changed

+100
-28
lines changed
Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,68 @@
11
(ns gensql.inference.gpm.crosscat-test
2-
(:require [gensql.inference.test-models.crosscat :refer [model]]
3-
[clojure.test :refer [deftest is]]
2+
(:require [gensql.inference.test-models.crosscat :refer [model str-model]]
3+
[clojure.test :refer [deftest is testing]]
44
[gensql.inference.gpm :as gpm]))
55

66
(deftest simulate
7-
(let [sim-no-constraint (gpm/simulate model [:color :height :flip] {})
8-
sim-constraint-not-target (gpm/simulate model [:color] {:height 0.4})
9-
sim-constraints-are-target (gpm/simulate model [:color :height] {:color "blue" :height 0.4})]
10-
;; Simply checking that we generated all columns without error.
11-
(is (= #{:color :height :flip} (set (keys sim-no-constraint))))
12-
(is (= #{:color} (set (keys sim-constraint-not-target))))
13-
(is (= "blue" (:color sim-constraints-are-target)))
14-
(is (= 0.4 (:height sim-constraints-are-target)))))
7+
(testing "keyword"
8+
(let [sim-no-constraint (gpm/simulate model [:color :height :flip] {})
9+
sim-constraint-not-target (gpm/simulate model [:color] {:height 0.4})
10+
sim-constraints-are-target (gpm/simulate model [:color :height] {:color "blue" :height 0.4})]
11+
;; Simply checking that we generated all columns without error.
12+
(is (= #{:color :height :flip} (set (keys sim-no-constraint))))
13+
(is (= #{:color} (set (keys sim-constraint-not-target))))
14+
(is (= "blue" (:color sim-constraints-are-target)))
15+
(is (= 0.4 (:height sim-constraints-are-target)))))
16+
17+
(testing "string"
18+
(let [sim-no-constraint (gpm/simulate str-model ["color" "height" "flip"] {})
19+
sim-constraint-not-target (gpm/simulate str-model ["color"] {"height" 0.4})
20+
sim-constraints-are-target (gpm/simulate str-model ["color" "height"] {"color" "blue" "height" 0.4})]
21+
;; Simply checking that we generated all columns without error.
22+
(is (= #{"color" "height" "flip"} (set (keys sim-no-constraint))))
23+
(is (= #{"color"} (set (keys sim-constraint-not-target))))
24+
(is (= "blue" (get sim-constraints-are-target "color")))
25+
(is (= 0.4 (get sim-constraints-are-target "height"))))))
1526

1627
(deftest logpdf
17-
(let [no-constraints (gpm/logpdf model {:color "red" :height 4.0 :flip true} {})
18-
constraints-match-target (gpm/logpdf model {:color "red" :height 4.0} {:color "red" :height 4.0})
19-
mistmatch (gpm/logpdf model
20-
{:color "red" :height 4.0 :flip true} {:color "blue" :height 4.0 :flip true})
21-
match-subset (gpm/logpdf model
22-
{:color "red" :height 4.0} {:color "red" :flip true})
23-
nonmatch-subset (gpm/logpdf model
24-
{:height 4.0} {:color "red" :flip true})
25-
no-target (gpm/logpdf model
26-
{} {:color "red" :height 4.0 :flip true})
27-
fully-constrained-target (gpm/logpdf model
28-
{:color "red" :height 4.0} {:color "red" :height 4.0 :flip true})]
29-
(is (number? no-constraints))
30-
(is (= 0.0 constraints-match-target))
31-
(is (= ##-Inf mistmatch))
32-
(is (= match-subset nonmatch-subset))
33-
(is (= no-target fully-constrained-target 0.0))))
28+
(testing "keyword"
29+
(let [no-constraints (gpm/logpdf model {:color "red" :height 4.0 :flip true} {})
30+
constraints-match-target (gpm/logpdf model {:color "red" :height 4.0} {:color "red" :height 4.0})
31+
mistmatch (gpm/logpdf model
32+
{:color "red" :height 4.0 :flip true} {:color "blue" :height 4.0 :flip true})
33+
match-subset (gpm/logpdf model
34+
{:color "red" :height 4.0} {:color "red" :flip true})
35+
nonmatch-subset (gpm/logpdf model
36+
{:height 4.0} {:color "red" :flip true})
37+
no-target (gpm/logpdf model
38+
{} {:color "red" :height 4.0 :flip true})
39+
fully-constrained-target (gpm/logpdf model
40+
{:color "red" :height 4.0} {:color "red" :height 4.0 :flip true})]
41+
(is (number? no-constraints))
42+
(is (= 0.0 constraints-match-target))
43+
(is (= ##-Inf mistmatch))
44+
(is (= match-subset nonmatch-subset))
45+
(is (= no-target fully-constrained-target 0.0))))
46+
47+
(testing "string"
48+
(let [no-constraints (gpm/logpdf str-model {"color" "red" "height" 4.0 "flip" true} {})
49+
constraints-match-target (gpm/logpdf str-model {"color" "red" "height" 4.0} {"color" "red" "height" 4.0})
50+
mistmatch (gpm/logpdf str-model
51+
{"color" "red" "height" 4.0 "flip" true} {"color" "blue" "height" 4.0 "flip" true})
52+
match-subset (gpm/logpdf str-model
53+
{"color" "red" "height" 4.0} {"color" "red" "flip" true})
54+
nonmatch-subset (gpm/logpdf str-model
55+
{"height" 4.0} {"color" "red" "flip" true})
56+
no-target (gpm/logpdf str-model
57+
{} {"color" "red" "height" 4.0 "flip" true})
58+
fully-constrained-target (gpm/logpdf str-model
59+
{"color" "red" "height" 4.0} {"color" "red" "height" 4.0 "flip" true})]
60+
(is (number? no-constraints))
61+
(is (= 0.0 constraints-match-target))
62+
(is (= ##-Inf mistmatch))
63+
(is (= match-subset nonmatch-subset))
64+
(is (= no-target fully-constrained-target 0.0)))))
3465

3566
(deftest variables
36-
(is (= #{:color :height :flip} (gpm/variables model))))
67+
(is (= #{:color :height :flip} (gpm/variables model)))
68+
(is (= #{"color" "height" "flip"} (gpm/variables str-model))))

test/gensql/inference/test_models/crosscat.cljc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,43 @@
4141
4 :two
4242
5 :two}}}}]
4343
(xcat/construct-xcat-from-latents xcat-spec xcat-latents data {:options options})))
44+
45+
46+
(def str-data
47+
{0 {"color" "red" "height" 6 "flip" true}
48+
1 {"color" "red" "height" 6 "flip" true}
49+
2 {"color" "red" "height" 6 "flip" true}
50+
3 {"color" "red" "height" 4 "flip" false}
51+
4 {"color" "blue" "height" 4 "flip" false}
52+
5 {"color" "green" "height" 4 "flip" false}})
53+
54+
(def str-model
55+
(let [options {"color" ["red" "blue" "green"]}
56+
view-1-name (str (gensym))
57+
view-2-name (str (gensym))
58+
59+
xcat-spec {:views {view-1-name {:hypers {"color" {:alpha 2}
60+
"height" {:m 0 :r 1 :s 2 :nu 3}}}
61+
view-2-name {:hypers {"flip" {:alpha 1 :beta 1}}}}
62+
:types {"color" :categorical
63+
"height" :gaussian
64+
"flip" :bernoulli}}
65+
66+
xcat-latents {:global {:alpha 0.5}
67+
:local {view-1-name {:alpha 1
68+
:counts {"one" 4 "two" 2}
69+
:y {0 "one"
70+
1 "one"
71+
2 "one"
72+
3 "one"
73+
4 "two"
74+
5 "two"}}
75+
view-2-name {:alpha 1
76+
:counts {"one" 3 "two" 3}
77+
:y {0 "one"
78+
1 "one"
79+
2 "one"
80+
3 "two"
81+
4 "two"
82+
5 "two"}}}}]
83+
(xcat/construct-xcat-from-latents xcat-spec xcat-latents str-data {:options options})))

0 commit comments

Comments
 (0)