Skip to content

Commit

Permalink
Implemented logarithmic Baum-Welch and unit tested
Browse files Browse the repository at this point in the history
It produces _almost_ the same results as Emilio Frazzoli's worked
example. There are several possible sources of discrepancy. For one, he
uses standard probabilities in place of logarithmic ones. He also sets a
fixed number of iterations at 20, which I can add as an option.
  • Loading branch information
dwysocki committed Apr 6, 2015
1 parent 61a3836 commit 1bb7951
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 33 deletions.
2 changes: 1 addition & 1 deletion project.clj
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
(defproject hidden-markov-music "0.1.2"
(defproject hidden-markov-music "0.1.3-SNAPSHOT"
:description "Generate original musical scores by means of a hidden Markov
model."
:url "https://github.com/dwysocki/hidden-markov-music"
Expand Down
97 changes: 70 additions & 27 deletions src/hidden_markov_music/hmm.clj
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
(map-vals (partial map-vals exp) (:observation-prob model))))

(defmulti hmms-almost-equal?
"Returns true if two HMMs are equal to the given precision"
"Returns true if two HMMs are equal to the given precision."
(fn [x y & {:keys [decimal] :or {decimal 6}}]
[(class x) (class y)]))

Expand Down Expand Up @@ -83,6 +83,29 @@
[x y & args]
false)

(defmulti valid-hmm?
"Returns true if the HMM has all stochastic probabilities to the given
precision."
model-class)

(defmethod valid-hmm? HMM
[model & {:keys [decimal] :or {decimal 10}}]
(and (stats/stochastic-map? (:initial-prob model)
:decimal decimal)
(stats/row-stochastic-map? (:transition-prob model)
:decimal decimal)
(stats/row-stochastic-map? (:observation-prob model)
:decimal decimal)))

(defmethod valid-hmm? LogHMM
[model & {:keys [decimal] :or {decimal 10}}]
(and (stats/log-stochastic-map? (:initial-prob model)
:decimal decimal)
(stats/log-row-stochastic-map? (:transition-prob model)
:decimal decimal)
(stats/log-row-stochastic-map? (:observation-prob model)
:decimal decimal)))

(defn random-HMM
"Returns an HMM with random probabilities, given the state and observation
labels."
Expand Down Expand Up @@ -662,16 +685,29 @@
(defmethod train-transition-probs HMM
[model gammas digammas]
(map-for [state-current (:states model)]
(let [expected-transitions
(->> gammas
butlast
(map #(get % state-current))
(reduce +))]
(map-for [state-next (:states model)]
(/ (->> digammas
(map #(get-in % [state-current state-next]))
(reduce +))
expected-transitions)))))
(let [expected-transitions
(->> gammas
butlast
(map #(get % state-current))
(reduce +))]
(map-for [state-next (:states model)]
(/ (->> digammas
(map #(get-in % [state-current state-next]))
(reduce +))
expected-transitions)))))

(defmethod train-transition-probs LogHMM
[model gammas digammas]
(map-for [state-i (:states model)
state-j (:states model)]
(let [[numerator denominator]
(reduce (fn [[num denom] [gamma digamma]]
[(log-sum num (get-in digamma [state-i state-j])),
(log-sum denom (get gamma state-i))])
[Double/NEGATIVE_INFINITY Double/NEGATIVE_INFINITY]
(map vector
gammas digammas))]
(- numerator denominator))))

(defmulti ^:private train-observation-probs
"Returns an updated observation probability matrix given the gammas computed
Expand All @@ -690,11 +726,25 @@
(map (fn [[g o]] (g state-current)))
(reduce +))))))

(defmulti ^:private train-model-helper
""
model-class)

