Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binomial distribution #66

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ pom.xml
!template/pom.xml
pom.xml.asc
node_modules
**.shadow-cljs
**.shadow-cljs
.#*
2 changes: 1 addition & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions src/gen/distribution/commons_math.clj
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
0 false
1 true)))))

(defn binomial-distribution
([n ^double p]
(BinomialDistribution. (rng) n p)))

(defn beta-distribution
([] (beta-distribution 1.0 1.0))
([^double alpha ^double beta]
Expand Down Expand Up @@ -129,6 +133,9 @@
(def bernoulli
(d/->GenerativeFn bernoulli-distribution 1))

(def binomial
(d/->GenerativeFn binomial-distribution 2))

(def beta
(d/->GenerativeFn beta-distribution 2))

Expand Down
16 changes: 16 additions & 0 deletions src/gen/distribution/kixi.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
[kixi.stats.distribution :as k])
#?(:clj
(:import (kixi.stats.distribution Bernoulli Cauchy
Binomial
Exponential Beta
Gamma Normal Uniform T))))

Expand All @@ -22,6 +23,14 @@
(logpdf [this v]
(ll/bernoulli (.-p this) v)))

(extend-type #?(:clj Binomial :cljs k/Binomial)
d/Sample
(sample [this] (k/draw this))

d/LogPDF
(logpdf [this v]
(ll/binomial (.-n this) (.-p this) v)))

(extend-type #?(:clj Beta :cljs k/Beta)
d/Sample
(sample [this] (k/draw this))
Expand Down Expand Up @@ -106,6 +115,10 @@
([] (bernoulli-distribution 0.5))
([p] (k/bernoulli {:p p})))

(defn binomial-distribution
([n p]
(k/binomial {:n n :p p})))

(defn beta-distribution
([] (beta-distribution 1.0 1.0))
([alpha beta]
Expand Down Expand Up @@ -143,6 +156,9 @@
(def bernoulli
(d/->GenerativeFn bernoulli-distribution 1))

(def binomial
(d/->GenerativeFn binomial-distribution 2))

(def beta
(d/->GenerativeFn beta-distribution 2))

Expand Down
69 changes: 34 additions & 35 deletions src/gen/distribution/math/log_likelihood.cljc
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
(ns gen.distribution.math.log-likelihood
"Log-likelihood implementations for various primitive distributions.")

;; ## Helpful constants
;;
;; These come in handy in the implementations below and are worth caching.
"Log-likelihood implementations for various primitive distributions."
(:require [kixi.stats.math :as k]))

(def ^:no-doc log-pi
(Math/log Math/PI))
Expand All @@ -14,39 +11,11 @@
(def ^:no-doc sqrt-2pi
(Math/sqrt (* 2 Math/PI)))

;; ## Log-likelihood implementations

(def ^:no-doc gamma-coefficients
"Coefficients for the Lanczos approximation to the natural log of the Gamma
function described in [section 6.1 of Numerical
Recipes](http://phys.uri.edu/nigh/NumRec/bookfpdf/f6-1.pdf)."
[76.18009172947146
-86.50532032941677
24.01409824083091
-1.231739572450155
0.1208650973866179e-2
-0.5395239384953e-5])

(defn ^:no-doc log-gamma-fn
"Returns the natural log of the value of the [Gamma
function](https://en.wikipedia.org/wiki/Gamma_function) evaluated at `x`

This function implements the Lanczos approximation described in [section 6.1
of Numerical Recipes](http://phys.uri.edu/nigh/NumRec/bookfpdf/f6-1.pdf)."
function](https://en.wikipedia.org/wiki/Gamma_function) evaluated at `x`"
[x]
(let [tmp (+ x 5.5)
tmp (- (* (+ x 0.5) (Math/log tmp)) tmp)
n (dec (count gamma-coefficients))
ser (loop [i 0
x+1 (inc x)
acc 1.000000000190015]
(if (> i n)
acc
(let [coef (nth gamma-coefficients i nil)]
(recur (inc i)
(inc x+1)
(+ acc (/ coef x+1))))))]
(+ tmp (Math/log (* sqrt-2pi (/ ser x))))))
(k/log-gamma x))

(defn gamma
"Returns the log-likelihood of the [Gamma
Expand Down Expand Up @@ -97,6 +66,36 @@
{:pre [(<= 0 p 1)]}
(Math/log (if v p (- 1.0 p))))

(defn binomial
"Returns the log-likelihood of a [Binomial
distribution](https://en.wikipedia.org/wiki/Binomial_distribution)
parameterized by `n` (number of trials) and `p` (probability of success in
each trial) at the value `v` (number of successes)."
[n p v]
{:pre [(integer? n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in ClojureScript, number? checks if you have a goog.math/Integer... is that right? I want to check on how we did this in Emmy, I feel like there is some other predicate we should use. I could be wrong, flying blind here before a meeting..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah not sure, but it doesn't look like cljs checks for goog.math/Integer:

https://github.com/clojure/clojurescript/blob/acbefb9b1e79b659b639919fbf96cb3726719e25/src/main/cljs/cljs/core.cljs#L2328-L2334

in emmy i don't see an integer? defined, but the value namespace does have predicates like real? and number? that do use goog.math/...:

https://github.com/mentat-collective/emmy/blob/main/src/emmy/value.cljc

BUT if i understand the compatibility layer between clj and cljs, if we use integer? in clojure code and then compile it to cljs, the cljs version of integer? should be used when the code runs in a js environment. if that's the case, then it's plausible that this is the right predicate to use?

(integer? v)
(>= v 0)
(>= n v)
(<= 0 p 1)]}
(letfn [(log-fact
[x]
(log-gamma-fn (inc x)))
(log-bico
[n k]
(if (or (zero? k) (= k n))
0 ;; log(1)
(- (log-fact n) (log-fact k) (log-fact (- n k)))))]
(case p
0 (if (= v 0)
0.0 ;; log(1)
##-Inf) ;; log(0))
1 (if (= v n)
0.0 ;; log(1)
##-Inf) ;; log(0)
(+ (log-bico n v)
(* v (Math/log p))
(* (- n v) (Math/log (- 1 p)))))))

(defn cauchy
"Returns the log-likelihood of a [Cauchy
distribution](https://en.wikipedia.org/wiki/Cauchy_distribution) parameterized
Expand Down
4 changes: 4 additions & 0 deletions test/gen/distribution/commons_math_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
(dt/bernoulli-tests commons/bernoulli-distribution)
(dt/bernoulli-gfi-tests commons/bernoulli))

(deftest binomial-tests
(dt/binomial-tests commons/binomial-distribution)
(dt/binomial-gf-tests commons/binomial))

(deftest beta-tests
(dt/beta-tests commons/beta-distribution))

Expand Down
4 changes: 4 additions & 0 deletions test/gen/distribution/kixi_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
(dt/bernoulli-tests kixi/bernoulli-distribution)
(dt/bernoulli-gfi-tests kixi/bernoulli))

(deftest binomial-tests
(dt/binomial-tests kixi/binomial-distribution)
(dt/binomial-gf-tests kixi/binomial))

(deftest beta-tests
(dt/beta-tests kixi/beta-distribution))

Expand Down
3 changes: 3 additions & 0 deletions test/gen/distribution/math/log_likelihood_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
(deftest bernoulli-tests
(dt/bernoulli-tests (->logpdf ll/bernoulli)))

(deftest binomial-tests
(dt/binomial-tests (->logpdf ll/binomial)))

(deftest cauchy-tests
(dt/cauchy-tests (->logpdf ll/cauchy)))

Expand Down
72 changes: 72 additions & 0 deletions test/gen/distribution_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,78 @@
(Math/exp (dist/logpdf (->bernoulli p) (not v)))))
"All options sum to 1")))

