From 1bb7951c4bba3c9eb31dbe466ffcc79da59ceb8d Mon Sep 17 00:00:00 2001 From: Dan Wysocki Date: Sun, 5 Apr 2015 23:49:32 -0400 Subject: [PATCH] Implemented logarithmic Baum-Welch and unit tested It produces _almost_ the same results as Emilio Frazzoli's worked example. There are several possible sources of discrepancy. For one, he uses standard probabilities in place of logarithmic ones. He also sets a fixed number of iterations at 20, which I can add as an option. --- project.clj | 2 +- src/hidden_markov_music/hmm.clj | 97 ++++++++++++++------ src/hidden_markov_music/stats.clj | 32 ++++++- test/hidden_markov_music/baum_welch_test.clj | 17 +++- test/hidden_markov_music/stats_test.clj | 14 +++ test/hidden_markov_music/test_models.clj | 72 +++++++++++++++ 6 files changed, 201 insertions(+), 33 deletions(-) create mode 100644 test/hidden_markov_music/stats_test.clj diff --git a/project.clj b/project.clj index bdfded9..39573bc 100644 --- a/project.clj +++ b/project.clj @@ -1,4 +1,4 @@ -(defproject hidden-markov-music "0.1.2" +(defproject hidden-markov-music "0.1.3-SNAPSHOT" :description "Generate original musical scores by means of a hidden Markov model." :url "https://github.com/dwysocki/hidden-markov-music" diff --git a/src/hidden_markov_music/hmm.clj b/src/hidden_markov_music/hmm.clj index 925b911..c7b1b5c 100644 --- a/src/hidden_markov_music/hmm.clj +++ b/src/hidden_markov_music/hmm.clj @@ -55,7 +55,7 @@ (map-vals (partial map-vals exp) (:observation-prob model)))) (defmulti hmms-almost-equal? - "Returns true if two HMMs are equal to the given precision" + "Returns true if two HMMs are equal to the given precision." (fn [x y & {:keys [decimal] :or {decimal 6}}] [(class x) (class y)])) @@ -83,6 +83,29 @@ [x y & args] false) +(defmulti valid-hmm? + "Returns true if the HMM has all stochastic probabilities to the given + precision." + model-class) + +(defmethod valid-hmm? HMM + [model & {:keys [decimal] :or {decimal 10}}] + (and (stats/stochastic-map? (:initial-prob model) + :decimal decimal) + (stats/row-stochastic-map? (:transition-prob model) + :decimal decimal) + (stats/row-stochastic-map? (:observation-prob model) + :decimal decimal))) + +(defmethod valid-hmm? LogHMM + [model & {:keys [decimal] :or {decimal 10}}] + (and (stats/log-stochastic-map? (:initial-prob model) + :decimal decimal) + (stats/log-row-stochastic-map? (:transition-prob model) + :decimal decimal) + (stats/log-row-stochastic-map? (:observation-prob model) + :decimal decimal))) + (defn random-HMM "Returns an HMM with random probabilities, given the state and observation labels." @@ -662,16 +685,29 @@ (defmethod train-transition-probs HMM [model gammas digammas] (map-for [state-current (:states model)] - (let [expected-transitions - (->> gammas - butlast - (map #(get % state-current)) - (reduce +))] - (map-for [state-next (:states model)] - (/ (->> digammas - (map #(get-in % [state-current state-next])) - (reduce +)) - expected-transitions))))) + (let [expected-transitions + (->> gammas + butlast + (map #(get % state-current)) + (reduce +))] + (map-for [state-next (:states model)] + (/ (->> digammas + (map #(get-in % [state-current state-next])) + (reduce +)) + expected-transitions))))) + +(defmethod train-transition-probs LogHMM + [model gammas digammas] + (map-for [state-i (:states model) + state-j (:states model)] + (let [[numerator denominator] + (reduce (fn [[num denom] [gamma digamma]] + [(log-sum num (get-in digamma [state-i state-j])), + (log-sum denom (get gamma state-i))]) + [Double/NEGATIVE_INFINITY Double/NEGATIVE_INFINITY] + (map vector + gammas digammas))] + (- numerator denominator)))) (defmulti ^:private train-observation-probs "Returns an updated observation probability matrix given the gammas computed @@ -690,11 +726,25 @@ (map (fn [[g o]] (g state-current))) (reduce +)))))) -(defmulti ^:private train-model-helper - "" - model-class) - -(defmethod train-model-helper HMM +(defmethod train-observation-probs LogHMM + [model gammas observations] + (map-for [state (:states model) + obs (:observations model)] + (let [[numerator denominator] + (reduce (fn [[num denom] [gamma o]] + [;; update numerator if observation at time t is obs + (if (= obs o) + (log-sum num (get gamma state)) + num) + ;; update denominator + (log-sum denom (get gamma state))]) + [Double/NEGATIVE_INFINITY Double/NEGATIVE_INFINITY] + (map vector + gammas + observations))] + (- numerator denominator)))) + +(defn train-model-helper [model observations threshold likelihood] (let [alphas (forward-probability-seq model observations) betas (reverse (backward-probability-seq model observations)) @@ -705,23 +755,16 @@ new-transition-probs (train-transition-probs model gammas digammas) new-observation-probs (train-observation-probs model gammas observations) - new-model (HMM. - (:states model) - (:observations model) - new-initial-probs - new-transition-probs - new-observation-probs) - + new-model (assoc model + :initial-prob new-initial-probs + :transition-prob new-transition-probs + :observation-prob new-observation-probs) new-likelihood (likelihood-forward new-model observations)] (if (> (- new-likelihood likelihood) threshold) (recur new-model observations threshold new-likelihood) new-model))) -(defmethod train-model-helper LogHMM - [model observations threshold likelihood] - ) - (defn train-model "Trains the model via the Baum-Welch algorithm." ([model observations] diff --git a/src/hidden_markov_music/stats.clj b/src/hidden_markov_music/stats.clj index 33afb4e..9f868ac 100644 --- a/src/hidden_markov_music/stats.clj +++ b/src/hidden_markov_music/stats.clj @@ -1,6 +1,8 @@ (ns hidden-markov-music.stats "General statistical functions." - (:require [hidden-markov-music.util :refer [map-for]])) + (:require [hidden-markov-music.math :refer [log exp log-sum]] + [hidden-markov-music.util :refer [map-for + numbers-almost-equal?]])) (defn normalize "Normalizes a sequence." @@ -43,3 +45,31 @@ (map-for [r row-keys] (let [col-probs (random-stochastic-vector n-cols)] (zipmap col-keys col-probs))))) + +(defn stochastic-map? + "Returns true if the map is stochastic to the given precision." + [m & {:keys [decimal] :or {decimal 10}}] + (numbers-almost-equal? (reduce + (vals m)) + 1.0 + :decimal decimal)) + +(defn log-stochastic-map? + "Returns true if the map is logarithmically stochastic to the given + precision." + [m & {:keys [decimal] :or {decimal 10}}] + (numbers-almost-equal? (reduce log-sum (vals m)) + 0.0 + :decimal decimal)) + +(defn row-stochastic-map? + "Returns true if the map is row stochastic to the given precision." + [m & {:keys [decimal] :or {decimal 10}}] + (every? #(stochastic-map? % :decimal decimal) + (vals m))) + +(defn log-row-stochastic-map? + "Returns true if the map is logarithmically stochastic to the given + precision." + [m & {:keys [decimal] :or {decimal 10}}] + (every? #(log-stochastic-map? % :decimal decimal) + (vals m))) diff --git a/test/hidden_markov_music/baum_welch_test.clj b/test/hidden_markov_music/baum_welch_test.clj index fe72fb5..3bbaafb 100644 --- a/test/hidden_markov_music/baum_welch_test.clj +++ b/test/hidden_markov_music/baum_welch_test.clj @@ -4,7 +4,16 @@ (:use clojure.test clojure.pprint)) -(testing "Baum-Welch algorithm" - (hmm/train-model tm/ibe-ex-11-model - [:good :good :so-so :bad :bad :good :bad :so-so] - 0.0001)) +(deftest log-baum-welch-algorithm-test + (testing "logarithmic Baum-Welch algorithm" + (testing "with Oliver Ibe's Example 11" + (is (hmm/valid-hmm? + (hmm/train-model tm/ibe-ex-11-log-model + [:good :good :so-so :bad :bad :good :bad :so-so] + 0.00001)))) + (testing "with Emilio Frazzoli's Baum-Welch example" + (pprint + (hmm/LogHMM->HMM + (hmm/train-model tm/frazzoli-ex-log-model + tm/frazzoli-ex-observations + 0.0001)))))) diff --git a/test/hidden_markov_music/stats_test.clj b/test/hidden_markov_music/stats_test.clj new file mode 100644 index 0000000..ca9d3ac --- /dev/null +++ b/test/hidden_markov_music/stats_test.clj @@ -0,0 +1,14 @@ +(ns hidden-markov-music.stats-test + (:require [hidden-markov-music.stats :as stats]) + (:use clojure.test)) + +(deftest stochastic-map-test + (testing "random stochastic map is stochastic" + (dotimes [_ 10] + (is (stats/stochastic-map? + (stats/random-stochastic-map [:a :b :c :d :e]))))) + (testing "random row-stochastic map is row-stochastic" + (dotimes [_ 10] + (is (stats/row-stochastic-map? + (stats/random-row-stochastic-map [:a :b :c :d :e] + [:A :B :C :D :E])))))) diff --git a/test/hidden_markov_music/test_models.clj b/test/hidden_markov_music/test_models.clj index 94d3a62..6b067dc 100644 --- a/test/hidden_markov_music/test_models.clj +++ b/test/hidden_markov_music/test_models.clj @@ -43,6 +43,78 @@ #{[:sunny :sunny :sunny :rainy :rainy] [:sunny :sunny :cloudy :rainy :rainy]}) + +;; models taken from Larry Moss' Baum-Welch examples +(def moss-ex-1-model + (HMM. [:s :t] + + [:A :B] + + {:s 0.85, + :t 0.15} + + {:s {:s 0.3, + :t 0.7}, + :t {:s 0.1, + :t 0.9}}, + + {:s {:A 0.4, + :B 0.6}, + :t {:A 0.5, + :B 0.5}})) + +(def moss-ex-1-log-model + (hmm/HMM->LogHMM moss-ex-1-model)) + + +;; model taken from Emilio Frazzoli's Baum-Welch example +(def frazzoli-ex-model + (HMM. [:LA :NY] + + [:LA :NY :null] + + {:LA 0.5, + :NY 0.5} + + {:LA {:LA 0.5, + :NY 0.5}, + :NY {:LA 0.5, + :NY 0.5}} + + {:LA {:LA 0.4, + :NY 0.1, + :null 0.5}, + :NY {:LA 0.1, + :NY 0.5, + :null 0.4}})) + +(def frazzoli-ex-log-model + (hmm/HMM->LogHMM frazzoli-ex-model)) + +(def frazzoli-ex-observations + [:null :LA :LA :null :NY :null :NY :NY :NY :null + :NY :NY :NY :NY :NY :null :null :LA :LA :NY ]) + +(def frazzoli-ex-trained-model + (HMM. [:LA :NY] + + [:LA :NY :null] + + {:LA 1.0, + :NY 0.0} + + {:LA {:LA 0.6909 + :NY 0.3091}, + :NY {:LA 0.0934 + :NY 0.9066}} + + {:LA {:LA 0.5807 + :NY 0.0010 + :null 0.4183}, + :NY {:LA 0.0000 + :NY 0.7621 + :null 0.2379}})) + ;; fully deterministic HMM, whose states must be ;; :A -> :B -> :C -> :A -> ... ;; and whose emissions must be