(defmethod train-model-helper HMM
(defmethod train-observation-probs LogHMM
[model gammas observations]
(map-for [state (:states model)
obs (:observations model)]
(let [[numerator denominator]
(reduce (fn [[num denom] [gamma o]]
[;; update numerator if observation at time t is obs
(if (= obs o)
(log-sum num (get gamma state))
num)
;; update denominator
(log-sum denom (get gamma state))])
[Double/NEGATIVE_INFINITY Double/NEGATIVE_INFINITY]
(map vector
gammas
observations))]
(- numerator denominator))))

(defn train-model-helper
[model observations threshold likelihood]
(let [alphas (forward-probability-seq model observations)
betas (reverse (backward-probability-seq model observations))
Expand All @@ -705,23 +755,16 @@
new-transition-probs (train-transition-probs model gammas digammas)
new-observation-probs (train-observation-probs model gammas
observations)
new-model (HMM.
(:states model)
(:observations model)
new-initial-probs
new-transition-probs
new-observation-probs)

new-model (assoc model
:initial-prob new-initial-probs
:transition-prob new-transition-probs
:observation-prob new-observation-probs)
new-likelihood (likelihood-forward new-model observations)]
(if (> (- new-likelihood likelihood)
threshold)
(recur new-model observations threshold new-likelihood)
new-model)))

(defmethod train-model-helper LogHMM
[model observations threshold likelihood]
)

(defn train-model
"Trains the model via the Baum-Welch algorithm."
([model observations]
Expand Down
32 changes: 31 additions & 1 deletion src/hidden_markov_music/stats.clj
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
(ns hidden-markov-music.stats
"General statistical functions."
(:require [hidden-markov-music.util :refer [map-for]]))
(:require [hidden-markov-music.math :refer [log exp log-sum]]
[hidden-markov-music.util :refer [map-for
numbers-almost-equal?]]))

(defn normalize
"Normalizes a sequence."
Expand Down Expand Up @@ -43,3 +45,31 @@
(map-for [r row-keys]
(let [col-probs (random-stochastic-vector n-cols)]
(zipmap col-keys col-probs)))))

(defn stochastic-map?
"Returns true if the map is stochastic to the given precision."
[m & {:keys [decimal] :or {decimal 10}}]
(numbers-almost-equal? (reduce + (vals m))
1.0
:decimal decimal))

(defn log-stochastic-map?
"Returns true if the map is logarithmically stochastic to the given
precision."
[m & {:keys [decimal] :or {decimal 10}}]
(numbers-almost-equal? (reduce log-sum (vals m))
0.0
:decimal decimal))

(defn row-stochastic-map?
"Returns true if the map is row stochastic to the given precision."
[m & {:keys [decimal] :or {decimal 10}}]
(every? #(stochastic-map? % :decimal decimal)
(vals m)))

(defn log-row-stochastic-map?
"Returns true if the map is logarithmically stochastic to the given
precision."
[m & {:keys [decimal] :or {decimal 10}}]
(every? #(log-stochastic-map? % :decimal decimal)
(vals m)))
17 changes: 13 additions & 4 deletions test/hidden_markov_music/baum_welch_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
(:use clojure.test
clojure.pprint))

(testing "Baum-Welch algorithm"
(hmm/train-model tm/ibe-ex-11-model
[:good :good :so-so :bad :bad :good :bad :so-so]
0.0001))
(deftest log-baum-welch-algorithm-test
(testing "logarithmic Baum-Welch algorithm"
(testing "with Oliver Ibe's Example 11"
(is (hmm/valid-hmm?
(hmm/train-model tm/ibe-ex-11-log-model
[:good :good :so-so :bad :bad :good :bad :so-so]
0.00001))))
(testing "with Emilio Frazzoli's Baum-Welch example"
(pprint
(hmm/LogHMM->HMM
(hmm/train-model tm/frazzoli-ex-log-model
tm/frazzoli-ex-observations
0.0001))))))
14 changes: 14 additions & 0 deletions test/hidden_markov_music/stats_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
(ns hidden-markov-music.stats-test
(:require [hidden-markov-music.stats :as stats])
(:use clojure.test))

(deftest stochastic-map-test
(testing "random stochastic map is stochastic"
(dotimes [_ 10]
(is (stats/stochastic-map?
(stats/random-stochastic-map [:a :b :c :d :e])))))
(testing "random row-stochastic map is row-stochastic"
(dotimes [_ 10]
(is (stats/row-stochastic-map?
(stats/random-row-stochastic-map [:a :b :c :d :e]
[:A :B :C :D :E]))))))
72 changes: 72 additions & 0 deletions test/hidden_markov_music/test_models.clj
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,78 @@
#{[:sunny :sunny :sunny :rainy :rainy]
[:sunny :sunny :cloudy :rainy :rainy]})


;; models taken from Larry Moss' Baum-Welch examples
(def moss-ex-1-model
(HMM. [:s :t]

[:A :B]

{:s 0.85,
:t 0.15}

{:s {:s 0.3,
:t 0.7},
:t {:s 0.1,
:t 0.9}},

{:s {:A 0.4,
:B 0.6},
:t {:A 0.5,
:B 0.5}}))

(def moss-ex-1-log-model
(hmm/HMM->LogHMM moss-ex-1-model))


;; model taken from Emilio Frazzoli's Baum-Welch example
(def frazzoli-ex-model
(HMM. [:LA :NY]

[:LA :NY :null]

{:LA 0.5,
:NY 0.5}

{:LA {:LA 0.5,
:NY 0.5},
:NY {:LA 0.5,
:NY 0.5}}

{:LA {:LA 0.4,
:NY 0.1,
:null 0.5},
:NY {:LA 0.1,
:NY 0.5,
:null 0.4}}))

(def frazzoli-ex-log-model
(hmm/HMM->LogHMM frazzoli-ex-model))

(def frazzoli-ex-observations
[:null :LA :LA :null :NY :null :NY :NY :NY :null
:NY :NY :NY :NY :NY :null :null :LA :LA :NY ])

(def frazzoli-ex-trained-model
(HMM. [:LA :NY]

[:LA :NY :null]

{:LA 1.0,
:NY 0.0}

{:LA {:LA 0.6909
:NY 0.3091},
:NY {:LA 0.0934
:NY 0.9066}}

{:LA {:LA 0.5807
:NY 0.0010
:null 0.4183},
:NY {:LA 0.0000
:NY 0.7621
:null 0.2379}}))

;; fully deterministic HMM, whose states must be
;; :A -> :B -> :C -> :A -> ...
;; and whose emissions must be
Expand Down

0 comments on commit 1bb7951

Please sign in to comment.