diff --git a/src/gen/distribution/commons_math.clj b/src/gen/distribution/commons_math.clj index 6d0c0df..2808fef 100644 --- a/src/gen/distribution/commons_math.clj +++ b/src/gen/distribution/commons_math.clj @@ -96,12 +96,34 @@ (defn uniform-discrete-distribution [low high] (UniformIntegerDistribution. (rng) low high)) -(defn categorical-distribution [probabilities] +(defn- v->categorical [probabilities] (let [n (count probabilities) ks (int-array (range n)) vs (double-array probabilities)] (EnumeratedIntegerDistribution. (rng) ks vs))) +(defn- m->categorical [probabilities] + (let [ks (keys probabilities) + vs (vals probabilities) + k->i (zipmap ks (range)) + i->k (zipmap (range) ks)] + (-> (v->categorical vs) + (d/->Encoded k->i i->k)))) + +(defn categorical-distribution + "Given either + + - a sequence of `probabilities` that sum to 1.0 + - a map of object => probability (whose values sum to 1.0) + + returns a distribution that produces samples of an integer in the range $[0, + n)$ (where `n == (count probabilities)`), or of a map key (for map-shaped + `probabilities`)." + [probabilities] + (if (map? probabilities) + (m->categorical probabilities) + (v->categorical probabilities))) + ;; ## Primitive generative functions (def bernoulli diff --git a/test/gen/distribution/commons_math_test.clj b/test/gen/distribution/commons_math_test.clj index 6abacd9..8068278 100644 --- a/test/gen/distribution/commons_math_test.clj +++ b/test/gen/distribution/commons_math_test.clj @@ -10,6 +10,9 @@ (deftest beta-tests (dt/beta-tests commons/beta-distribution)) +(deftest categorical-tests + (dt/categorical-tests commons/categorical-distribution)) + (deftest uniform-tests (dt/uniform-tests commons/uniform-distribution)) diff --git a/test/gen/distribution_test.cljc b/test/gen/distribution_test.cljc index 8938ba5..f90a3e7 100644 --- a/test/gen/distribution_test.cljc +++ b/test/gen/distribution_test.cljc @@ -46,6 +46,25 @@ (Math/exp (dist/logpdf (->bernoulli p) (not v))))) "All options sum to 1"))) +(defn categorical-tests [->cat] + (checking "map => categorical properties" + [p (gen-double 0 1)] + (let [dist (->cat {:true p :false (- 1 p)})] + (is (ish? (Math/log p) (dist/logpdf dist :true)) + "prob of `:true` matches `p`") + + (is (ish? (Math/log (- 1 p)) (dist/logpdf dist :false)) + "prob of `:false` matches `1-p`"))) + + (checking "vector => categorical properties" + [p (gen-double 0 1)] + (let [dist (->cat [p (- 1 p)])] + (is (ish? (Math/log p) (dist/logpdf dist 0)) + "prob of `1` matches `p`") + + (is (ish? (Math/log (- 1 p)) (dist/logpdf dist 1)) + "prob of `0` matches `1-p`")))) + (defn bernoulli-gfi-tests [bernoulli-dist] (testing "bernoulli-call-no-args" (is (boolean? (bernoulli-dist))))