Skip to content

Commit

Permalink
Began implementing Baum-Welch, but not finished
Browse files Browse the repository at this point in the history
Wait a second, those probabilities shouldn't sum to 1E-20... well,
better commit what I have and find the bugs tomorrow.
  • Loading branch information
dwysocki committed Mar 19, 2015
1 parent ced6db7 commit c8c830b
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 37 deletions.
184 changes: 147 additions & 37 deletions src/hidden_markov_music/hmm.clj
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
(ns hidden-markov-music.hmm
"General implementation of a hidden Markov model, and associated algorithms."
(:require [hidden-markov-music.stats :as stats]
[hidden-markov-music.random :refer [select-random-key]]))
[hidden-markov-music.random :refer [select-random-key]]
[hidden-markov-music.util :refer [map-for]]))

(defrecord HMM
[states
Expand Down Expand Up @@ -43,11 +44,10 @@
```"
[model obs]
;; map each state to its initial α
(zipmap (:states model)
(for [state (:states model)]
;; compute α_1 for the given state
(* (get-in model [:initial-prob state])
(get-in model [:observation-prob state obs])))))
(map-for [state (:states model)]
;; compute α_1 for the given state
(* (get-in model [:initial-prob state])
(get-in model [:observation-prob state obs]))))

(defn forward-probability-next
"Returns `α_t(i)`, for all states `i`, for `t > 1`, where `α_t(i)` is the
Expand All @@ -65,14 +65,13 @@
```"
[model obs alpha-prev]
;; map each state to its α
(zipmap (:states model)
(for [state (:states model)]
;; compute α_t for the given state
(* (get-in model [:observation-prob state obs])
(reduce +
(for [other-state (:states model)]
(* (get-in model [:transition-prob other-state state])
(alpha-prev other-state))))))))
(map-for [state (:states model)]
;; compute α_t for the given state
(* (get-in model [:observation-prob state obs])
(reduce +
(for [other-state (:states model)]
(* (get-in model [:transition-prob other-state state])
(alpha-prev other-state)))))))

(defn- forward-probability-helper
"Helper function for computing lazy seq of `α`'s.
Expand Down Expand Up @@ -135,16 +134,15 @@
```"
[model obs beta-next]
;; map each state to its β
(zipmap (:states model)
(for [state (:states model)]
;; compute β_t for the given state
(reduce +
(for [other-state (:states model)]
(* (get-in model
[:transition-prob state other-state])
(beta-next other-state)
(get-in model
[:observation-prob other-state obs])))))))
(map-for [state (:states model)]
;; compute β_t for the given state
(reduce +
(for [other-state (:states model)]
(* (get-in model
[:transition-prob state other-state])
(beta-next other-state)
(get-in model
[:observation-prob other-state obs]))))))

(defn- backward-probability-helper
"Helper function for computing lazy seq of `β`'s.
Expand Down Expand Up @@ -206,11 +204,10 @@
```"
[model obs]
{:delta
(zipmap (:states model)
(for [state (:states model)]
;; δ_1(i) = π(i)*b_i(o_1)
(* (get-in model [:initial-prob state])
(get-in model [:observation-prob state obs])))),
(map-for [state (:states model)]
;; δ_1(i) = π(i)*b_i(o_1)
(* (get-in model [:initial-prob state])
(get-in model [:observation-prob state obs]))),
;; initial state has no preceding states, so ψ_1(i) = nil
:psi nil})

Expand All @@ -228,13 +225,11 @@
(let [;; this is a mapping of
;; state-j -> state-i -> δ_{t-1}(i) p_{ij}
weighted-deltas
(zipmap (:states model)
(for [state (:states model)]
(zipmap (:states model)
(for [other-state (:states model)]
(* (get delta-prev other-state)
(get-in model [:transition-prob
other-state state]))))))
(map-for [state (:states model)
other-state (:states model)]
(* (get delta-prev other-state)
(get-in model [:transition-prob
other-state state])))
;; this is a mapping of
;; state-j -> [argmax(δ_{t-1}(i) p_{ij}),
;; max(δ_{t-1}(i) p_{ij})]
Expand Down Expand Up @@ -381,6 +376,121 @@
(map (partial random-emission model)
states)))

(defn gamma
""
[model forward-prob backward-prob]
(map-for [state (:states model)]
(/ (* (forward-prob state)
(backward-prob state))
(reduce +
(for [other-state (:states model)]
(* (backward-prob other-state)
(forward-prob other-state)))))))

(defn gamma-seq
""
[model forward-probs backward-probs]
(map (partial gamma model)
forward-probs
backward-probs))


(defn digamma
[model forward-prob backward-prob-next observation-next]
(let [likelihood
(->> (for [state-i (:states model)
state-j (:states model)]
(* (forward-prob state-i)
(get-in model [:transition-prob
state-i
state-j])
(get-in model [:observation-prob
state-j
observation-next])
(backward-prob-next state-j)))
flatten
(reduce +))]
(map-for [state-current (:states model)
state-next (:states model)]
(/ (* (forward-prob state-current)
(get-in model [:transition-prob
state-current
state-next])
(get-in model [:observation-prob
state-next
observation-next])
(backward-prob-next state-next))
likelihood))))

(defn digamma-seq
[model forward-probs backward-probs observations]
(map (partial digamma model)
forward-probs
(rest backward-probs)
(rest observations)))

(defn- train-initial-probs
[gammas]
(first gammas))

(defn- train-transition-probs
[model gammas digammas]
(map-for [state-current (:states model)]
(let [expected-transitions
(->> gammas
butlast
(map #(get % state-current))
(reduce +))]
(map-for [state-next (:states model)]
(/ (->> digammas
butlast
(map #(get-in % [state-current state-next]))
(reduce +))
expected-transitions)))))

(defn- train-observation-probs
[model gammas observations]
(map-for [state-current (:states model)]
(let [expected-transitions (->> gammas
butlast
(map #(get % state-current))
(reduce +))]
(map-for [obs (:observations model)]
(->> (map vector gammas observations)
(filter (fn [[g o]] (= o obs)))
(map (fn [[g o]] (g state-current)))
(reduce +))))))

(defn- train-model-helper
[model observations threshold likelihood]
(let [alphas (forward-probability-seq model observations)
betas (reverse (backward-probability-seq model observations))
gammas (gamma-seq model alphas betas)
digammas (digamma-seq model alphas betas observations)

new-initial-probs (train-initial-probs gammas)
new-transition-probs (train-transition-probs model gammas
digammas)
new-observation-probs (train-observation-probs model gammas
observations)
new-model (HMM.
(:states model)
(:observations model)
new-transition-probs
new-observation-probs
new-initial-probs)

new-likelihood (likelihood-forward new-model observations)]
(if (> (- new-likelihood likelihood)
threshold)
(recur new-model observations threshold new-likelihood)
new-model)))

(defn train-model
"Trains the model via the Baum-Welch algorithm."
[model observations])
([model observations]
(train-model model observations 0.0))
([model observations threshold]
(train-model-helper model observations threshold
(likelihood-forward model observations))))

10 changes: 10 additions & 0 deletions test/hidden_markov_music/baum_welch_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
(ns hidden-markov-music.baum-welch-test
(:require [hidden-markov-music.hmm :as hmm]
[hidden-markov-music.test-models :as tm])
(:use clojure.test
clojure.pprint))

(testing "Baum-Welch algorithm"
(pprint (hmm/train-model tm/ibe-ex-11-model
tm/ibe-ex-11-observations
0.0001)))

0 comments on commit c8c830b

Please sign in to comment.