(defn binomial-gf-tests [->binomial-gf]
(checking "spot check gf score implementations"
[n (gen/choose 0 10000)
p (gen-double 0.11111 0.99999)]
(let [trace (gf/simulate ->binomial-gf [n p])
sample (trace/get-retval trace)]
(is (<= sample n)))))

(defn binomial-tests [->binomial]
;; boundaries...
(testing "when p = 0 and v = 0, probability is 1, log(1) = 0"
(is 0 (dist/logpdf (->binomial 10 0) 0)))

(testing "when p = 0 and v > 0, probability is 0, log(0) = -Inf"
(is ##-Inf (dist/logpdf (->binomial 10 0) 1)))

(testing "when p = 1 and v = n, probability is 1, log(1) = 0"
(is 0 (dist/logpdf (->binomial 10 1) 10)))

(testing "when p = 1 and v < n, probability is 0, log(0) = -Inf"
(is ##-Inf(dist/logpdf (->binomial 10 0) 1)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops, missing a space

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed. should bb lint normally catch that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like it's still there, I'm not sure, maybe it should?


;; properties...
(checking "sum of probabilities equals 1"
[n (gen/choose 0 10000)
p (gen-double 0.11111 0.99999)]
(let [log-probs (map (fn [k] (dist/logpdf (->binomial n p) k)) (range 0 (inc n)))
probs (map (fn [x] (Math/exp x)) log-probs)
sum-probs (reduce + probs)]
(with-comparator (within 1e-9)
(is (ish? 1.0 sum-probs)))))

;; A binomial distribution is symmetrical if the probability of observing $k$
;; successes in $n$ trials is the same as observing $n - k$ successes, which
;; should be true when $p = 0.5$.
(checking "symmetrical shape when $p = 0.5$"
[n (gen/choose 0 10000)]
(with-comparator (within 1e-9)
(let [p 0.5
ks (range 0 (inc n))
k (map (fn [k] (dist/logpdf (->binomial n p) k)) ks)
n-k (map (fn [k] (dist/logpdf (->binomial n p) (- n k))) ks)]
(is (ish? k n-k)))))

(testing "spot check against scipy.stats.binom.logpmf (v1.12.0)"
(with-comparator (within 1e-12)
(let [scipy-data [[5 0.2 5 -8.047189562170502]
[50 0.99 49 -1.1856136373815076]
[50 0.01 1 -1.185613637381508]
[10 0 0 0]
[10 1 10 0]
[100 0.9 90 -2.02597397686619]
[500 0.1 0 -52.680257828913156]]]
(doseq [[n p v expected] scipy-data]
(let [actual (dist/logpdf (->binomial n p) v)]
(is (ish? expected actual)
(str "n=" n ", p=" p ", v=" v)))))))

(testing "spot check against gen.jl logpdf (v0.4.6)"
(with-comparator (within 1e-12)
(let [gen-data [[5 0.2 5 -8.047189562170502]
[50 0.99 49 -1.185613637381516]
[50 0.01 1 -1.1856136373815152]
[10 0 0 0]
[10 1 10 0]
[100 0.9 90 -2.025973976866184]
[500 0.1 0 -52.680257828913156]]]
(doseq [[n p v expected] gen-data]
(let [actual (dist/logpdf (->binomial n p) v)]
(is (ish? expected actual)
(str "n=" n ", p=" p ", v=" v))))))))

(defn categorical-tests [->cat]
(checking "map => categorical properties"
[p (gen-double 0 1)]
Expand Down
Loading