Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
sandreza committed Nov 6, 2024
1 parent f9636f3 commit 73830aa
Showing 1 changed file with 101 additions and 4 deletions.
105 changes: 101 additions & 4 deletions examples/statistical_model_prototype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ struct GeneralGaussianMixture{W, M, C}
covariances::C
end

struct ScoreModel{P, S}
probability_model::P
inverse_covariances::S
end

function determine_statistical_model(trajectory, tree_type::Tree{Val{false},S}; override=false) where {S}
if typeof(tree_type.arguments) <: NamedTuple
Expand Down Expand Up @@ -66,17 +70,17 @@ function determine_statistical_model(trajectory, tree_type::Tree{Val{false},S};
return embedding, probability_weights, means, partitions
end


##
# include the lorenz file here
method = Tree(false, 0.01)
method = Tree(false, 0.001)

emb, pr, ms, partitions = determine_statistical_model(trajectory, method)

empirical_centers = zeros(size(trajectory)[1], maximum(partitions))
empirical_covariance = zeros(size(trajectory)[1], size(trajectory)[1], maximum(partitions))
empirical_count = zeros(Int64, maximum(partitions))
for (i, state) in ProgressBar(enumerate(eachcol(trajectory)))
cell_index = embedding(state)
cell_index = emb(state)
partitions[i] = cell_index
empirical_count[cell_index] += 1
empirical_centers[:, cell_index] .+= state
Expand Down Expand Up @@ -107,6 +111,14 @@ function rand(Σmodel::GeneralGaussianMixture)
return rand(MvNormal(Σmodel.means[:, cell_index], Σmodel.covariances[:, :, cell_index]))
end

function rand(Σmodel::GeneralGaussianMixture, N::Int)
samples = zeros(size(Σmodel.means)[1], N)
for i in 1:N
samples[:, i] .= rand(Σmodel)
end
return samples
end

function mean(δmodel::DeltaFunction)
ensemble_mean = zeros(size(δmodel.means)[1])
for i in 1:size(δmodel.means)[2]
Expand Down Expand Up @@ -168,4 +180,89 @@ scatter(samples)
##
totcov = cov(trajectory')
cov(Σmodel)
cov(δmodel)
cov(δmodel)

fig = Figure()
ax = Axis(fig[1, 1])
density!(ax, samples[1, :], normalization = :pdf, color = (:red, 0.5))
hist!(ax, trajectory[1, :], normalization = :pdf, color = (:blue, 0.5))
display(fig)

##
function ScoreModel(gm::GeneralGaussianMixture)
n = size(gm.means)[1]
m = size(gm.means)[2]
Σinv = zeros(n, n, m)
for i in 1:m
Σinv[:, :, i] = pinv(gm.covariances[:, :, i])
end
return ScoreModel(gm, Σinv)
end

score = ScoreModel(Σmodel)

function (score::ScoreModel)(x)
n = size(score.probability_model.means)[1]
m = size(score.probability_model.means)[2]
score_value = zeros(n)
denominator = [0.0]
for i in 1:m
Δ = score.probability_model.means[:, i] - x
Σ⁻¹Δ = score.inverse_covariances[:, :, i] * Δ
normalization = sqrt(det(2π * score.probability_model.covariances[:, :, i]))
U = exp(-0.5 * Δ' * Σ⁻¹Δ) / normalization
weightedU = score.probability_model.weights[i] * U
score_value .+= weightedU * Σ⁻¹Δ
denominator .+= weightedU
end
return score_value / denominator[1]
end

##
dt = 0.01
iterations = 10^6


trajectory = zeros(1, iterations)
trajectory[:, 1] .= [0.0]
step = RungeKutta4(1)
for i in ProgressBar(2:iterations)
step(linear, trajectory[:, i-1], dt)
trajectory[:, i] .= step.xⁿ⁺¹ .+ sqrt(dt) * randn(1)
end

method = Tree(false, 0.1)

emb, pr, ms, partitions = determine_statistical_model(trajectory, method)

empirical_centers = zeros(size(trajectory)[1], maximum(partitions))
empirical_covariance = zeros(size(trajectory)[1], size(trajectory)[1], maximum(partitions))
empirical_count = zeros(Int64, maximum(partitions))
for (i, state) in ProgressBar(enumerate(eachcol(trajectory)))
cell_index = emb(state)
partitions[i] = cell_index
empirical_count[cell_index] += 1
empirical_centers[:, cell_index] .+= state
empirical_covariance[:, :, cell_index] .+= state * state'
end
# adjust
adj_empirical_centers = empirical_centers ./ reshape(empirical_count, 1, length(empirical_count))
adj_empirical_covariance = empirical_covariance ./ reshape(empirical_count .- 1, 1, 1, length(empirical_count))
for i in 1:length(probability_weights)
adj_empirical_covariance[:, :, i] .-= (adj_empirical_centers[:, i] * adj_empirical_centers[:, i]') * empirical_count[i] / (empirical_count[i] - 1)
end
probability_weights = empirical_count / sum(empirical_count)

δmodel = DeltaFunction(probability_weights, adj_empirical_centers)
Σmodel = GeneralGaussianMixture(probability_weights, adj_empirical_centers, adj_empirical_covariance * 20)

score = ScoreModel(Σmodel)

σemp = cov(trajectory')
xs = range(-3, 3, length=100)
scorevals = [score([x])[1] for x in xs]
fig = Figure()
ax = Axis(fig[1, 1])
scatter!(ax, xs, scorevals)
lines!(ax, xs, -xs / σemp^2, color=:red)
display(fig)

0 comments on commit 73830aa

Please sign in to comment.