diff --git a/Project.toml b/Project.toml index 07d14abd..9b4c73c4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LaplaceRedux" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" authors = ["Patrick Altmeyer"] -version = "0.1.2" +version = "0.1.3" [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" @@ -16,10 +16,8 @@ MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" @@ -37,7 +35,6 @@ MLJFlux = "0.2.10, 0.3" MLJModelInterface = "1.8.0" MLUtils = "0.4.3" Parameters = "0.12" -Plots = "1" ProgressMeter = "1.7.2" Tables = "1.10.1" Tullio = "0.3.5" diff --git a/src/LaplaceRedux.jl b/src/LaplaceRedux.jl index 8cf17b93..0f5e76bf 100644 --- a/src/LaplaceRedux.jl +++ b/src/LaplaceRedux.jl @@ -22,6 +22,4 @@ export optimize_prior!, include("mlj_flux.jl") export LaplaceApproximation -include("plotting.jl") - end diff --git a/src/plotting.jl b/src/plotting.jl deleted file mode 100644 index 13b216f8..00000000 --- a/src/plotting.jl +++ /dev/null @@ -1,124 +0,0 @@ -using Plots - -function Plots.plot( - la::Laplace, - X::AbstractArray, - y::AbstractArray; - link_approx::Symbol=:probit, - target::Union{Nothing,Real}=nothing, - colorbar=true, - title=nothing, - length_out=50, - zoom=-1, - xlims=nothing, - ylims=nothing, - linewidth=0.1, - lw=4, - kwargs..., -) - if la.likelihood == :regression - @assert size(X, 1) == 1 "Cannot plot regression for multiple input variables." - else - @assert size(X, 1) == 2 "Cannot plot classification for more than two input variables." - end - - if la.likelihood == :regression - - # REGRESSION - - # Surface range: - if isnothing(xlims) - xlims = (minimum(X), maximum(X)) .+ (zoom, -zoom) - else - xlims = xlims .+ (zoom, -zoom) - end - if isnothing(ylims) - ylims = (minimum(y), maximum(y)) .+ (zoom, -zoom) - else - ylims = ylims .+ (zoom, -zoom) - end - x_range = range(xlims[1]; stop=xlims[2], length=length_out) - y_range = range(ylims[1]; stop=ylims[2], length=length_out) - - title = isnothing(title) ? "" : title - - # Plot: - scatter( - vec(X), - vec(y); - label="ytrain", - xlim=xlims, - ylim=ylims, - lw=lw, - title=title, - kwargs..., - ) - _x = collect(x_range)[:, :]' - fμ, fvar = la(_x) - fμ = vec(fμ) - fσ = vec(sqrt.(fvar)) - pred_std = sqrt.(fσ .^ 2 .+ la.σ^2) - plot!( - x_range, - fμ; - color=2, - label="yhat", - ribbon=(1.96 * pred_std, 1.96 * pred_std), - lw=lw, - kwargs..., - ) # the specific values 1.96 are used here to create a 95% confidence interval - else - - # CLASSIFICATION - - # Surface range: - if isnothing(xlims) - xlims = (minimum(X[1, :]), maximum(X[1, :])) .+ (zoom, -zoom) - else - xlims = xlims .+ (zoom, -zoom) - end - if isnothing(ylims) - ylims = (minimum(X[2, :]), maximum(X[2, :])) .+ (zoom, -zoom) - else - ylims = ylims .+ (zoom, -zoom) - end - x_range = range(xlims[1]; stop=xlims[2], length=length_out) - y_range = range(ylims[1]; stop=ylims[2], length=length_out) - - # Plot - predict_ = function (X::AbstractVector) - z = la(X; link_approx=link_approx) - if outdim(la) == 1 # binary - z = [1.0 - z[1], z[1]] - end - return z - end - Z = [predict_([x, y]) for x in x_range, y in y_range] - Z = reduce(hcat, Z) - if outdim(la) > 1 - if isnothing(target) - @info "No target label supplied, using first." - end - target = isnothing(target) ? 1 : target - title = isnothing(title) ? "p̂(y=$(target))" : title - else - target = isnothing(target) ? 2 : target - title = isnothing(title) ? "p̂(y=$(target-1))" : title - end - - # Contour: - contourf( - x_range, - y_range, - Z[Int(target), :]; - colorbar=colorbar, - title=title, - linewidth=linewidth, - xlims=xlims, - ylims=ylims, - kwargs..., - ) - # Samples: - scatter!(X[1, :], X[2, :]; group=Int.(y), color=Int.(y), kwargs...) - end -end