diff --git a/project.clj b/project.clj index 39573bc..f6e1a94 100644 --- a/project.clj +++ b/project.clj @@ -1,4 +1,4 @@ -(defproject hidden-markov-music "0.1.3-SNAPSHOT" +(defproject hidden-markov-music "0.1.3" :description "Generate original musical scores by means of a hidden Markov model." :url "https://github.com/dwysocki/hidden-markov-music" diff --git a/src/hidden_markov_music/hmm.clj b/src/hidden_markov_music/hmm.clj index c7b1b5c..9feb8c2 100644 --- a/src/hidden_markov_music/hmm.clj +++ b/src/hidden_markov_music/hmm.clj @@ -4,6 +4,7 @@ [hidden-markov-music.stats :as stats] [hidden-markov-music.random :refer [select-random-key]] [hidden-markov-music.util :refer [map-for map-vals + numbers-almost-equal? maps-almost-equal?]]) (:use clojure.pprint)) @@ -744,8 +745,8 @@ observations))] (- numerator denominator)))) -(defn train-model-helper - [model observations threshold likelihood] +(defn train-model-seq + [model observations] (let [alphas (forward-probability-seq model observations) betas (reverse (backward-probability-seq model observations)) gammas (gamma-seq model alphas betas) @@ -758,20 +759,43 @@ new-model (assoc model :initial-prob new-initial-probs :transition-prob new-transition-probs - :observation-prob new-observation-probs) - new-likelihood (likelihood-forward new-model observations)] - (if (> (- new-likelihood likelihood) - threshold) - (recur new-model observations threshold new-likelihood) - new-model))) + :observation-prob new-observation-probs)] + (cons new-model + (lazy-seq (train-model-seq new-model observations))))) (defn train-model "Trains the model via the Baum-Welch algorithm." - ([model observations] - (train-model model observations 0.0)) - ([model observations threshold] - (train-model-helper model observations threshold - (likelihood-forward model observations)))) + ([model observations & {:keys [decimal max-iter] + :or {decimal 15 max-iter 100}}] + (let [;; generate the infinite lazy seq of trained models, + ;; and take the maximum number of them + trained-models + (take max-iter + (cons model + (train-model-seq model + observations))) + ;; associate with each trained model its likelihood + trained-model-likelihoods + (map (fn [model] + [model (likelihood-forward model observations)]) + trained-models) + ;; create a sliding window of pairs of trained-models + trained-model-likelihood-pairs + (partition 2 1 trained-model-likelihoods)] + (->> trained-model-likelihood-pairs + ;; take from the list until convergence to the given decimal place + ;; is reached + (take-while (fn [[[model-prev likelihood-prev] + [model likelihood ]]] + (not (numbers-almost-equal? likelihood-prev + likelihood + :decimal decimal)))) + ;; the last element is where convergence happened + last + ;; extract the [model likelihood] pair later in the window + second + ;; extract the model itself from the [model likelihood] pair + first)))) diff --git a/test/hidden_markov_music/baum_welch_test.clj b/test/hidden_markov_music/baum_welch_test.clj index 3bbaafb..7a16e59 100644 --- a/test/hidden_markov_music/baum_welch_test.clj +++ b/test/hidden_markov_music/baum_welch_test.clj @@ -10,10 +10,12 @@ (is (hmm/valid-hmm? (hmm/train-model tm/ibe-ex-11-log-model [:good :good :so-so :bad :bad :good :bad :so-so] - 0.00001)))) + :decimal 5)))) (testing "with Emilio Frazzoli's Baum-Welch example" - (pprint - (hmm/LogHMM->HMM - (hmm/train-model tm/frazzoli-ex-log-model - tm/frazzoli-ex-observations - 0.0001)))))) + (is (hmm/hmms-almost-equal? + tm/frazzoli-ex-trained-model + (hmm/LogHMM->HMM + (hmm/train-model tm/frazzoli-ex-log-model + tm/frazzoli-ex-observations + :max-iter 20)) + :decimal 3)))))