diff --git a/dev/.documenter-siteinfo.json b/dev/.documenter-siteinfo.json index a164a073..df19cef6 100644 --- a/dev/.documenter-siteinfo.json +++ b/dev/.documenter-siteinfo.json @@ -1 +1 @@ -{"documenter":{"julia_version":"1.10.4","generation_timestamp":"2024-08-13T19:07:09","documenter_version":"1.5.0"}} \ No newline at end of file +{"documenter":{"julia_version":"1.10.5","generation_timestamp":"2024-09-03T10:07:17","documenter_version":"1.6.0"}} \ No newline at end of file diff --git a/dev/assets/documenter.js b/dev/assets/documenter.js index b2bdd43e..82252a11 100644 --- a/dev/assets/documenter.js +++ b/dev/assets/documenter.js @@ -77,30 +77,35 @@ require(['jquery'], function($) { let timer = 0; var isExpanded = true; -$(document).on("click", ".docstring header", function () { - let articleToggleTitle = "Expand docstring"; - - debounce(() => { - if ($(this).siblings("section").is(":visible")) { - $(this) - .find(".docstring-article-toggle-button") - .removeClass("fa-chevron-down") - .addClass("fa-chevron-right"); - } else { - $(this) - .find(".docstring-article-toggle-button") - .removeClass("fa-chevron-right") - .addClass("fa-chevron-down"); +$(document).on( + "click", + ".docstring .docstring-article-toggle-button", + function () { + let articleToggleTitle = "Expand docstring"; + const parent = $(this).parent(); + + debounce(() => { + if (parent.siblings("section").is(":visible")) { + parent + .find("a.docstring-article-toggle-button") + .removeClass("fa-chevron-down") + .addClass("fa-chevron-right"); + } else { + parent + .find("a.docstring-article-toggle-button") + .removeClass("fa-chevron-right") + .addClass("fa-chevron-down"); - articleToggleTitle = "Collapse docstring"; - } + articleToggleTitle = "Collapse docstring"; + } - $(this) - .find(".docstring-article-toggle-button") - .prop("title", articleToggleTitle); - $(this).siblings("section").slideToggle(); - }); -}); + parent + .children(".docstring-article-toggle-button") + .prop("title", articleToggleTitle); + parent.siblings("section").slideToggle(); + }); + } +); $(document).on("click", ".docs-article-toggle-button", function (event) { let articleToggleTitle = "Expand docstring"; @@ -110,7 +115,7 @@ $(document).on("click", ".docs-article-toggle-button", function (event) { debounce(() => { if (isExpanded) { $(this).removeClass("fa-chevron-up").addClass("fa-chevron-down"); - $(".docstring-article-toggle-button") + $("a.docstring-article-toggle-button") .removeClass("fa-chevron-down") .addClass("fa-chevron-right"); @@ -119,7 +124,7 @@ $(document).on("click", ".docs-article-toggle-button", function (event) { $(".docstring section").slideUp(animationSpeed); } else { $(this).removeClass("fa-chevron-down").addClass("fa-chevron-up"); - $(".docstring-article-toggle-button") + $("a.docstring-article-toggle-button") .removeClass("fa-chevron-right") .addClass("fa-chevron-down"); diff --git a/dev/index.html b/dev/index.html index 7c7b3448..14cf98ac 100644 --- a/dev/index.html +++ b/dev/index.html @@ -1,5 +1,5 @@ -Home · LaplaceRedux.jl

Documentation for LaplaceRedux.jl.

LaplaceRedux

LaplaceRedux.jl is a library written in pure Julia that can be used for effortless Bayesian Deep Learning through Laplace Approximation (LA). In the development of this package I have drawn inspiration from this Python library and its companion paper (Daxberger et al. 2021).

🚩 Installation

The stable version of this package can be installed as follows:

using Pkg
+Home · LaplaceRedux.jl

Documentation for LaplaceRedux.jl.

LaplaceRedux

LaplaceRedux.jl is a library written in pure Julia that can be used for effortless Bayesian Deep Learning through Laplace Approximation (LA). In the development of this package I have drawn inspiration from this Python library and its companion paper (Daxberger et al. 2021).

🚩 Installation

The stable version of this package can be installed as follows:

using Pkg
 Pkg.add("LaplaceRedux.jl")

The development version can be installed like so:

using Pkg
 Pkg.add("https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl")

🏃 Getting Started

If you are new to Deep Learning in Julia or simply prefer learning through videos, check out this awesome YouTube tutorial by doggo.jl 🐶. Additionally, you can also find a video of my presentation at JuliaCon 2022 on YouTube.

🖥️ Basic Usage

LaplaceRedux.jl can be used for any neural network trained in Flux.jl. Below we show basic usage examples involving two simple models for a regression and a classification task, respectively.

Regression

A complete worked example for a regression model can be found in the docs. Here we jump straight to Laplace Approximation and take the pre-trained model nn as given. Then LA can be implemented as follows, where we specify the model likelihood. The plot shows the fitted values overlaid with a 95% confidence interval. As expected, predictive uncertainty quickly increases in areas that are not populated by any training data.

la = Laplace(nn; likelihood=:regression)
 fit!(la, data)
@@ -14,4 +14,4 @@
 p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1))
 p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
 p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
-plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

📢 JuliaCon 2022

This project was presented at JuliaCon 2022 in July 2022. See here for details.

🛠️ Contribute

Contributions are very much welcome! Please follow the SciML ColPrac guide. You may want to start by having a look at any open issues.

🎓 References

Daxberger, Erik, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. 2021. “Laplace Redux-Effortless Bayesian Deep Learning.” Advances in Neural Information Processing Systems 34.

+plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

📢 JuliaCon 2022

This project was presented at JuliaCon 2022 in July 2022. See here for details.

🛠️ Contribute

Contributions are very much welcome! Please follow the SciML ColPrac guide. You may want to start by having a look at any open issues.

🎓 References

Daxberger, Erik, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. 2021. “Laplace Redux-Effortless Bayesian Deep Learning.” Advances in Neural Information Processing Systems 34.

diff --git a/dev/index_files/figure-commonmark/cell-4-output-1.svg b/dev/index_files/figure-commonmark/cell-4-output-1.svg index f4226f0e..5451cb31 100644 --- a/dev/index_files/figure-commonmark/cell-4-output-1.svg +++ b/dev/index_files/figure-commonmark/cell-4-output-1.svg @@ -1,202 +1,202 @@ - + - + - - + + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/index_files/figure-commonmark/cell-7-output-1.svg b/dev/index_files/figure-commonmark/cell-7-output-1.svg index 01c9e1a8..17bde69b 100644 --- a/dev/index_files/figure-commonmark/cell-7-output-1.svg +++ b/dev/index_files/figure-commonmark/cell-7-output-1.svg @@ -1,922 +1,946 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/mlj_interface.qmd b/dev/mlj_interface.qmd new file mode 100644 index 00000000..20c6b740 --- /dev/null +++ b/dev/mlj_interface.qmd @@ -0,0 +1,6 @@ + + +``` @meta +CurrentModule = LaplaceRedux +``` +# Interface to the MLJ framework \ No newline at end of file diff --git a/dev/mlj_interface/index.html b/dev/mlj_interface/index.html new file mode 100644 index 00000000..1c6d4153 --- /dev/null +++ b/dev/mlj_interface/index.html @@ -0,0 +1,2 @@ + +MLJ interface · LaplaceRedux.jl
diff --git a/dev/objects.inv b/dev/objects.inv index 4688855a..92d4f8bd 100644 Binary files a/dev/objects.inv and b/dev/objects.inv differ diff --git a/dev/reference/index.html b/dev/reference/index.html index fbe414e1..4280f153 100644 --- a/dev/reference/index.html +++ b/dev/reference/index.html @@ -1,21 +1,21 @@ -Reference · LaplaceRedux.jl

All functions and types

Exported functions

LaplaceRedux.LaplaceType
Laplace

Concrete type for Laplace approximation. This type is a subtype of AbstractLaplace and is used to store all the necessary information for a Laplace approximation.

Fields

  • model::Flux.Chain: The model to be approximated.
  • likelihood::Symbol: The likelihood function to be used.
  • est_params::EstimationParams: The estimation parameters.
  • prior::Prior: The parameters defining prior distribution.
  • posterior::Posterior: The posterior distribution.
source
LaplaceRedux.LaplaceMethod
Laplace(model::Any; likelihood::Symbol, kwargs...)

Outer constructor for Laplace approximation. This function constructs a Laplace object from a given model and likelihood function.

Arguments

  • model::Any: The model to be approximated (a Flux.Chain).
  • likelihood::Symbol: The likelihood function to be used. Possible values are :regression and :classification.

Keyword Arguments

See LaplaceParams for a description of the keyword arguments.

Returns

  • la::Laplace: The Laplace object.

Examples

using Flux, LaplaceRedux
+Reference · LaplaceRedux.jl

All functions and types

Exported functions

LaplaceRedux.LaplaceType
Laplace

Concrete type for Laplace approximation. This type is a subtype of AbstractLaplace and is used to store all the necessary information for a Laplace approximation.

Fields

  • model::Flux.Chain: The model to be approximated.
  • likelihood::Symbol: The likelihood function to be used.
  • est_params::EstimationParams: The estimation parameters.
  • prior::Prior: The parameters defining prior distribution.
  • posterior::Posterior: The posterior distribution.
source
LaplaceRedux.LaplaceMethod
Laplace(model::Any; likelihood::Symbol, kwargs...)

Outer constructor for Laplace approximation. This function constructs a Laplace object from a given model and likelihood function.

Arguments

  • model::Any: The model to be approximated (a Flux.Chain).
  • likelihood::Symbol: The likelihood function to be used. Possible values are :regression and :classification.

Keyword Arguments

See LaplaceParams for a description of the keyword arguments.

Returns

  • la::Laplace: The Laplace object.

Examples

using Flux, LaplaceRedux
 nn = Chain(Dense(2,1))
-la = Laplace(nn, likelihood=:regression)
source
LaplaceRedux.LaplaceClassificationType
MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic

A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.

The model is defined by the following default parameters for all MLJFlux models:

  • builder: a Flux model that constructs the neural network.
  • finaliser: a Flux model that processes the output of the neural network.
  • optimiser: a Flux optimiser.
  • loss: a loss function that takes the predicted output and the true output as arguments.
  • epochs: the number of epochs.
  • batch_size: the size of a batch.
  • lambda: the regularization strength.
  • alpha: the regularization mix (0 for all l2, 1 for all l1).
  • rng: a random number generator.
  • optimiser_changes_trigger_retraining: a boolean indicating whether changes in the optimiser trigger retraining.
  • acceleration: the computational resource to use.

The model also has the following parameters, which are specific to the Laplace approximation:

  • subset_of_weights: the subset of weights to use, either :all, :last_layer, or :subnetwork.
  • subnetwork_indices: the indices of the subnetworks.
  • hessian_structure: the structure of the Hessian matrix, either :full or :diagonal.
  • backend: the backend to use, either :GGN or :EmpiricalFisher.
  • σ: the standard deviation of the prior distribution.
  • μ₀: the mean of the prior distribution.
  • P₀: the covariance matrix of the prior distribution.
  • link_approx: the link approximation to use, either :probit or :plugin.
  • predict_proba: a boolean that select whether to predict probabilities or not.
  • ret_distr: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
  • fit_prior_nsteps: the number of steps used to fit the priors.
source
LaplaceRedux.LaplaceRegressionType
MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic

A mutable struct representing a Laplace regression model that extends the MLJFlux.MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.

The model is defined by the following default parameters for all MLJFlux models:

  • builder: a Flux model that constructs the neural network.
  • optimiser: a Flux optimiser.
  • loss: a loss function that takes the predicted output and the true output as arguments.
  • epochs: the number of epochs.
  • batch_size: the size of a batch.
  • lambda: the regularization strength.
  • alpha: the regularization mix (0 for all l2, 1 for all l1).
  • rng: a random number generator.
  • optimiser_changes_trigger_retraining: a boolean indicating whether changes in the optimiser trigger retraining.
  • acceleration: the computational resource to use.

The model also has the following parameters, which are specific to the Laplace approximation:

  • subset_of_weights: the subset of weights to use, either :all, :last_layer, or :subnetwork.
  • subnetwork_indices: the indices of the subnetworks.
  • hessian_structure: the structure of the Hessian matrix, either :full or :diagonal.
  • backend: the backend to use, either :GGN or :EmpiricalFisher.
  • σ: the standard deviation of the prior distribution.
  • μ₀: the mean of the prior distribution.
  • P₀: the covariance matrix of the prior distribution.
  • ret_distr: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
  • fit_prior_nsteps: the number of steps used to fit the priors.
source
LaplaceRedux.fit!Method
fit!(la::AbstractLaplace,data)

Fits the Laplace approximation for a data set. The function returns the number of observations (n_data) that were used to update the Laplace object. It does not return the updated Laplace object itself because the function modifies the input Laplace object in place (as denoted by the use of '!' in the function's name).

Examples

using Flux, LaplaceRedux
+la = Laplace(nn, likelihood=:regression)
source
LaplaceRedux.LaplaceClassificationType
MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic

A mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.

The model is defined by the following default parameters for all MLJFlux models:

  • builder: a Flux model that constructs the neural network.
  • finaliser: a Flux model that processes the output of the neural network.
  • optimiser: a Flux optimiser.
  • loss: a loss function that takes the predicted output and the true output as arguments.
  • epochs: the number of epochs.
  • batch_size: the size of a batch.
  • lambda: the regularization strength.
  • alpha: the regularization mix (0 for all l2, 1 for all l1).
  • rng: a random number generator.
  • optimiser_changes_trigger_retraining: a boolean indicating whether changes in the optimiser trigger retraining.
  • acceleration: the computational resource to use.

The model also has the following parameters, which are specific to the Laplace approximation:

  • subset_of_weights: the subset of weights to use, either :all, :last_layer, or :subnetwork.
  • subnetwork_indices: the indices of the subnetworks.
  • hessian_structure: the structure of the Hessian matrix, either :full or :diagonal.
  • backend: the backend to use, either :GGN or :EmpiricalFisher.
  • σ: the standard deviation of the prior distribution.
  • μ₀: the mean of the prior distribution.
  • P₀: the covariance matrix of the prior distribution.
  • link_approx: the link approximation to use, either :probit or :plugin.
  • predict_proba: a boolean that select whether to predict probabilities or not.
  • ret_distr: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
  • fit_prior_nsteps: the number of steps used to fit the priors.
source
LaplaceRedux.LaplaceRegressionType
MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic

A mutable struct representing a Laplace regression model that extends the MLJFlux.MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network.

The model is defined by the following default parameters for all MLJFlux models:

  • builder: a Flux model that constructs the neural network.
  • optimiser: a Flux optimiser.
  • loss: a loss function that takes the predicted output and the true output as arguments.
  • epochs: the number of epochs.
  • batch_size: the size of a batch.
  • lambda: the regularization strength.
  • alpha: the regularization mix (0 for all l2, 1 for all l1).
  • rng: a random number generator.
  • optimiser_changes_trigger_retraining: a boolean indicating whether changes in the optimiser trigger retraining.
  • acceleration: the computational resource to use.

The model also has the following parameters, which are specific to the Laplace approximation:

  • subset_of_weights: the subset of weights to use, either :all, :last_layer, or :subnetwork.
  • subnetwork_indices: the indices of the subnetworks.
  • hessian_structure: the structure of the Hessian matrix, either :full or :diagonal.
  • backend: the backend to use, either :GGN or :EmpiricalFisher.
  • σ: the standard deviation of the prior distribution.
  • μ₀: the mean of the prior distribution.
  • P₀: the covariance matrix of the prior distribution.
  • ret_distr: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.
  • fit_prior_nsteps: the number of steps used to fit the priors.
source
LaplaceRedux.empirical_frequency_binary_classificationMethod
empirical_frequency_binary_classification(y_binary, distributions::Vector{Bernoulli{Float64}}; n_bins::Int=20)

FOR BINARY CLASSIFICATION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ let $p_t= H(x_t)∈[0,1]$ be the forecasted probability.
We group the $p_t$ into intervals $I_j$ for $j= 1,2,...,m$ that form a partition of [0,1]. The function computes the observed average $p_j= T^-1_j ∑_{t:p_t ∈ I_j} y_j$ in each interval $I_j$.
Source: Kuleshov, Fenner, Ermon 2018

Inputs:
- y_binary: the array of outputs $y_t$ numerically coded: 1 for the target class, 0 for the null class.
- distributions: an array of Bernoulli distributions
- n_bins: number of equally spaced bins to use.

Outputs:
- num_p_per_interval: array with the number of probabilities falling within interval.
- emp_avg: array with the observed empirical average per interval.
- bin_centers: array with the centers of the bins.

source
LaplaceRedux.empirical_frequency_regressionMethod
empirical_frequency_regression(Y_cal, distributions::Distributions.Normal, n_bins=20)

Dispatched version for Normal distributions FOR REGRESSION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ and an array of predicted distributions, the function calculates the empirical frequency

\[p^hat_j = {y_t|F_t(y_t)<= p_j, t= 1,....,T}/T,\]

where $T$ is the number of calibration points, $p_j$ is the confidence level and $F_t$ is the cumulative distribution function of the predicted distribution targeting $y_t$.
Source: Kuleshov, Fenner, Ermon 2018

Inputs:
- Y_cal: a vector of values $y_t$
- distributions:a Vector{Distributions.Normal{Float64}} of distributions stacked row-wise.
For example the output of LaplaceRedux.predict(la,Xcal).
- `n
bins`: number of equally spaced bins to use.
Outputs:
- counts: an array cointaining the empirical frequencies for each quantile interval.

source
LaplaceRedux.extract_mean_and_varianceMethod
extract_mean_and_variance(distr::Vector{Normal{<: AbstractFloat}})

Extract the mean and the variance of each distributions and return them in two separate lists.

Inputs: - distributions: a Vector of Normal distributions

Outputs: - means: the list of the means - variances: the list of the variances

source
LaplaceRedux.fit!Method
fit!(la::AbstractLaplace,data)

Fits the Laplace approximation for a data set. The function returns the number of observations (n_data) that were used to update the Laplace object. It does not return the updated Laplace object itself because the function modifies the input Laplace object in place (as denoted by the use of '!' in the function's name).

Examples

using Flux, LaplaceRedux
 x, y = LaplaceRedux.Data.toy_data_linear()
 data = zip(x,y)
 nn = Chain(Dense(2,1))
 la = Laplace(nn)
-fit!(la, data)
source
LaplaceRedux.glm_predictive_distributionMethod
glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)

Computes the linearized GLM predictive.

Arguments

  • la::AbstractLaplace: A Laplace object.
  • X::AbstractArray: Input data.

Returns

  • normal_distr A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.
  • fμ::AbstractArray: Mean of the predictive distribution. The output shape is column-major as in Flux.
  • fvar::AbstractArray: Variance of the predictive distribution. The output shape is column-major as in Flux.

Examples

```julia-repl using Flux, LaplaceRedux using LaplaceRedux.Data: toydatalinear x, y = toydatalinear() data = zip(x,y) nn = Chain(Dense(2,1)) la = Laplace(nn; likelihood=:classification) fit!(la, data) glmpredictivedistribution(la, hcat(x...))

source
LaplaceRedux.glm_predictive_distributionMethod
glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)

Computes the linearized GLM predictive.

Arguments

  • la::AbstractLaplace: A Laplace object.
  • X::AbstractArray: Input data.

Returns

  • normal_distr A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.- normal_distr A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.
  • fμ::AbstractArray: Mean of the predictive distribution. The output shape is column-major as in Flux.
  • fvar::AbstractArray: Variance of the predictive distribution. The output shape is column-major as in Flux.

Examples

```julia-repl using Flux, LaplaceRedux using LaplaceRedux.Data: toydatalinear x, y = toydatalinear() data = zip(x,y) nn = Chain(Dense(2,1)) la = Laplace(nn; likelihood=:classification) fit!(la, data) glmpredictivedistribution(la, hcat(x...))

source
LaplaceRedux.optimize_prior!Method
optimize_prior!(
     la::AbstractLaplace; 
     n_steps::Int=100, lr::Real=1e-1,
     λinit::Union{Nothing,Real}=nothing,
     σinit::Union{Nothing,Real}=nothing
-)

Optimize the prior precision post-hoc through Empirical Bayes (marginal log-likelihood maximization).

source
LaplaceRedux.posterior_covarianceFunction
posterior_covariance(la::AbstractLaplace, P=la.P)

Computes the posterior covariance $∑$ as the inverse of the posterior precision: $\Sigma=P^{-1}$.

source
LaplaceRedux.posterior_precisionFunction
posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.P₀)

Computes the posterior precision $P$ for a fitted Laplace Approximation as follows,

$P = \sum_{n=1}^N\nabla_{\theta}^2 \log p(\mathcal{D}_n|\theta)|_{\hat\theta} + \nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}$

where $\sum_{n=1}^N\nabla_{\theta}^2\log p(\mathcal{D}_n|\theta)|_{\hat\theta}=H$ is the Hessian and $\nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}=P_0$ is the prior precision and $\hat\theta$ is the MAP estimate.

source
LaplaceRedux.predictMethod
predict(la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true)

Computes predictions from Bayesian neural network.

Arguments

  • la::AbstractLaplace: A Laplace object.
  • X::AbstractArray: Input data.
  • link_approx::Symbol=:probit: Link function approximation. Options are :probit and :plugin.
  • predict_proba::Bool=true: If true (default) apply a sigmoid or a softmax function to the output of the Flux model.
  • return_distr::Bool=false: if false (default), the function output either the direct output of the chain or pseudo-probabilities (if predict_proba= true). if true predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.

Returns

For classification tasks, LaplaceRedux provides different options: if retdistr is false: - fμ::AbstractArray: Mean of the predictive distribution if link function is set to :plugin, otherwise the probit approximation. The output shape is column-major as in Flux. if retdistr is true: - a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. For regression tasks:

  • normal_distr::Distributions.Normal:the array of Normal distributions computed by glmpredictivedistribution.

Examples

using Flux, LaplaceRedux
+)

Optimize the prior precision post-hoc through Empirical Bayes (marginal log-likelihood maximization).

source
LaplaceRedux.posterior_covarianceFunction
posterior_covariance(la::AbstractLaplace, P=la.P)

Computes the posterior covariance $∑$ as the inverse of the posterior precision: $\Sigma=P^{-1}$.

source
LaplaceRedux.posterior_precisionFunction
posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.P₀)

Computes the posterior precision $P$ for a fitted Laplace Approximation as follows,

$P = \sum_{n=1}^N\nabla_{\theta}^2 \log p(\mathcal{D}_n|\theta)|_{\hat\theta} + \nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}$

where $\sum_{n=1}^N\nabla_{\theta}^2\log p(\mathcal{D}_n|\theta)|_{\hat\theta}=H$ is the Hessian and $\nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}=P_0$ is the prior precision and $\hat\theta$ is the MAP estimate.

source
LaplaceRedux.predictMethod
predict(la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true)

Computes predictions from Bayesian neural network.

Arguments

  • la::AbstractLaplace: A Laplace object.
  • X::AbstractArray: Input data.
  • link_approx::Symbol=:probit: Link function approximation. Options are :probit and :plugin.
  • predict_proba::Bool=true: If true (default) apply a sigmoid or a softmax function to the output of the Flux model.
  • return_distr::Bool=false: if false (default), the function output either the direct output of the chain or pseudo-probabilities (if predict_proba= true). if true predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.

Returns

For classification tasks, LaplaceRedux provides different options: if retdistr is false: - fμ::AbstractArray: Mean of the predictive distribution if link function is set to :plugin, otherwise the probit approximation. The output shape is column-major as in Flux. if retdistr is true: - a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. For regression tasks:

  • normal_distr::Distributions.Normal:the array of Normal distributions computed by glmpredictivedistribution.

Examples

using Flux, LaplaceRedux
 using LaplaceRedux.Data: toy_data_linear
 x, y = toy_data_linear()
 data = zip(x,y)
 nn = Chain(Dense(2,1))
 la = Laplace(nn; likelihood=:classification)
 fit!(la, data)
-predict(la, hcat(x...))
source
MLJModelInterface.predictMethod
predict(model::LaplaceClassification, Xnew)

Predicts the class labels for new data using the LaplaceClassification model.

Arguments

  • model::LaplaceClassification: The trained LaplaceClassification model.
  • fitresult: the fitresult output produced by MLJFlux.fit!
  • Xnew: The new data to make predictions on.

Returns

An array of predicted class labels.

source
MLJModelInterface.predictMethod
predict(model::LaplaceRegression, Xnew)

Predict the output for new input data using a Laplace regression model.

Arguments

  • model::LaplaceRegression: The trained Laplace regression model.
  • the fitresult output produced by MLJFlux.fit!
  • Xnew: The new input data.

Returns

  • The predicted output for the new input data.
source

Internal functions

LaplaceRedux.AbstractLaplaceMethod
(la::AbstractLaplace)(X::AbstractArray)

Calling a model with Laplace Approximation on an array of inputs is equivalent to explicitly calling the predict function.

source
LaplaceRedux.EstimationParamsType
EstimationParams

Container for the parameters of a Laplace approximation.

Fields

  • subset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.
  • subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.
  • hessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.
  • curvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.
source
LaplaceRedux.KronType

Kronecker-factored approximate curvature representation for a neural network model. Each element in kfacs represents two Kronecker factors (𝐆, 𝐀), such that the full block Hessian approximation would be approximated as 𝐀⊗𝐆.

source
LaplaceRedux.KronDecomposedType
KronDecomposed

Decomposed Kronecker-factored approximate curvature representation for a neural network model.

Decomposition is required to add the prior (diagonal matrix) to the posterior (KronDecomposed). It also has the benefits of reducing the costs for computation of inverses and log-determinants.

source
LaplaceRedux.LaplaceParamsType
LaplaceParams

Container for the parameters of a Laplace approximation.

Fields

  • subset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.
  • subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.
  • hessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.
  • backend::Symbol: the backend to use. Possible values are :GGN and :Fisher.
  • curvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.
  • σ::Real: the observation noise
  • μ₀::Real: the prior mean
  • λ::Real: the prior precision
  • P₀::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix
source
LaplaceRedux.PosteriorType
Posterior

Container for the results of a Laplace approximation.

Fields

  • μ::AbstractVector: the MAP estimate of the parameters
  • H::Union{AbstractArray,AbstractDecomposition,Nothing}: the Hessian matrix
  • P::Union{AbstractArray,AbstractDecomposition,Nothing}: the posterior precision matrix
  • Σ::Union{AbstractArray,Nothing}: the posterior covariance matrix
  • n_data::Union{Int,Nothing}: the number of data points
  • n_params::Union{Int,Nothing}: the number of parameters
  • n_out::Union{Int,Nothing}: the number of outputs
  • loss::Real: the loss value
source
LaplaceRedux.PriorType
Prior

Container for the prior parameters of a Laplace approximation.

Fields

  • σ::Real: the observation noise
  • μ₀::Real: the prior mean
  • λ::Real: the prior precision
  • P₀::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix
source
LaplaceRedux.PriorMethod
Prior(params::LaplaceParams)

Extracts the prior parameters from a LaplaceParams object.

source
Base.:*Method

Multiply by a scalar by changing the eigenvalues. Distribute the scalar along the factors of a block.

source
Base.:*Method

Kronecker-factored curvature scalar scaling.

source
Base.:+Method

Shift the factors by a diagonal (assumed uniform scaling)

source
Base.:+Method

Shift the factors by a scalar across the diagonal.

source
Flux.paramsMethod
Flux.params(model::Any, params::EstimationParams)

Extracts the parameters of a model based on the subset of weights specified in the EstimationParams object.

source
Flux.paramsMethod
Flux.params(la::Laplace)

Overloads the params function for a Laplace object.

source
LaplaceRedux._H_factorMethod
_H_factor(la::AbstractLaplace)

Returns the factor σ⁻², where σ is used in the zero-centered Gaussian prior p(θ) = N(θ;0,σ²I)

source
LaplaceRedux._fit!Method
_fit!(la::Laplace, hessian_structure::FullHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)

Fit a Laplace approximation to the posterior distribution of a model using the full Hessian.

source
LaplaceRedux._fit!Method
_fit!(la::Laplace, hessian_structure::KronHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)

Fit a Laplace approximation to the posterior distribution of a model using the Kronecker-factored Hessian.

source
LaplaceRedux._weight_penaltyMethod
_weight_penalty(la::AbstractLaplace)

The weight penalty term is a regularization term used to prevent overfitting. Weight regularization methods such as weight decay introduce a penalty to the loss function when training a neural network to encourage the network to use small weights. Smaller weights in a neural network can result in a model that is more stable and less likely to overfit the training dataset, in turn having better performance when making a prediction on new data.

source
LaplaceRedux.approximateMethod
approximate(curvature::CurvatureInterface, hessian_structure::FullHessian, d::Tuple; batched::Bool=false)

Compute the full approximation, for either a single input-output datapoint or a batch of such.

source
LaplaceRedux.approximateMethod
approximate(curvature::CurvatureInterface, hessian_structure::KronHessian, data; batched::Bool=false)

Compute the eigendecomposed Kronecker-factored approximate curvature as the Fisher information matrix.

Note, since the network predictive distribution is used in a weighted sum, and the number of backward passes is linear in the number of target classes, e.g. 100 for CIFAR-100.

source
LaplaceRedux.clampMethod

Clamp eigenvalues in an eigendecomposition to be non-negative.

Since the Fisher information matrix is a positive-semidefinite by construction, the (near-zero) negative eigenvalues should be neglected.

source
LaplaceRedux.convert_subnetwork_indicesMethod

convertsubnetworkindices(subnetwork_indices::AbstractArray)

Converts the subnetwork indices from the user given format [theta, row, column] to an Int i that corresponds to the index of that weight in the flattened array of weights.

source
LaplaceRedux.functional_varianceMethod
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)

Compute the functional variance for the GLM predictive. Dispatches to the appropriate method based on the Hessian structure.

source
LaplaceRedux.functional_varianceMethod

functional_variance(la::Laplace,𝐉)

Compute the linearized GLM predictive variance as 𝐉ₙΣ𝐉ₙ' where 𝐉=∇f(x;θ)|θ̂ is the Jacobian evaluated at the MAP estimate and Σ = P⁻¹.

source
LaplaceRedux.functional_varianceMethod

functionalvariance(la::Laplace, hessianstructure::KronHessian, 𝐉::Matrix)

Compute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ, where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output 𝐉=∇f(x;θ)|θ̂.

source
LaplaceRedux.get_map_estimateMethod
get_map_estimate(model::Any, est_params::EstimationParams)

Helper function to extract the MAP estimate of the parameters for the model based on the subset of weights specified in the EstimationParams object.

source
LaplaceRedux.has_softmax_or_sigmoid_final_layerMethod
has_softmax_or_sigmoid_final_layer(model::Flux.Chain)

Check if the FLux model ends with a sigmoid or with a softmax layer

Input: - model: the Flux Chain object that represent the neural network. Return: - has_finaliser: true if the check is positive, false otherwise.

source
LaplaceRedux.instantiate_curvature!Method
instantiate_curvature!(params::EstimationParams, model::Any, likelihood::Symbol, backend::Symbol)

Instantiates the curvature interface for a Laplace approximation. The curvature interface is a concrete subtype of CurvatureInterface and is used to compute the Hessian matrix. The curvature interface is stored in the curvature field of the EstimationParams object.

source
LaplaceRedux.inv_square_formMethod

function invsquareform(K::KronDecomposed, W::Matrix)

Special function to compute the inverse square form 𝐉𝐏⁻¹𝐉ᵀ (or 𝐖𝐊⁻¹𝐖ᵀ)

source
LaplaceRedux.logdetblockMethod
logdetblock(block::Tuple{Eigen,Eigen}, delta::Number)

Log-determinant of a block in KronDecomposed, shifted by delta by on the diagonal.

source
LaplaceRedux.mmMethod

Matrix-multuply for the KronDecomposed Hessian approximation K and a 2-d matrix W, applying an exponent to K and transposing W before multiplication. Return (K^x)W^T, where x is the exponent.

source
LaplaceRedux.n_paramsMethod
n_params(model::Any, params::EstimationParams)

Helper function to determine the number of parameters of a Flux.Chain with Laplace approximation.

source
LaplaceRedux.outdimMethod
outdim(model::Chain)

Helper function to determine the output dimension of a Flux.Chain, corresponding to the number of neurons on the last layer of the NN.

source
LaplaceRedux.outdimMethod
outdim(la::AbstractLaplace)

Helper function to determine the output dimension, corresponding to the number of neurons on the last layer of the NN, of a Flux.Chain with Laplace approximation.

source
LaplaceRedux.probitMethod
probit(fμ::AbstractArray, fvar::AbstractArray)

Compute the probit approximation of the predictive distribution.

source
LinearAlgebra.detMethod
det(K::KronDecomposed)

Log-determinant of the KronDecomposed block-diagonal matrix, as the exponentiated log-determinant.

source
LinearAlgebra.logdetMethod
logdet(K::KronDecomposed)

Log-determinant of the KronDecomposed block-diagonal matrix, as the product of the determinants of the blocks

source
MLJFlux.buildMethod
MLJFlux.build(model::LaplaceClassification, rng, shape)

Builds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by shape.

Arguments

  • model::LaplaceClassification: The Laplace classification model.
  • rng: A random number generator to ensure reproducibility.
  • shape: A tuple or array specifying the dimensions of the input and output layers.

Returns

  • The constructed MLJFlux model, compatible with the specified input and output dimensions.
source
MLJFlux.buildMethod
MLJFlux.build(model::LaplaceRegression, rng, shape)

Builds an MLJFlux model for Laplace regression compatible with the dimensions of the input and output layers specified by shape.

Arguments

  • model::LaplaceRegression: The Laplace regression model.
  • rng: A random number generator to ensure reproducibility.
  • shape: A tuple or array specifying the dimensions of the input and output layers.

Returns

  • The constructed MLJFlux model, compatible with the specified input and output dimensions.
source
MLJFlux.fitresultMethod
MLJFlux.fitresult(model::LaplaceClassification, chain, y)

Computes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data.

Arguments

  • model::LaplaceClassification: The Laplace classification model to be evaluated.
  • chain: The trained model chain.
  • y: The target data, typically a vector of class labels.

Returns

Returns

A tuple containing:

  • The trained Flux chain.
  • a deepcopy of the laplace model.
source
MLJFlux.fitresultMethod
MLJFlux.fitresult(model::LaplaceRegression, chain, y)

Computes the fit result for a Laplace Regression model, returning the model chain and the number of output variables in the target data.

Arguments

  • model::LaplaceRegression: The Laplace Regression model to be evaluated.
  • chain: The trained model chain.
  • y: The target data, typically a vector of class labels.

Returns

A tuple containing:

  • The trained Flux chain.
  • a deepcopy of the laplace model.
source
MLJFlux.shapeMethod
MLJFlux.shape(model::LaplaceRegression, X, y)

Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.

Arguments

  • model::LaplaceRegression: The LaplaceRegression model to fit.
  • X: The input data for training.
  • y: The target labels for training one-hot encoded.

Returns

  • (input size, output size)
source
MLJFlux.trainMethod
MLJFlux.train(model::LaplaceClassification, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)

Fit the LaplaceRegression model using Flux.jl.

Arguments

  • model::LaplaceClassification: The LaplaceClassification model.
  • regularized_optimiser: the regularized optimiser to apply to the loss function.
  • optimiser_state: thestate of the optimiser.
  • epochs: The number of epochs for training.
  • verbosity: The verbosity level for training.
  • X: The input data for training.
  • y: The target labels for training.

Returns (fitresult, cache, report )

where

  • la: the fitted Laplace model.
  • optimiser_state: the state of the optimiser.
  • history: the training loss history.
source
MLJFlux.trainMethod
MLJFlux.train(model::LaplaceRegression, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)

Fit the LaplaceRegression model using Flux.jl.

Arguments

  • model::LaplaceRegression: The LaplaceRegression model.
  • regularized_optimiser: the regularized optimiser to apply to the loss function.
  • optimiser_state: thestate of the optimiser.
  • epochs: The number of epochs for training.
  • verbosity: The verbosity level for training.
  • X: The input data for training.
  • y: The target labels for training.

Returns (la, optimiser_state, history )

where

  • la: the fitted Laplace model.
  • optimiser_state: the state of the optimiser.
  • history: the training loss history.
source
LaplaceRedux.Curvature.gradientsMethod
gradients(curvature::CurvatureInterface, X::AbstractArray, y::Number)

Compute the gradients with respect to the loss function: ∇ℓ(f(x;θ),y) where f: ℝᴰ ↦ ℝᴷ.

source
LaplaceRedux.Curvature.jacobians_unbatchedMethod
jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray)

Compute the Jacobian of the model output w.r.t. model parameters for the point X, without batching. Here, the nn function is wrapped in an anonymous function using the () -> syntax, which allows it to be differentiated using automatic differentiation.

source
+predict(la, hcat(x...))
source
LaplaceRedux.rescale_stddevMethod
rescale_stddev(distr::Vector{Normal{T}}, s::T) where {T<:AbstractFloat}

Rescale the standard deviation of the Normal distributions received as argument and return a vector of rescaled Normal distributions. Inputs:
- distr: a Vector of Normal distributions - s: a scale factor of type T.

Outputs:
- Vector{Normal{T}}: a Vector of rescaled Normal distributions.

source
LaplaceRedux.sharpness_classificationMethod
sharpness_classification(y_binary,distributions::Distributions.Bernoulli)

dispatched for Bernoulli Distributions FOR BINARY CLASSIFICATION MODELS.
Assess the sharpness of the model by looking at the distribution of model predictions. When forecasts are sharp, most predictions are close to either 0 or 1
Source: Kuleshov, Fenner, Ermon 2018

Inputs:
- y_binary: the array of outputs $y_t$ numerically coded: 1 for the target class, 0 for the negative result.
- distributions: an array of Bernoulli distributions describing the probability of of the output belonging to the target class
Outputs:
- mean_class_one: a scalar that measure the average prediction for the target class
- mean_class_zero: a scalar that measure the average prediction for the null class

source
LaplaceRedux.sharpness_regressionMethod
sharpness_regression(distributions::Distributions.Normal)

Dispatched version for Normal distributions FOR REGRESSION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ and an array of predicted distributions, the function calculates the sharpness of the predicted distributions, i.e., the average of the variances $\sigma^2(F_t)$ predicted by the forecaster for each $x_t$.
source: Kuleshov, Fenner, Ermon 2018

Inputs:
- distributions: an array of normal distributions $F(x_t)$ stacked row-wise.
Outputs:
- sharpness: a scalar that measure the level of sharpness of the regressor

source
LaplaceRedux.sigma_scalingMethod
sigma_scaling(distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat}) where T <: AbstractFloat

Compute the value of Σ that maximize the conditional log-likelihood:

\[ m ln(Σ) +1/2 * Σ^{-2} ∑_{i=1}^{i=m} || y_cal_i - ̄y_mean_i ||^2 / σ^2_i \]

where m is the number of elements in the calibration set (xcal,ycal).
Source: Laves,Ihler,Fast, Kahrs, Ortmaier,2020 Inputs:
- distr: a Vector of Normal distributions
- y_cal: a Vector of true results.

Outputs:
- sigma: the scalar that maximize the likelihood.

source
MLJModelInterface.predictMethod
predict(model::LaplaceClassification, Xnew)

Predicts the class labels for new data using the LaplaceClassification model.

Arguments

  • model::LaplaceClassification: The trained LaplaceClassification model.
  • fitresult: the fitresult output produced by MLJFlux.fit!
  • Xnew: The new data to make predictions on.

Returns

An array of predicted class labels.

source
MLJModelInterface.predictMethod
predict(model::LaplaceRegression, Xnew)

Predict the output for new input data using a Laplace regression model.

Arguments

  • model::LaplaceRegression: The trained Laplace regression model.
  • the fitresult output produced by MLJFlux.fit!
  • Xnew: The new input data.

Returns

  • The predicted output for the new input data.
source

Internal functions

LaplaceRedux.AbstractLaplaceMethod
(la::AbstractLaplace)(X::AbstractArray)

Calling a model with Laplace Approximation on an array of inputs is equivalent to explicitly calling the predict function.

source
LaplaceRedux.EstimationParamsType
EstimationParams

Container for the parameters of a Laplace approximation.

Fields

  • subset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.
  • subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.
  • hessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.
  • curvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.
source
LaplaceRedux.KronType

Kronecker-factored approximate curvature representation for a neural network model. Each element in kfacs represents two Kronecker factors (𝐆, 𝐀), such that the full block Hessian approximation would be approximated as 𝐀⊗𝐆.

source
LaplaceRedux.KronDecomposedType
KronDecomposed

Decomposed Kronecker-factored approximate curvature representation for a neural network model.

Decomposition is required to add the prior (diagonal matrix) to the posterior (KronDecomposed). It also has the benefits of reducing the costs for computation of inverses and log-determinants.

source
LaplaceRedux.LaplaceParamsType
LaplaceParams

Container for the parameters of a Laplace approximation.

Fields

  • subset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.
  • subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.
  • hessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.
  • backend::Symbol: the backend to use. Possible values are :GGN and :Fisher.
  • curvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.
  • σ::Real: the observation noise
  • μ₀::Real: the prior mean
  • λ::Real: the prior precision
  • P₀::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix
source
LaplaceRedux.PosteriorType
Posterior

Container for the results of a Laplace approximation.

Fields

  • μ::AbstractVector: the MAP estimate of the parameters
  • H::Union{AbstractArray,AbstractDecomposition,Nothing}: the Hessian matrix
  • P::Union{AbstractArray,AbstractDecomposition,Nothing}: the posterior precision matrix
  • Σ::Union{AbstractArray,Nothing}: the posterior covariance matrix
  • n_data::Union{Int,Nothing}: the number of data points
  • n_params::Union{Int,Nothing}: the number of parameters
  • n_out::Union{Int,Nothing}: the number of outputs
  • loss::Real: the loss value
source
LaplaceRedux.PriorType
Prior

Container for the prior parameters of a Laplace approximation.

Fields

  • σ::Real: the observation noise
  • μ₀::Real: the prior mean
  • λ::Real: the prior precision
  • P₀::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix
source
LaplaceRedux.PriorMethod
Prior(params::LaplaceParams)

Extracts the prior parameters from a LaplaceParams object.

source
Base.:*Method

Multiply by a scalar by changing the eigenvalues. Distribute the scalar along the factors of a block.

source
Base.:*Method

Kronecker-factored curvature scalar scaling.

source
Base.:+Method

Shift the factors by a diagonal (assumed uniform scaling)

source
Base.:+Method

Shift the factors by a scalar across the diagonal.

source
Flux.paramsMethod
Flux.params(model::Any, params::EstimationParams)

Extracts the parameters of a model based on the subset of weights specified in the EstimationParams object.

source
Flux.paramsMethod
Flux.params(la::Laplace)

Overloads the params function for a Laplace object.

source
LaplaceRedux._H_factorMethod
_H_factor(la::AbstractLaplace)

Returns the factor σ⁻², where σ is used in the zero-centered Gaussian prior p(θ) = N(θ;0,σ²I)

source
LaplaceRedux._fit!Method
_fit!(la::Laplace, hessian_structure::FullHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)

Fit a Laplace approximation to the posterior distribution of a model using the full Hessian.

source
LaplaceRedux._fit!Method
_fit!(la::Laplace, hessian_structure::KronHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)

Fit a Laplace approximation to the posterior distribution of a model using the Kronecker-factored Hessian.

source
LaplaceRedux._weight_penaltyMethod
_weight_penalty(la::AbstractLaplace)

The weight penalty term is a regularization term used to prevent overfitting. Weight regularization methods such as weight decay introduce a penalty to the loss function when training a neural network to encourage the network to use small weights. Smaller weights in a neural network can result in a model that is more stable and less likely to overfit the training dataset, in turn having better performance when making a prediction on new data.

source
LaplaceRedux.approximateMethod
approximate(curvature::CurvatureInterface, hessian_structure::FullHessian, d::Tuple; batched::Bool=false)

Compute the full approximation, for either a single input-output datapoint or a batch of such.

source
LaplaceRedux.approximateMethod
approximate(curvature::CurvatureInterface, hessian_structure::KronHessian, data; batched::Bool=false)

Compute the eigendecomposed Kronecker-factored approximate curvature as the Fisher information matrix.

Note, since the network predictive distribution is used in a weighted sum, and the number of backward passes is linear in the number of target classes, e.g. 100 for CIFAR-100.

source
LaplaceRedux.clampMethod

Clamp eigenvalues in an eigendecomposition to be non-negative.

Since the Fisher information matrix is a positive-semidefinite by construction, the (near-zero) negative eigenvalues should be neglected.

source
LaplaceRedux.convert_subnetwork_indicesMethod

convertsubnetworkindices(subnetwork_indices::AbstractArray)

Converts the subnetwork indices from the user given format [theta, row, column] to an Int i that corresponds to the index of that weight in the flattened array of weights.

source
LaplaceRedux.functional_varianceMethod
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)

Compute the functional variance for the GLM predictive. Dispatches to the appropriate method based on the Hessian structure.

source
LaplaceRedux.functional_varianceMethod

functional_variance(la::Laplace,𝐉)

Compute the linearized GLM predictive variance as 𝐉ₙΣ𝐉ₙ' where 𝐉=∇f(x;θ)|θ̂ is the Jacobian evaluated at the MAP estimate and Σ = P⁻¹.

source
LaplaceRedux.functional_varianceMethod

functionalvariance(la::Laplace, hessianstructure::KronHessian, 𝐉::Matrix)

Compute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ, where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output 𝐉=∇f(x;θ)|θ̂.

source
LaplaceRedux.get_map_estimateMethod
get_map_estimate(model::Any, est_params::EstimationParams)

Helper function to extract the MAP estimate of the parameters for the model based on the subset of weights specified in the EstimationParams object.

source
LaplaceRedux.has_softmax_or_sigmoid_final_layerMethod
has_softmax_or_sigmoid_final_layer(model::Flux.Chain)

Check if the FLux model ends with a sigmoid or with a softmax layer

Input: - model: the Flux Chain object that represent the neural network. Return: - has_finaliser: true if the check is positive, false otherwise.

source
LaplaceRedux.instantiate_curvature!Method
instantiate_curvature!(params::EstimationParams, model::Any, likelihood::Symbol, backend::Symbol)

Instantiates the curvature interface for a Laplace approximation. The curvature interface is a concrete subtype of CurvatureInterface and is used to compute the Hessian matrix. The curvature interface is stored in the curvature field of the EstimationParams object.

source
LaplaceRedux.inv_square_formMethod

function invsquareform(K::KronDecomposed, W::Matrix)

Special function to compute the inverse square form 𝐉𝐏⁻¹𝐉ᵀ (or 𝐖𝐊⁻¹𝐖ᵀ)

source
LaplaceRedux.logdetblockMethod
logdetblock(block::Tuple{Eigen,Eigen}, delta::Number)

Log-determinant of a block in KronDecomposed, shifted by delta by on the diagonal.

source
LaplaceRedux.mmMethod

Matrix-multuply for the KronDecomposed Hessian approximation K and a 2-d matrix W, applying an exponent to K and transposing W before multiplication. Return (K^x)W^T, where x is the exponent.

source
LaplaceRedux.n_paramsMethod
n_params(model::Any, params::EstimationParams)

Helper function to determine the number of parameters of a Flux.Chain with Laplace approximation.

source
LaplaceRedux.outdimMethod
outdim(model::Chain)

Helper function to determine the output dimension of a Flux.Chain, corresponding to the number of neurons on the last layer of the NN.

source
LaplaceRedux.outdimMethod
outdim(la::AbstractLaplace)

Helper function to determine the output dimension, corresponding to the number of neurons on the last layer of the NN, of a Flux.Chain with Laplace approximation.

source
LaplaceRedux.probitMethod
probit(fμ::AbstractArray, fvar::AbstractArray)

Compute the probit approximation of the predictive distribution.

source
LinearAlgebra.detMethod
det(K::KronDecomposed)

Log-determinant of the KronDecomposed block-diagonal matrix, as the exponentiated log-determinant.

source
LinearAlgebra.logdetMethod
logdet(K::KronDecomposed)

Log-determinant of the KronDecomposed block-diagonal matrix, as the product of the determinants of the blocks

source
MLJFlux.buildMethod
MLJFlux.build(model::LaplaceClassification, rng, shape)

Builds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by shape.

Arguments

  • model::LaplaceClassification: The Laplace classification model.
  • rng: A random number generator to ensure reproducibility.
  • shape: A tuple or array specifying the dimensions of the input and output layers.

Returns

  • The constructed MLJFlux model, compatible with the specified input and output dimensions.
source
MLJFlux.buildMethod
MLJFlux.build(model::LaplaceRegression, rng, shape)

Builds an MLJFlux model for Laplace regression compatible with the dimensions of the input and output layers specified by shape.

Arguments

  • model::LaplaceRegression: The Laplace regression model.
  • rng: A random number generator to ensure reproducibility.
  • shape: A tuple or array specifying the dimensions of the input and output layers.

Returns

  • The constructed MLJFlux model, compatible with the specified input and output dimensions.
source
MLJFlux.fitresultMethod
MLJFlux.fitresult(model::LaplaceClassification, chain, y)

Computes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data.

Arguments

  • model::LaplaceClassification: The Laplace classification model to be evaluated.
  • chain: The trained model chain.
  • y: The target data, typically a vector of class labels.

Returns

Returns

A tuple containing:

  • The trained Flux chain.
  • a deepcopy of the laplace model.
source
MLJFlux.fitresultMethod
MLJFlux.fitresult(model::LaplaceRegression, chain, y)

Computes the fit result for a Laplace Regression model, returning the model chain and the number of output variables in the target data.

Arguments

  • model::LaplaceRegression: The Laplace Regression model to be evaluated.
  • chain: The trained model chain.
  • y: The target data, typically a vector of class labels.

Returns

A tuple containing:

  • The trained Flux chain.
  • a deepcopy of the laplace model.
source
MLJFlux.shapeMethod
MLJFlux.shape(model::LaplaceRegression, X, y)

Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.

Arguments

  • model::LaplaceRegression: The LaplaceRegression model to fit.
  • X: The input data for training.
  • y: The target labels for training one-hot encoded.

Returns

  • (input size, output size)
source
MLJFlux.trainMethod
MLJFlux.train(model::LaplaceClassification, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)

Fit the LaplaceRegression model using Flux.jl.

Arguments

  • model::LaplaceClassification: The LaplaceClassification model.
  • regularized_optimiser: the regularized optimiser to apply to the loss function.
  • optimiser_state: thestate of the optimiser.
  • epochs: The number of epochs for training.
  • verbosity: The verbosity level for training.
  • X: The input data for training.
  • y: The target labels for training.

Returns (fitresult, cache, report )

where

  • la: the fitted Laplace model.
  • optimiser_state: the state of the optimiser.
  • history: the training loss history.
source
MLJFlux.trainMethod
MLJFlux.train(model::LaplaceRegression, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)

Fit the LaplaceRegression model using Flux.jl.

Arguments

  • model::LaplaceRegression: The LaplaceRegression model.
  • regularized_optimiser: the regularized optimiser to apply to the loss function.
  • optimiser_state: thestate of the optimiser.
  • epochs: The number of epochs for training.
  • verbosity: The verbosity level for training.
  • X: The input data for training.
  • y: The target labels for training.

Returns (la, optimiser_state, history )

where

  • la: the fitted Laplace model.
  • optimiser_state: the state of the optimiser.
  • history: the training loss history.
source
LaplaceRedux.Curvature.gradientsMethod
gradients(curvature::CurvatureInterface, X::AbstractArray, y::Number)

Compute the gradients with respect to the loss function: ∇ℓ(f(x;θ),y) where f: ℝᴰ ↦ ℝᴷ.

source
LaplaceRedux.Curvature.jacobians_unbatchedMethod
jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray)

Compute the Jacobian of the model output w.r.t. model parameters for the point X, without batching. Here, the nn function is wrapped in an anonymous function using the () -> syntax, which allows it to be differentiated using automatic differentiation.

source
diff --git a/dev/resources/_resources/index.html b/dev/resources/_resources/index.html index 9f173099..b4a710a6 100644 --- a/dev/resources/_resources/index.html +++ b/dev/resources/_resources/index.html @@ -1,2 +1,2 @@ -Additional Resources · LaplaceRedux.jl
+Additional Resources · LaplaceRedux.jl
diff --git a/dev/search_index.js b/dev/search_index.js index 69b76749..9085c0f0 100644 --- a/dev/search_index.js +++ b/dev/search_index.js @@ -1,3 +1,3 @@ var documenterSearchIndex = {"docs": -[{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/mlp/#Bayesian-MLP","page":"MLP Binary Classifier","title":"Bayesian MLP","text":"","category":"section"},{"location":"tutorials/mlp/#Libraries","page":"MLP Binary Classifier","title":"Libraries","text":"","category":"section"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra\ntheme(:lime)","category":"page"},{"location":"tutorials/mlp/#Data","page":"MLP Binary Classifier","title":"Data","text":"","category":"section"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"This time we use a synthetic dataset containing samples that are not linearly separable:","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_non_linear(200)\nX = hcat(xs...) # bring into tabular format\ndata = zip(xs,ys)","category":"page"},{"location":"tutorials/mlp/#Model","page":"MLP Binary Classifier","title":"Model","text":"","category":"section"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"For the classification task we build a neural network with weight decay composed of a single hidden layer.","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"n_hidden = 10\nD = size(X,1)\nnn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, 1)\n) \nloss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) ","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"The model is trained until training loss stagnates.","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"using Flux.Optimise: update!, Adam\nopt = Adam(1e-3)\nepochs = 100\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend","category":"page"},{"location":"tutorials/mlp/#Laplace-Approximation","page":"MLP Binary Classifier","title":"Laplace Approximation","text":"","category":"section"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"Laplace approximation can be implemented as follows:","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"la = Laplace(nn; likelihood=:classification, subset_of_weights=:all)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; verbose=true, n_steps=500)","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"# Plot the posterior distribution with a contour plot.\nzoom=0\np_plugin = plot(la, X, ys; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X, ys; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X, ys; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"Zooming out we can note that the plugin estimator produces high-confidence estimates in regions scarce of any samples. The Laplace approximation is much more conservative about these regions.","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"zoom=-50\np_plugin = plot(la, X, ys; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X, ys; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X, ys; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/logit/#Bayesian-Logistic-Regression","page":"Logistic Regression","title":"Bayesian Logistic Regression","text":"","category":"section"},{"location":"tutorials/logit/#Libraries","page":"Logistic Regression","title":"Libraries","text":"","category":"section"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra\ntheme(:lime)","category":"page"},{"location":"tutorials/logit/#Data","page":"Logistic Regression","title":"Data","text":"","category":"section"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"We will use synthetic data with linearly separable samples:","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_linear(100)\nX = hcat(xs...) # bring into tabular format\ndata = zip(xs,ys)","category":"page"},{"location":"tutorials/logit/#Model","page":"Logistic Regression","title":"Model","text":"","category":"section"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"Logistic regression with weight decay can be implemented in Flux.jl as a single dense (linear) layer with binary logit crossentropy loss:","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"nn = Chain(Dense(2,1))\nλ = 0.5\nsqnorm(x) = sum(abs2, x)\nweight_regularization(λ=λ) = 1/2 * λ^2 * sum(sqnorm, Flux.params(nn))\nloss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) + weight_regularization()","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"The code below simply trains the model. After about 50 training epochs training loss stagnates.","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"using Flux.Optimise: update!, Adam\nopt = Adam()\nepochs = 50\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend","category":"page"},{"location":"tutorials/logit/#Laplace-approximation","page":"Logistic Regression","title":"Laplace approximation","text":"","category":"section"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"Laplace approximation for the posterior predictive can be implemented as follows:","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"la = Laplace(nn; likelihood=:classification, λ=λ, subset_of_weights=:last_layer)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; verbose=true, n_steps=500)","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"zoom = 0\np_plugin = plot(la, X, ys; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X, ys; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X, ys; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/prior/#Libraries","page":"A note on the prior ...","title":"Libraries","text":"","category":"section"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"note: In Progress\n","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"    This documentation is still incomplete.","category":"page"},{"location":"tutorials/prior/#A-quick-note-on-the-prior","page":"A note on the prior ...","title":"A quick note on the prior","text":"","category":"section"},{"location":"tutorials/prior/#General-Effect","page":"A note on the prior ...","title":"General Effect","text":"","category":"section"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"High prior precision rightarrow only observation noise. Low prior precision rightarrow high posterior uncertainty.","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"using LaplaceRedux.Data\nn = 150 # number of observations\nσtrue = 0.30 # true observational noise\nx, y = Data.toy_data_regression(n;noise=σtrue)\nxs = [[x] for x in x]\nX = permutedims(x)","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"data = zip(xs,y)\nn_hidden = 10\nD = size(X,1)\nΛ = [1e5, nothing, 1e-5]\nplts = []\nnns = []\nopt=Flux.Adam(1e-3)\nfor λ ∈ Λ\n nn = Chain(\n Dense(D, n_hidden, tanh),\n Dense(n_hidden, 1)\n ) \n loss(x, y) = Flux.Losses.mse(nn(x), y)\n # train\n epochs = 1000\n for epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n Flux.update!(opt, Flux.params(nn), gs)\n end\n end\n # laplace\n if !isnothing(λ)\n la = Laplace(nn; likelihood=:regression, λ=λ)\n fit!(la, data) \n else\n la = Laplace(nn; likelihood=:regression)\n fit!(la, data) \n optimize_prior!(la)\n end\n \n _suffix = isnothing(λ) ? \" (optimal)\" : \"\"\n λ = unique(diag(la.prior.P₀))[1]\n title = \"λ=$(round(λ,digits=2))$(_suffix)\"\n\n # plot \n plt = plot(la, X, y; title=title, zoom=-5)\n plts = vcat(plts..., plt)\n nns = vcat(nns..., nn)\nend\nplot(plts..., layout=(1,3), size=(1200,300))","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"(Image: )","category":"page"},{"location":"tutorials/prior/#Effect-of-Model-Size-on-Optimal-Choice","page":"A note on the prior ...","title":"Effect of Model Size on Optimal Choice","text":"","category":"section"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"For larger models, the optimal prior precision lambda as evaluated through Empirical Bayes tends to be smaller.","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"data = zip(xs,y)\nn_hiddens = [5, 10, 50]\nD = size(X,1)\nplts = []\nnns = []\nopt=Flux.Adam(1e-3)\nfor n_hidden ∈ n_hiddens\n nn = Chain(\n Dense(D, n_hidden, tanh),\n Dense(n_hidden, 1)\n ) \n loss(x, y) = Flux.Losses.mse(nn(x), y)\n # train\n epochs = 1000\n for epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n Flux.update!(opt, Flux.params(nn), gs)\n end\n end\n # laplace\n la = Laplace(nn; likelihood=:regression)\n fit!(la, data) \n optimize_prior!(la)\n \n λ = unique(diag(la.prior.P₀))[1]\n title = \"n_params=$(LaplaceRedux.n_params(la)),λ=$(round(λ,digits=2))\"\n\n # plot \n plt = plot(la, X, y; title=title, zoom=-5)\n plts = vcat(plts..., plt)\n nns = vcat(nns..., nn)\nend\nplot(plts..., layout=(1,3), size=(1200,300))","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"(Image: )","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_non_linear(200)\nX = hcat(xs...) # bring into tabular format\ndata = zip(xs,ys)\n\nn_hiddens = [5, 10, 50]\nD = size(X,1)\nplts = []\nnns = []\nopt=Flux.Adam(1e-3)\nfor n_hidden ∈ n_hiddens\n nn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, 1)\n ) \n loss(x, y) = Flux.Losses.mse(nn(x), y)\n # train\n epochs = 100\n for epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n Flux.update!(opt, Flux.params(nn), gs)\n end\n end\n # laplace\n la = Laplace(nn; likelihood=:classification)\n fit!(la, data) \n optimize_prior!(la)\n \n λ = unique(diag(la.prior.P₀))[1]\n title = \"n_params=$(LaplaceRedux.n_params(la)),λ=$(round(λ,digits=2))\"\n\n # plot \n plt = plot(la, X, ys; title=title, zoom=-1, clim=(0,1))\n plts = vcat(plts..., plt)\n nns = vcat(nns..., nn)\nend\nplot(plts..., layout=(1,3), size=(1200,300))","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"(Image: )","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/regression/#Libraries","page":"MLP Regression","title":"Libraries","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"Import the libraries required to run this example","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux\ntheme(:wong)","category":"page"},{"location":"tutorials/regression/#Data","page":"MLP Regression","title":"Data","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"We first generate some synthetic data:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"using LaplaceRedux.Data\nn = 300 # number of observations\nσtrue = 0.30 # true observational noise\nx, y = Data.toy_data_regression(n;noise=σtrue)\nxs = [[x] for x in x]\nX = permutedims(x)","category":"page"},{"location":"tutorials/regression/#MLP","page":"MLP Regression","title":"MLP","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"We set up a model and loss with weight regularization:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"data = zip(xs,y)\nn_hidden = 50\nD = size(X,1)\nnn = Chain(\n Dense(D, n_hidden, tanh),\n Dense(n_hidden, 1)\n) \nloss(x, y) = Flux.Losses.mse(nn(x), y)","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"We train the model:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"using Flux.Optimise: update!, Adam\nopt = Adam(1e-3)\nepochs = 1000\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend","category":"page"},{"location":"tutorials/regression/#Laplace-Approximation","page":"MLP Regression","title":"Laplace Approximation","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"Laplace approximation can be implemented as follows:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"subset_w = :all\nla = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)\nfit!(la, data)\nplot(la, X, y; zoom=-5, size=(400,400))","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"(Image: )","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"Next we optimize the prior precision P_0 and and observational noise sigma using Empirical Bayes:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"optimize_prior!(la; verbose=true)\nplot(la, X, y; zoom=-5, size=(400,400))","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"loss(exp.(logP₀), exp.(logσ)) = 104.78561546028183\nLog likelihood: -70.48742092717352\nLog det ratio: 41.1390695290454\nScatter: 27.45731953717124\nloss(exp.(logP₀), exp.(logσ)) = 104.9736282327825\nLog likelihood: -74.85481357633174\nLog det ratio: 46.59827618892447\nScatter: 13.639353123977058\nloss(exp.(logP₀), exp.(logσ)) = 84.38222356291794\nLog likelihood: -54.86985627702764\nLog det ratio: 49.92347667032635\nScatter: 9.101257901454279\n\nloss(exp.(logP₀), exp.(logσ)) = 84.53493863039972\nLog likelihood: -55.013137224636\nLog det ratio: 51.43622180356522\nScatter: 7.607381007962245\nloss(exp.(logP₀), exp.(logσ)) = 83.95921598606084\nLog likelihood: -54.41492266831395\nLog det ratio: 51.794520967146354\nScatter: 7.294065668347427\nloss(exp.(logP₀), exp.(logσ)) = 83.03505059021086\nLog likelihood: -53.50540374805591\nLog det ratio: 51.574749787874794\nScatter: 7.484543896435117\n\nloss(exp.(logP₀), exp.(logσ)) = 82.97840036025443\nLog likelihood: -53.468475394115416\nLog det ratio: 51.17273666609066\nScatter: 7.847113266187348\nloss(exp.(logP₀), exp.(logσ)) = 82.98550025321256\nLog likelihood: -53.48508828283467\nLog det ratio: 50.81442045868749\nScatter: 8.186403482068298\nloss(exp.(logP₀), exp.(logσ)) = 82.9584040552644\nLog likelihood: -53.45989630330948\nLog det ratio: 50.59063282947659\nScatter: 8.406382674433235\n\n\nloss(exp.(logP₀), exp.(logσ)) = 82.94465052328141\nLog likelihood: -53.44600301956443\nLog det ratio: 50.500079294094405\nScatter: 8.497215713339543","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"(Image: )","category":"page"},{"location":"tutorials/multi/#Multi-class-problem","page":"MLP Multi-Label Classifier","title":"Multi-class problem","text":"","category":"section"},{"location":"tutorials/multi/#Libraries","page":"MLP Multi-Label Classifier","title":"Libraries","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux\ntheme(:lime)","category":"page"},{"location":"tutorials/multi/#Data","page":"MLP Multi-Label Classifier","title":"Data","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"using LaplaceRedux.Data\nx, y = Data.toy_data_multi()\nX = hcat(x...)\ny_train = Flux.onehotbatch(y, unique(y))\ny_train = Flux.unstack(y_train',1)","category":"page"},{"location":"tutorials/multi/#MLP","page":"MLP Multi-Label Classifier","title":"MLP","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"We set up a model","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"data = zip(x,y_train)\nn_hidden = 3\nD = size(X,1)\nout_dim = length(unique(y))\nnn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, out_dim)\n) \nloss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"training:","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"using Flux.Optimise: update!, Adam\nopt = Adam()\nepochs = 100\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend","category":"page"},{"location":"tutorials/multi/#Laplace-Approximation","page":"MLP Multi-Label Classifier","title":"Laplace Approximation","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"The Laplace approximation can be implemented as follows:","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"la = Laplace(nn; likelihood=:classification)\nfit!(la, data)\noptimize_prior!(la; verbose=true, n_steps=100)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"_labels = sort(unique(y))\nplt_list = []\nfor target in _labels\n plt = plot(la, X, y; target=target, clim=(0,1))\n push!(plt_list, plt)\nend\nplot(plt_list...)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"_labels = sort(unique(y))\nplt_list = []\nfor target in _labels\n plt = plot(la, X, y; target=target, clim=(0,1), link_approx=:plugin)\n push!(plt_list, plt)\nend\nplot(plt_list...)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"(Image: )","category":"page"},{"location":"reference/","page":"Reference","title":"Reference","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"reference/#All-functions-and-types","page":"Reference","title":"All functions and types","text":"","category":"section"},{"location":"reference/","page":"Reference","title":"Reference","text":"","category":"page"},{"location":"reference/#Exported-functions","page":"Reference","title":"Exported functions","text":"","category":"section"},{"location":"reference/","page":"Reference","title":"Reference","text":"Modules = [\n LaplaceRedux,\n LaplaceRedux.Curvature,\n LaplaceRedux.Data,\n]\nPrivate = false","category":"page"},{"location":"reference/#LaplaceRedux.Laplace","page":"Reference","title":"LaplaceRedux.Laplace","text":"Laplace\n\nConcrete type for Laplace approximation. This type is a subtype of AbstractLaplace and is used to store all the necessary information for a Laplace approximation.\n\nFields\n\nmodel::Flux.Chain: The model to be approximated.\nlikelihood::Symbol: The likelihood function to be used.\nest_params::EstimationParams: The estimation parameters.\nprior::Prior: The parameters defining prior distribution.\nposterior::Posterior: The posterior distribution.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Laplace-Tuple{Any}","page":"Reference","title":"LaplaceRedux.Laplace","text":"Laplace(model::Any; likelihood::Symbol, kwargs...)\n\nOuter constructor for Laplace approximation. This function constructs a Laplace object from a given model and likelihood function.\n\nArguments\n\nmodel::Any: The model to be approximated (a Flux.Chain).\nlikelihood::Symbol: The likelihood function to be used. Possible values are :regression and :classification.\n\nKeyword Arguments\n\nSee LaplaceParams for a description of the keyword arguments.\n\nReturns\n\nla::Laplace: The Laplace object.\n\nExamples\n\nusing Flux, LaplaceRedux\nnn = Chain(Dense(2,1))\nla = Laplace(nn, likelihood=:regression)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.LaplaceClassification","page":"Reference","title":"LaplaceRedux.LaplaceClassification","text":"MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic\n\nA mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. \n\nThe model is defined by the following default parameters for all MLJFlux models:\n\nbuilder: a Flux model that constructs the neural network.\nfinaliser: a Flux model that processes the output of the neural network.\noptimiser: a Flux optimiser.\nloss: a loss function that takes the predicted output and the true output as arguments.\nepochs: the number of epochs.\nbatch_size: the size of a batch.\nlambda: the regularization strength.\nalpha: the regularization mix (0 for all l2, 1 for all l1).\nrng: a random number generator.\noptimiser_changes_trigger_retraining: a boolean indicating whether changes in the optimiser trigger retraining.\nacceleration: the computational resource to use.\n\nThe model also has the following parameters, which are specific to the Laplace approximation:\n\nsubset_of_weights: the subset of weights to use, either :all, :last_layer, or :subnetwork.\nsubnetwork_indices: the indices of the subnetworks.\nhessian_structure: the structure of the Hessian matrix, either :full or :diagonal.\nbackend: the backend to use, either :GGN or :EmpiricalFisher.\nσ: the standard deviation of the prior distribution.\nμ₀: the mean of the prior distribution.\nP₀: the covariance matrix of the prior distribution.\nlink_approx: the link approximation to use, either :probit or :plugin.\npredict_proba: a boolean that select whether to predict probabilities or not.\nret_distr: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.\nfit_prior_nsteps: the number of steps used to fit the priors.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.LaplaceRegression","page":"Reference","title":"LaplaceRedux.LaplaceRegression","text":"MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic\n\nA mutable struct representing a Laplace regression model that extends the MLJFlux.MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. \n\nThe model is defined by the following default parameters for all MLJFlux models:\n\nbuilder: a Flux model that constructs the neural network.\noptimiser: a Flux optimiser.\nloss: a loss function that takes the predicted output and the true output as arguments.\nepochs: the number of epochs.\nbatch_size: the size of a batch.\nlambda: the regularization strength.\nalpha: the regularization mix (0 for all l2, 1 for all l1).\nrng: a random number generator.\noptimiser_changes_trigger_retraining: a boolean indicating whether changes in the optimiser trigger retraining.\nacceleration: the computational resource to use.\n\nThe model also has the following parameters, which are specific to the Laplace approximation:\n\nsubset_of_weights: the subset of weights to use, either :all, :last_layer, or :subnetwork.\nsubnetwork_indices: the indices of the subnetworks.\nhessian_structure: the structure of the Hessian matrix, either :full or :diagonal.\nbackend: the backend to use, either :GGN or :EmpiricalFisher.\nσ: the standard deviation of the prior distribution.\nμ₀: the mean of the prior distribution.\nP₀: the covariance matrix of the prior distribution.\nret_distr: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.\nfit_prior_nsteps: the number of steps used to fit the priors.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.fit!-Tuple{LaplaceRedux.AbstractLaplace, Any}","page":"Reference","title":"LaplaceRedux.fit!","text":"fit!(la::AbstractLaplace,data)\n\nFits the Laplace approximation for a data set. The function returns the number of observations (n_data) that were used to update the Laplace object. It does not return the updated Laplace object itself because the function modifies the input Laplace object in place (as denoted by the use of '!' in the function's name).\n\nExamples\n\nusing Flux, LaplaceRedux\nx, y = LaplaceRedux.Data.toy_data_linear()\ndata = zip(x,y)\nnn = Chain(Dense(2,1))\nla = Laplace(nn)\nfit!(la, data)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.fit!-Tuple{LaplaceRedux.AbstractLaplace, MLUtils.DataLoader}","page":"Reference","title":"LaplaceRedux.fit!","text":"Fit the Laplace approximation, with batched data.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.glm_predictive_distribution-Tuple{LaplaceRedux.AbstractLaplace, AbstractArray}","page":"Reference","title":"LaplaceRedux.glm_predictive_distribution","text":"glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)\n\nComputes the linearized GLM predictive.\n\nArguments\n\nla::AbstractLaplace: A Laplace object.\nX::AbstractArray: Input data.\n\nReturns\n\nnormal_distr A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.\nfμ::AbstractArray: Mean of the predictive distribution. The output shape is column-major as in Flux.\nfvar::AbstractArray: Variance of the predictive distribution. The output shape is column-major as in Flux.\n\nExamples\n\n```julia-repl using Flux, LaplaceRedux using LaplaceRedux.Data: toydatalinear x, y = toydatalinear() data = zip(x,y) nn = Chain(Dense(2,1)) la = Laplace(nn; likelihood=:classification) fit!(la, data) glmpredictivedistribution(la, hcat(x...))\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.optimize_prior!-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.optimize_prior!","text":"optimize_prior!(\n la::AbstractLaplace; \n n_steps::Int=100, lr::Real=1e-1,\n λinit::Union{Nothing,Real}=nothing,\n σinit::Union{Nothing,Real}=nothing\n)\n\nOptimize the prior precision post-hoc through Empirical Bayes (marginal log-likelihood maximization).\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.posterior_covariance","page":"Reference","title":"LaplaceRedux.posterior_covariance","text":"posterior_covariance(la::AbstractLaplace, P=la.P)\n\nComputes the posterior covariance as the inverse of the posterior precision: Sigma=P^-1.\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.posterior_precision","page":"Reference","title":"LaplaceRedux.posterior_precision","text":"posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.P₀)\n\nComputes the posterior precision P for a fitted Laplace Approximation as follows,\n\nP = sum_n=1^Nnabla_theta^2 log p(mathcalD_ntheta)_hattheta + nabla_theta^2 log p(theta)_hattheta\n\nwhere sum_n=1^Nnabla_theta^2log p(mathcalD_ntheta)_hattheta=H is the Hessian and nabla_theta^2 log p(theta)_hattheta=P_0 is the prior precision and hattheta is the MAP estimate.\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.predict-Tuple{LaplaceRedux.AbstractLaplace, AbstractArray}","page":"Reference","title":"LaplaceRedux.predict","text":"predict(la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true)\n\nComputes predictions from Bayesian neural network.\n\nArguments\n\nla::AbstractLaplace: A Laplace object.\nX::AbstractArray: Input data.\nlink_approx::Symbol=:probit: Link function approximation. Options are :probit and :plugin.\npredict_proba::Bool=true: If true (default) apply a sigmoid or a softmax function to the output of the Flux model.\nreturn_distr::Bool=false: if false (default), the function output either the direct output of the chain or pseudo-probabilities (if predict_proba= true). if true predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.\n\nReturns\n\nFor classification tasks, LaplaceRedux provides different options: if retdistr is false: - fμ::AbstractArray: Mean of the predictive distribution if link function is set to :plugin, otherwise the probit approximation. The output shape is column-major as in Flux. if retdistr is true: - a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. For regression tasks:\n\nnormal_distr::Distributions.Normal:the array of Normal distributions computed by glmpredictivedistribution. \n\nExamples\n\nusing Flux, LaplaceRedux\nusing LaplaceRedux.Data: toy_data_linear\nx, y = toy_data_linear()\ndata = zip(x,y)\nnn = Chain(Dense(2,1))\nla = Laplace(nn; likelihood=:classification)\nfit!(la, data)\npredict(la, hcat(x...))\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJModelInterface.predict-Tuple{LaplaceClassification, Any, Any}","page":"Reference","title":"MLJModelInterface.predict","text":"predict(model::LaplaceClassification, Xnew)\n\nPredicts the class labels for new data using the LaplaceClassification model.\n\nArguments\n\nmodel::LaplaceClassification: The trained LaplaceClassification model.\nfitresult: the fitresult output produced by MLJFlux.fit!\nXnew: The new data to make predictions on.\n\nReturns\n\nAn array of predicted class labels.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJModelInterface.predict-Tuple{LaplaceRegression, Any, Any}","page":"Reference","title":"MLJModelInterface.predict","text":"predict(model::LaplaceRegression, Xnew)\n\nPredict the output for new input data using a Laplace regression model.\n\nArguments\n\nmodel::LaplaceRegression: The trained Laplace regression model.\nthe fitresult output produced by MLJFlux.fit!\nXnew: The new input data.\n\nReturns\n\nThe predicted output for the new input data.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.CurvatureInterface","page":"Reference","title":"LaplaceRedux.Curvature.CurvatureInterface","text":"Base type for any curvature interface.\n\n\n\n\n\n","category":"type"},{"location":"reference/#Internal-functions","page":"Reference","title":"Internal functions","text":"","category":"section"},{"location":"reference/","page":"Reference","title":"Reference","text":"Modules = [\n LaplaceRedux,\n LaplaceRedux.Curvature,\n LaplaceRedux.Data,\n]\nPublic = false","category":"page"},{"location":"reference/#LaplaceRedux.AbstractDecomposition","page":"Reference","title":"LaplaceRedux.AbstractDecomposition","text":"Abstract type of Hessian decompositions.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.AbstractLaplace","page":"Reference","title":"LaplaceRedux.AbstractLaplace","text":"Abstract base type for all Laplace approximations in this library. All subclasses implemented are parametric.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.AbstractLaplace-Tuple{AbstractArray}","page":"Reference","title":"LaplaceRedux.AbstractLaplace","text":"(la::AbstractLaplace)(X::AbstractArray)\n\nCalling a model with Laplace Approximation on an array of inputs is equivalent to explicitly calling the predict function.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.EstimationParams","page":"Reference","title":"LaplaceRedux.EstimationParams","text":"EstimationParams\n\nContainer for the parameters of a Laplace approximation. \n\nFields\n\nsubset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.\nsubnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.\nhessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.\ncurvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.EstimationParams-Tuple{LaplaceRedux.LaplaceParams, Any, Symbol}","page":"Reference","title":"LaplaceRedux.EstimationParams","text":"EstimationParams(params::LaplaceParams)\n\nExtracts the estimation parameters from a LaplaceParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.FullHessian","page":"Reference","title":"LaplaceRedux.FullHessian","text":"Concrete type for full Hessian structure. This is the default structure.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.HessianStructure","page":"Reference","title":"LaplaceRedux.HessianStructure","text":"Abstract type for Hessian structure.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Kron","page":"Reference","title":"LaplaceRedux.Kron","text":"Kronecker-factored approximate curvature representation for a neural network model. Each element in kfacs represents two Kronecker factors (𝐆, 𝐀), such that the full block Hessian approximation would be approximated as 𝐀⊗𝐆.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.KronDecomposed","page":"Reference","title":"LaplaceRedux.KronDecomposed","text":"KronDecomposed\n\nDecomposed Kronecker-factored approximate curvature representation for a neural network model.\n\nDecomposition is required to add the prior (diagonal matrix) to the posterior (KronDecomposed). It also has the benefits of reducing the costs for computation of inverses and log-determinants.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.KronHessian","page":"Reference","title":"LaplaceRedux.KronHessian","text":"Concrete type for Kronecker-factored Hessian structure.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.LaplaceParams","page":"Reference","title":"LaplaceRedux.LaplaceParams","text":"LaplaceParams\n\nContainer for the parameters of a Laplace approximation.\n\nFields\n\nsubset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.\nsubnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.\nhessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.\nbackend::Symbol: the backend to use. Possible values are :GGN and :Fisher.\ncurvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.\nσ::Real: the observation noise\nμ₀::Real: the prior mean\nλ::Real: the prior precision\nP₀::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Posterior","page":"Reference","title":"LaplaceRedux.Posterior","text":"Posterior\n\nContainer for the results of a Laplace approximation.\n\nFields\n\nμ::AbstractVector: the MAP estimate of the parameters\nH::Union{AbstractArray,AbstractDecomposition,Nothing}: the Hessian matrix\nP::Union{AbstractArray,AbstractDecomposition,Nothing}: the posterior precision matrix\nΣ::Union{AbstractArray,Nothing}: the posterior covariance matrix\nn_data::Union{Int,Nothing}: the number of data points\nn_params::Union{Int,Nothing}: the number of parameters\nn_out::Union{Int,Nothing}: the number of outputs\nloss::Real: the loss value\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Posterior-Tuple{Any, LaplaceRedux.EstimationParams}","page":"Reference","title":"LaplaceRedux.Posterior","text":"Posterior(model::Any, est_params::EstimationParams)\n\nOuter constructor for Posterior object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Prior","page":"Reference","title":"LaplaceRedux.Prior","text":"Prior\n\nContainer for the prior parameters of a Laplace approximation.\n\nFields\n\nσ::Real: the observation noise\nμ₀::Real: the prior mean\nλ::Real: the prior precision\nP₀::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Prior-Tuple{LaplaceRedux.LaplaceParams, Any, Symbol}","page":"Reference","title":"LaplaceRedux.Prior","text":"Prior(params::LaplaceParams)\n\nExtracts the prior parameters from a LaplaceParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:*-Tuple{LaplaceRedux.KronDecomposed, Number}","page":"Reference","title":"Base.:*","text":"Multiply by a scalar by changing the eigenvalues. Distribute the scalar along the factors of a block.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:*-Tuple{Real, LaplaceRedux.Kron}","page":"Reference","title":"Base.:*","text":"Kronecker-factored curvature scalar scaling.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:+-Tuple{LaplaceRedux.Kron, LaplaceRedux.Kron}","page":"Reference","title":"Base.:+","text":"Kronecker-factored curvature sum.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:+-Tuple{LaplaceRedux.KronDecomposed, LinearAlgebra.Diagonal}","page":"Reference","title":"Base.:+","text":"Shift the factors by a diagonal (assumed uniform scaling)\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:+-Tuple{LaplaceRedux.KronDecomposed, Number}","page":"Reference","title":"Base.:+","text":"Shift the factors by a scalar across the diagonal.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:==-Tuple{LaplaceRedux.Kron, LaplaceRedux.Kron}","page":"Reference","title":"Base.:==","text":"Kronecker-factored curvature equality.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.getindex-Tuple{LaplaceRedux.Kron, Int64}","page":"Reference","title":"Base.getindex","text":"Get Kronecker-factored block represenation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.getindex-Tuple{LaplaceRedux.KronDecomposed, Int64}","page":"Reference","title":"Base.getindex","text":"Get i-th block of a a Kronecker-factored curvature.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.length-Tuple{LaplaceRedux.KronDecomposed}","page":"Reference","title":"Base.length","text":"Number of blocks in a Kronecker-factored curvature.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Flux.params-Tuple{Any, LaplaceRedux.EstimationParams}","page":"Reference","title":"Flux.params","text":"Flux.params(model::Any, params::EstimationParams)\n\nExtracts the parameters of a model based on the subset of weights specified in the EstimationParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Flux.params-Tuple{Laplace}","page":"Reference","title":"Flux.params","text":"Flux.params(la::Laplace)\n\nOverloads the params function for a Laplace object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._H_factor-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux._H_factor","text":"_H_factor(la::AbstractLaplace)\n\nReturns the factor σ⁻², where σ is used in the zero-centered Gaussian prior p(θ) = N(θ;0,σ²I)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._fit!-Tuple{Laplace, LaplaceRedux.FullHessian, Any}","page":"Reference","title":"LaplaceRedux._fit!","text":"_fit!(la::Laplace, hessian_structure::FullHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)\n\nFit a Laplace approximation to the posterior distribution of a model using the full Hessian.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._fit!-Tuple{Laplace, LaplaceRedux.KronHessian, Any}","page":"Reference","title":"LaplaceRedux._fit!","text":"_fit!(la::Laplace, hessian_structure::KronHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)\n\nFit a Laplace approximation to the posterior distribution of a model using the Kronecker-factored Hessian.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._init_H-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux._init_H","text":"_init_H(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._weight_penalty-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux._weight_penalty","text":"_weight_penalty(la::AbstractLaplace)\n\nThe weight penalty term is a regularization term used to prevent overfitting. Weight regularization methods such as weight decay introduce a penalty to the loss function when training a neural network to encourage the network to use small weights. Smaller weights in a neural network can result in a model that is more stable and less likely to overfit the training dataset, in turn having better performance when making a prediction on new data.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.approximate-Tuple{LaplaceRedux.Curvature.CurvatureInterface, LaplaceRedux.FullHessian, Tuple}","page":"Reference","title":"LaplaceRedux.approximate","text":"approximate(curvature::CurvatureInterface, hessian_structure::FullHessian, d::Tuple; batched::Bool=false)\n\nCompute the full approximation, for either a single input-output datapoint or a batch of such. \n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.approximate-Tuple{LaplaceRedux.Curvature.CurvatureInterface, LaplaceRedux.KronHessian, Any}","page":"Reference","title":"LaplaceRedux.approximate","text":"approximate(curvature::CurvatureInterface, hessian_structure::KronHessian, data; batched::Bool=false)\n\nCompute the eigendecomposed Kronecker-factored approximate curvature as the Fisher information matrix.\n\nNote, since the network predictive distribution is used in a weighted sum, and the number of backward passes is linear in the number of target classes, e.g. 100 for CIFAR-100.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.clamp-Tuple{LinearAlgebra.Eigen}","page":"Reference","title":"LaplaceRedux.clamp","text":"Clamp eigenvalues in an eigendecomposition to be non-negative.\n\nSince the Fisher information matrix is a positive-semidefinite by construction, the (near-zero) negative eigenvalues should be neglected.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.convert_subnetwork_indices-Tuple{Vector{Vector{Int64}}, AbstractArray}","page":"Reference","title":"LaplaceRedux.convert_subnetwork_indices","text":"convertsubnetworkindices(subnetwork_indices::AbstractArray)\n\nConverts the subnetwork indices from the user given format [theta, row, column] to an Int i that corresponds to the index of that weight in the flattened array of weights.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.decompose-Tuple{LaplaceRedux.Kron}","page":"Reference","title":"LaplaceRedux.decompose","text":"decompose(K::Kron)\n\nEigendecompose Kronecker factors and turn into KronDecomposed.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.functional_variance-Tuple{Any, Any}","page":"Reference","title":"LaplaceRedux.functional_variance","text":"functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)\n\nCompute the functional variance for the GLM predictive. Dispatches to the appropriate method based on the Hessian structure.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.functional_variance-Tuple{Laplace, LaplaceRedux.FullHessian, Any}","page":"Reference","title":"LaplaceRedux.functional_variance","text":"functional_variance(la::Laplace,𝐉)\n\nCompute the linearized GLM predictive variance as 𝐉ₙΣ𝐉ₙ' where 𝐉=∇f(x;θ)|θ̂ is the Jacobian evaluated at the MAP estimate and Σ = P⁻¹.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.functional_variance-Tuple{Laplace, LaplaceRedux.KronHessian, Matrix}","page":"Reference","title":"LaplaceRedux.functional_variance","text":"functionalvariance(la::Laplace, hessianstructure::KronHessian, 𝐉::Matrix)\n\nCompute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ, where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output 𝐉=∇f(x;θ)|θ̂.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.get_loss_fun-Tuple{Symbol, Flux.Chain}","page":"Reference","title":"LaplaceRedux.get_loss_fun","text":"get_loss_fun(likelihood::Symbol)\n\nHelper function to choose loss function based on specified model likelihood.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.get_loss_type-Tuple{Symbol, Flux.Chain}","page":"Reference","title":"LaplaceRedux.get_loss_type","text":"get_loss_type(likelihood::Symbol)\n\nChoose loss function type based on specified model likelihood.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.get_map_estimate-Tuple{Any, LaplaceRedux.EstimationParams}","page":"Reference","title":"LaplaceRedux.get_map_estimate","text":"get_map_estimate(model::Any, est_params::EstimationParams)\n\nHelper function to extract the MAP estimate of the parameters for the model based on the subset of weights specified in the EstimationParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.get_prior_mean-Tuple{Laplace}","page":"Reference","title":"LaplaceRedux.get_prior_mean","text":"get_prior_mean(la::Laplace)\n\nHelper function to extract the prior mean of the parameters from a Laplace approximation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.has_softmax_or_sigmoid_final_layer-Tuple{Flux.Chain}","page":"Reference","title":"LaplaceRedux.has_softmax_or_sigmoid_final_layer","text":"has_softmax_or_sigmoid_final_layer(model::Flux.Chain)\n\nCheck if the FLux model ends with a sigmoid or with a softmax layer\n\nInput: - model: the Flux Chain object that represent the neural network. Return: - has_finaliser: true if the check is positive, false otherwise.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.hessian_approximation-Tuple{LaplaceRedux.AbstractLaplace, Any}","page":"Reference","title":"LaplaceRedux.hessian_approximation","text":"hessian_approximation(la::AbstractLaplace, d; batched::Bool=false)\n\nComputes the local Hessian approximation at a single datapoint d.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.instantiate_curvature!-Tuple{LaplaceRedux.EstimationParams, Any, Symbol, Symbol}","page":"Reference","title":"LaplaceRedux.instantiate_curvature!","text":"instantiate_curvature!(params::EstimationParams, model::Any, likelihood::Symbol, backend::Symbol)\n\nInstantiates the curvature interface for a Laplace approximation. The curvature interface is a concrete subtype of CurvatureInterface and is used to compute the Hessian matrix. The curvature interface is stored in the curvature field of the EstimationParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.interleave-Tuple","page":"Reference","title":"LaplaceRedux.interleave","text":"Interleave elements of multiple iterables in order provided.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.inv_square_form-Tuple{LaplaceRedux.KronDecomposed, Matrix}","page":"Reference","title":"LaplaceRedux.inv_square_form","text":"function invsquareform(K::KronDecomposed, W::Matrix)\n\nSpecial function to compute the inverse square form 𝐉𝐏⁻¹𝐉ᵀ (or 𝐖𝐊⁻¹𝐖ᵀ)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_det_posterior_precision-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_det_posterior_precision","text":"log_det_posterior_precision(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_det_prior_precision-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_det_prior_precision","text":"log_det_prior_precision(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_det_ratio-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_det_ratio","text":"log_det_ratio(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_likelihood-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_likelihood","text":"log_likelihood(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_marginal_likelihood-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_marginal_likelihood","text":"log_marginal_likelihood(la::AbstractLaplace; P₀::Union{Nothing,UniformScaling}=nothing, σ::Union{Nothing, Real}=nothing)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.logdetblock-Tuple{Tuple{LinearAlgebra.Eigen, LinearAlgebra.Eigen}, Number}","page":"Reference","title":"LaplaceRedux.logdetblock","text":"logdetblock(block::Tuple{Eigen,Eigen}, delta::Number)\n\nLog-determinant of a block in KronDecomposed, shifted by delta by on the diagonal.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.mm-Tuple{LaplaceRedux.KronDecomposed, Any}","page":"Reference","title":"LaplaceRedux.mm","text":"Matrix-multuply for the KronDecomposed Hessian approximation K and a 2-d matrix W, applying an exponent to K and transposing W before multiplication. Return (K^x)W^T, where x is the exponent.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.n_params-Tuple{Any, LaplaceRedux.EstimationParams}","page":"Reference","title":"LaplaceRedux.n_params","text":"n_params(model::Any, params::EstimationParams)\n\nHelper function to determine the number of parameters of a Flux.Chain with Laplace approximation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.n_params-Tuple{Laplace}","page":"Reference","title":"LaplaceRedux.n_params","text":"LaplaceRedux.n_params(la::Laplace)\n\nOverloads the n_params function for a Laplace object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.outdim-Tuple{Flux.Chain}","page":"Reference","title":"LaplaceRedux.outdim","text":"outdim(model::Chain)\n\nHelper function to determine the output dimension of a Flux.Chain, corresponding to the number of neurons on the last layer of the NN.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.outdim-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.outdim","text":"outdim(la::AbstractLaplace)\n\nHelper function to determine the output dimension, corresponding to the number of neurons on the last layer of the NN, of a Flux.Chain with Laplace approximation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.prior_precision-Tuple{Laplace}","page":"Reference","title":"LaplaceRedux.prior_precision","text":"prior_precision(la::Laplace)\n\nHelper function to extract the prior precision matrix from a Laplace approximation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.probit-Tuple{AbstractArray, AbstractArray}","page":"Reference","title":"LaplaceRedux.probit","text":"probit(fμ::AbstractArray, fvar::AbstractArray)\n\nCompute the probit approximation of the predictive distribution.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.validate_subnetwork_indices-Tuple{Union{Nothing, Vector{Vector{Int64}}}, Any}","page":"Reference","title":"LaplaceRedux.validate_subnetwork_indices","text":"validatesubnetworkindices( subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}, params )\n\nDetermines whether subnetwork_indices is a valid input for specified parameters.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LinearAlgebra.det-Tuple{LaplaceRedux.KronDecomposed}","page":"Reference","title":"LinearAlgebra.det","text":"det(K::KronDecomposed)\n\nLog-determinant of the KronDecomposed block-diagonal matrix, as the exponentiated log-determinant.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LinearAlgebra.logdet-Tuple{LaplaceRedux.KronDecomposed}","page":"Reference","title":"LinearAlgebra.logdet","text":"logdet(K::KronDecomposed)\n\nLog-determinant of the KronDecomposed block-diagonal matrix, as the product of the determinants of the blocks\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.build-Tuple{LaplaceClassification, Any, Any}","page":"Reference","title":"MLJFlux.build","text":"MLJFlux.build(model::LaplaceClassification, rng, shape)\n\nBuilds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by shape.\n\nArguments\n\nmodel::LaplaceClassification: The Laplace classification model.\nrng: A random number generator to ensure reproducibility.\nshape: A tuple or array specifying the dimensions of the input and output layers.\n\nReturns\n\nThe constructed MLJFlux model, compatible with the specified input and output dimensions.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.build-Tuple{LaplaceRegression, Any, Any}","page":"Reference","title":"MLJFlux.build","text":"MLJFlux.build(model::LaplaceRegression, rng, shape)\n\nBuilds an MLJFlux model for Laplace regression compatible with the dimensions of the input and output layers specified by shape.\n\nArguments\n\nmodel::LaplaceRegression: The Laplace regression model.\nrng: A random number generator to ensure reproducibility.\nshape: A tuple or array specifying the dimensions of the input and output layers.\n\nReturns\n\nThe constructed MLJFlux model, compatible with the specified input and output dimensions.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.fitresult-Tuple{LaplaceClassification, Any, Any}","page":"Reference","title":"MLJFlux.fitresult","text":"MLJFlux.fitresult(model::LaplaceClassification, chain, y)\n\nComputes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data.\n\nArguments\n\nmodel::LaplaceClassification: The Laplace classification model to be evaluated.\nchain: The trained model chain.\ny: The target data, typically a vector of class labels.\n\nReturns\n\nReturns\n\nA tuple containing:\n\nThe trained Flux chain.\na deepcopy of the laplace model.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.fitresult-Tuple{LaplaceRegression, Any, Any}","page":"Reference","title":"MLJFlux.fitresult","text":"MLJFlux.fitresult(model::LaplaceRegression, chain, y)\n\nComputes the fit result for a Laplace Regression model, returning the model chain and the number of output variables in the target data.\n\nArguments\n\nmodel::LaplaceRegression: The Laplace Regression model to be evaluated.\nchain: The trained model chain.\ny: The target data, typically a vector of class labels.\n\nReturns\n\nA tuple containing:\n\nThe trained Flux chain.\na deepcopy of the laplace model.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.shape-Tuple{LaplaceRegression, Any, Any}","page":"Reference","title":"MLJFlux.shape","text":"MLJFlux.shape(model::LaplaceRegression, X, y)\n\nCompute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.\n\nArguments\n\nmodel::LaplaceRegression: The LaplaceRegression model to fit.\nX: The input data for training.\ny: The target labels for training one-hot encoded.\n\nReturns\n\n(input size, output size)\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.train-Tuple{LaplaceClassification, Vararg{Any, 7}}","page":"Reference","title":"MLJFlux.train","text":"MLJFlux.train(model::LaplaceClassification, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)\n\nFit the LaplaceRegression model using Flux.jl.\n\nArguments\n\nmodel::LaplaceClassification: The LaplaceClassification model.\nregularized_optimiser: the regularized optimiser to apply to the loss function.\noptimiser_state: thestate of the optimiser.\nepochs: The number of epochs for training.\nverbosity: The verbosity level for training.\nX: The input data for training.\ny: The target labels for training.\n\nReturns (fitresult, cache, report )\n\nwhere\n\nla: the fitted Laplace model.\noptimiser_state: the state of the optimiser.\nhistory: the training loss history.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.train-Tuple{LaplaceRegression, Vararg{Any, 7}}","page":"Reference","title":"MLJFlux.train","text":"MLJFlux.train(model::LaplaceRegression, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)\n\nFit the LaplaceRegression model using Flux.jl.\n\nArguments\n\nmodel::LaplaceRegression: The LaplaceRegression model.\nregularized_optimiser: the regularized optimiser to apply to the loss function.\noptimiser_state: thestate of the optimiser.\nepochs: The number of epochs for training.\nverbosity: The verbosity level for training.\nX: The input data for training.\ny: The target labels for training.\n\nReturns (la, optimiser_state, history )\n\nwhere\n\nla: the fitted Laplace model.\noptimiser_state: the state of the optimiser.\nhistory: the training loss history.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.@zb-Tuple{Any}","page":"Reference","title":"LaplaceRedux.@zb","text":"Macro for zero-based indexing. Example of usage: (@zb A[0]) = ...\n\n\n\n\n\n","category":"macro"},{"location":"reference/#LaplaceRedux.Curvature.EmpiricalFisher","page":"Reference","title":"LaplaceRedux.Curvature.EmpiricalFisher","text":"Constructor for curvature approximated by empirical Fisher.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Curvature.GGN","page":"Reference","title":"LaplaceRedux.Curvature.GGN","text":"Constructor for curvature approximated by Generalized Gauss-Newton.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Curvature.full_batched-Tuple{LaplaceRedux.Curvature.EmpiricalFisher, Tuple}","page":"Reference","title":"LaplaceRedux.Curvature.full_batched","text":"full_batched(curvature::EmpiricalFisher, d::Tuple)\n\nCompute the full empirical Fisher for batch of inputs-outputs, with the batch dimension at the end.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.full_batched-Tuple{LaplaceRedux.Curvature.GGN, Tuple}","page":"Reference","title":"LaplaceRedux.Curvature.full_batched","text":"full_batched(curvature::GGN, d::Tuple)\n\nCompute the full GGN for batch of inputs-outputs, with the batch dimension at the end.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.full_unbatched-Tuple{LaplaceRedux.Curvature.EmpiricalFisher, Tuple}","page":"Reference","title":"LaplaceRedux.Curvature.full_unbatched","text":"full_unbatched(curvature::EmpiricalFisher, d::Tuple)\n\nCompute the full empirical Fisher for a single datapoint.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.full_unbatched-Tuple{LaplaceRedux.Curvature.GGN, Tuple}","page":"Reference","title":"LaplaceRedux.Curvature.full_unbatched","text":"full_unbatched(curvature::GGN, d::Tuple)\n\nCompute the full GGN for a singular input-ouput datapoint. \n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.gradients-Tuple{LaplaceRedux.Curvature.CurvatureInterface, AbstractArray, Union{Number, AbstractArray}}","page":"Reference","title":"LaplaceRedux.Curvature.gradients","text":"gradients(curvature::CurvatureInterface, X::AbstractArray, y::Number)\n\nCompute the gradients with respect to the loss function: ∇ℓ(f(x;θ),y) where f: ℝᴰ ↦ ℝᴷ.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.jacobians-Tuple{LaplaceRedux.Curvature.CurvatureInterface, AbstractArray}","page":"Reference","title":"LaplaceRedux.Curvature.jacobians","text":"jacobians(curvature::CurvatureInterface, X::AbstractArray; batched::Bool=false)\n\nComputes the Jacobian ∇f(x;θ) where f: ℝᴰ ↦ ℝᴷ.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.jacobians_batched-Tuple{LaplaceRedux.Curvature.CurvatureInterface, AbstractArray}","page":"Reference","title":"LaplaceRedux.Curvature.jacobians_batched","text":"jacobians_batched(curvature::CurvatureInterface, X::AbstractArray)\n\nCompute Jacobians of the model output w.r.t. model parameters for points in X, with batching.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.jacobians_unbatched-Tuple{LaplaceRedux.Curvature.CurvatureInterface, AbstractArray}","page":"Reference","title":"LaplaceRedux.Curvature.jacobians_unbatched","text":"jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray)\n\nCompute the Jacobian of the model output w.r.t. model parameters for the point X, without batching. Here, the nn function is wrapped in an anonymous function using the () -> syntax, which allows it to be differentiated using automatic differentiation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Data.toy_data_linear","page":"Reference","title":"LaplaceRedux.Data.toy_data_linear","text":"toy_data_linear(N=100)\n\nExamples\n\ntoy_data_linear()\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.Data.toy_data_multi","page":"Reference","title":"LaplaceRedux.Data.toy_data_multi","text":"toy_data_multi(N=100)\n\nExamples\n\ntoy_data_multi()\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.Data.toy_data_non_linear","page":"Reference","title":"LaplaceRedux.Data.toy_data_non_linear","text":"toy_data_non_linear(N=100)\n\nExamples\n\ntoy_data_non_linear()\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.Data.toy_data_regression","page":"Reference","title":"LaplaceRedux.Data.toy_data_regression","text":"toy_data_regression(N=25, p=1; noise=0.3, fun::Function=f(x)=sin(2 * π * x))\n\nA helper function to generate synthetic data for regression.\n\n\n\n\n\n","category":"function"},{"location":"","page":"Home","title":"Home","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"","page":"Home","title":"Home","text":"(Image: )","category":"page"},{"location":"","page":"Home","title":"Home","text":"Documentation for LaplaceRedux.jl.","category":"page"},{"location":"#LaplaceRedux","page":"Home","title":"LaplaceRedux","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"LaplaceRedux.jl is a library written in pure Julia that can be used for effortless Bayesian Deep Learning through Laplace Approximation (LA). In the development of this package I have drawn inspiration from this Python library and its companion paper (Daxberger et al. 2021).","category":"page"},{"location":"#Installation","page":"Home","title":"🚩 Installation","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"The stable version of this package can be installed as follows:","category":"page"},{"location":"","page":"Home","title":"Home","text":"using Pkg\nPkg.add(\"LaplaceRedux.jl\")","category":"page"},{"location":"","page":"Home","title":"Home","text":"The development version can be installed like so:","category":"page"},{"location":"","page":"Home","title":"Home","text":"using Pkg\nPkg.add(\"https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl\")","category":"page"},{"location":"#Getting-Started","page":"Home","title":"🏃 Getting Started","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"If you are new to Deep Learning in Julia or simply prefer learning through videos, check out this awesome YouTube tutorial by doggo.jl 🐶. Additionally, you can also find a video of my presentation at JuliaCon 2022 on YouTube.","category":"page"},{"location":"#Basic-Usage","page":"Home","title":"🖥️ Basic Usage","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"LaplaceRedux.jl can be used for any neural network trained in Flux.jl. Below we show basic usage examples involving two simple models for a regression and a classification task, respectively.","category":"page"},{"location":"#Regression","page":"Home","title":"Regression","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"A complete worked example for a regression model can be found in the docs. Here we jump straight to Laplace Approximation and take the pre-trained model nn as given. Then LA can be implemented as follows, where we specify the model likelihood. The plot shows the fitted values overlaid with a 95% confidence interval. As expected, predictive uncertainty quickly increases in areas that are not populated by any training data.","category":"page"},{"location":"","page":"Home","title":"Home","text":"la = Laplace(nn; likelihood=:regression)\nfit!(la, data)\noptimize_prior!(la)\nplot(la, X, y; zoom=-5, size=(500,500))","category":"page"},{"location":"","page":"Home","title":"Home","text":"(Image: )","category":"page"},{"location":"#Binary-Classification","page":"Home","title":"Binary Classification","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"Once again we jump straight to LA and refer to the docs for a complete worked example involving binary classification. In this case we need to specify likelihood=:classification. The plot below shows the resulting posterior predictive distributions as contours in the two-dimensional feature space: note how the Plugin Approximation on the left compares to the Laplace Approximation on the right.","category":"page"},{"location":"","page":"Home","title":"Home","text":"la = Laplace(nn; likelihood=:classification)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; n_steps=100)\n\n# Plot the posterior predictive distribution:\nzoom=0\np_plugin = plot(la, X, ys; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X, ys; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X, ys; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))","category":"page"},{"location":"","page":"Home","title":"Home","text":"(Image: )","category":"page"},{"location":"#JuliaCon-2022","page":"Home","title":"📢 JuliaCon 2022","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"This project was presented at JuliaCon 2022 in July 2022. See here for details.","category":"page"},{"location":"#Contribute","page":"Home","title":"🛠️ Contribute","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"Contributions are very much welcome! Please follow the SciML ColPrac guide. You may want to start by having a look at any open issues.","category":"page"},{"location":"#References","page":"Home","title":"🎓 References","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"Daxberger, Erik, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. 2021. “Laplace Redux-Effortless Bayesian Deep Learning.” Advances in Neural Information Processing Systems 34.","category":"page"},{"location":"resources/_resources/#Additional-Resources","page":"Additional Resources","title":"Additional Resources","text":"","category":"section"},{"location":"resources/_resources/#JuliaCon-2022","page":"Additional Resources","title":"JuliaCon 2022","text":"","category":"section"},{"location":"resources/_resources/","page":"Additional Resources","title":"Additional Resources","text":"Slides: link","category":"page"},{"location":"resources/_resources/","page":"Additional Resources","title":"Additional Resources","text":"","category":"page"}] +[{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/mlp/#Bayesian-MLP","page":"MLP Binary Classifier","title":"Bayesian MLP","text":"","category":"section"},{"location":"tutorials/mlp/#Libraries","page":"MLP Binary Classifier","title":"Libraries","text":"","category":"section"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra\ntheme(:lime)","category":"page"},{"location":"tutorials/mlp/#Data","page":"MLP Binary Classifier","title":"Data","text":"","category":"section"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"This time we use a synthetic dataset containing samples that are not linearly separable:","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"#set seed\nseed = 1234\nRandom.seed!(seed)\n# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_non_linear(400; seed = seed)\n# Shuffle the data\nn = length(ys)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nxs_train = xs[train_indices]\nxs_test = xs[test_indices]\nys_train = ys[train_indices]\nys_test = ys[test_indices]\n# bring into tabular format\nX_train = hcat(xs_train...) \nX_test = hcat(xs_test...) \n\ndata = zip(xs_train,ys_train)","category":"page"},{"location":"tutorials/mlp/#Model","page":"MLP Binary Classifier","title":"Model","text":"","category":"section"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"For the classification task we build a neural network with weight decay composed of a single hidden layer.","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"n_hidden = 10\nD = size(X_train,1)\nnn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, 1)\n) \nloss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) ","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"The model is trained until training loss stagnates.","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"using Flux.Optimise: update!, Adam\nopt = Adam(1e-3)\nepochs = 100\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend","category":"page"},{"location":"tutorials/mlp/#Laplace-Approximation","page":"MLP Binary Classifier","title":"Laplace Approximation","text":"","category":"section"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"Laplace approximation can be implemented as follows:","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"la = Laplace(nn; likelihood=:classification, subset_of_weights=:all)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; verbose=true, n_steps=500)","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"# Plot the posterior distribution with a contour plot.\nzoom=0\np_plugin = plot(la, X_train, ys_train; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X_train, ys_train; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X_train, ys_train; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"Zooming out we can note that the plugin estimator produces high-confidence estimates in regions scarce of any samples. The Laplace approximation is much more conservative about these regions.","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"zoom=-50\np_plugin = plot(la, X_train, ys_train; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X_train, ys_train; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X_train, ys_train; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"We plot now the calibration plot to assess the level of average calibration reached by the neural network.","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"predicted_distributions= predict(la, X_test,ret_distr=true)\nCalibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"and the sharpness score","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"sharpness_classification(ys_test,vec(predicted_distributions))","category":"page"},{"location":"tutorials/mlp/","page":"MLP Binary Classifier","title":"MLP Binary Classifier","text":"(0.9277189055456709, 0.9196132560599691)","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/logit/#Bayesian-Logistic-Regression","page":"Logistic Regression","title":"Bayesian Logistic Regression","text":"","category":"section"},{"location":"tutorials/logit/#Libraries","page":"Logistic Regression","title":"Libraries","text":"","category":"section"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra\ntheme(:lime)","category":"page"},{"location":"tutorials/logit/#Data","page":"Logistic Regression","title":"Data","text":"","category":"section"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"We will use synthetic data with linearly separable samples:","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"# set seed\nseed= 1234\nRandom.seed!(seed)\n# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_linear(100; seed=seed)\nX = hcat(xs...) # bring into tabular format","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"split in a training and test set","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"# Shuffle the data\nn = length(ys)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nxs_train = xs[train_indices]\nxs_test = xs[test_indices]\nys_train = ys[train_indices]\nys_test = ys[test_indices]\n# bring into tabular format\nX_train = hcat(xs_train...) \nX_test = hcat(xs_test...) \n\ndata = zip(xs_train,ys_train)","category":"page"},{"location":"tutorials/logit/#Model","page":"Logistic Regression","title":"Model","text":"","category":"section"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"Logistic regression with weight decay can be implemented in Flux.jl as a single dense (linear) layer with binary logit crossentropy loss:","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"nn = Chain(Dense(2,1))\nλ = 0.5\nsqnorm(x) = sum(abs2, x)\nweight_regularization(λ=λ) = 1/2 * λ^2 * sum(sqnorm, Flux.params(nn))\nloss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) + weight_regularization()","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"The code below simply trains the model. After about 50 training epochs training loss stagnates.","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"using Flux.Optimise: update!, Adam\nopt = Adam()\nepochs = 50\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend","category":"page"},{"location":"tutorials/logit/#Laplace-approximation","page":"Logistic Regression","title":"Laplace approximation","text":"","category":"section"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"Laplace approximation for the posterior predictive can be implemented as follows:","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"la = Laplace(nn; likelihood=:classification, λ=λ, subset_of_weights=:last_layer)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; verbose=true, n_steps=500)","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"zoom = 0\np_plugin = plot(la, X, ys; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X, ys; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X, ys; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"(Image: )","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"Now we can test the level of calibration of the neural network. First we collect the predicted results over the test dataset","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":" predicted_distributions= predict(la, X_test,ret_distr=true)","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"1×20 Matrix{Distributions.Bernoulli{Float64}}:\n Distributions.Bernoulli{Float64}(p=0.13122) … Distributions.Bernoulli{Float64}(p=0.109559)","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"then we plot the calibration plot","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"Calibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"(Image: )","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"as we can see from the plot, although extremely accurate, the neural network does not seem to be calibrated well. This is, however, an effect of the extreme accuracy reached by the neural network which causes the lack of predictions with high uncertainty (low certainty). We can see this by looking at the level of sharpness for the two classes which are extremely close to 1, indicating the high level of trust that the neural network has in the predictions.","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"sharpness_classification(ys_test,vec(predicted_distributions))","category":"page"},{"location":"tutorials/logit/","page":"Logistic Regression","title":"Logistic Regression","text":"(0.9131870336577175, 0.8865055827351365)","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/prior/#Libraries","page":"A note on the prior ...","title":"Libraries","text":"","category":"section"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"note: In Progress\n","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"    This documentation is still incomplete.","category":"page"},{"location":"tutorials/prior/#A-quick-note-on-the-prior","page":"A note on the prior ...","title":"A quick note on the prior","text":"","category":"section"},{"location":"tutorials/prior/#General-Effect","page":"A note on the prior ...","title":"General Effect","text":"","category":"section"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"High prior precision rightarrow only observation noise. Low prior precision rightarrow high posterior uncertainty.","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"using LaplaceRedux.Data\nn = 150 # number of observations\nσtrue = 0.30 # true observational noise\nx, y = Data.toy_data_regression(n;noise=σtrue)\nxs = [[x] for x in x]\nX = permutedims(x)","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"data = zip(xs,y)\nn_hidden = 10\nD = size(X,1)\nΛ = [1e5, nothing, 1e-5]\nplts = []\nnns = []\nopt=Flux.Adam(1e-3)\nfor λ ∈ Λ\n nn = Chain(\n Dense(D, n_hidden, tanh),\n Dense(n_hidden, 1)\n ) \n loss(x, y) = Flux.Losses.mse(nn(x), y)\n # train\n epochs = 1000\n for epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n Flux.update!(opt, Flux.params(nn), gs)\n end\n end\n # laplace\n if !isnothing(λ)\n la = Laplace(nn; likelihood=:regression, λ=λ)\n fit!(la, data) \n else\n la = Laplace(nn; likelihood=:regression)\n fit!(la, data) \n optimize_prior!(la)\n end\n \n _suffix = isnothing(λ) ? \" (optimal)\" : \"\"\n λ = unique(diag(la.prior.P₀))[1]\n title = \"λ=$(round(λ,digits=2))$(_suffix)\"\n\n # plot \n plt = plot(la, X, y; title=title, zoom=-5)\n plts = vcat(plts..., plt)\n nns = vcat(nns..., nn)\nend\nplot(plts..., layout=(1,3), size=(1200,300))","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"(Image: )","category":"page"},{"location":"tutorials/prior/#Effect-of-Model-Size-on-Optimal-Choice","page":"A note on the prior ...","title":"Effect of Model Size on Optimal Choice","text":"","category":"section"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"For larger models, the optimal prior precision lambda as evaluated through Empirical Bayes tends to be smaller.","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"data = zip(xs,y)\nn_hiddens = [5, 10, 50]\nD = size(X,1)\nplts = []\nnns = []\nopt=Flux.Adam(1e-3)\nfor n_hidden ∈ n_hiddens\n nn = Chain(\n Dense(D, n_hidden, tanh),\n Dense(n_hidden, 1)\n ) \n loss(x, y) = Flux.Losses.mse(nn(x), y)\n # train\n epochs = 1000\n for epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n Flux.update!(opt, Flux.params(nn), gs)\n end\n end\n # laplace\n la = Laplace(nn; likelihood=:regression)\n fit!(la, data) \n optimize_prior!(la)\n \n λ = unique(diag(la.prior.P₀))[1]\n title = \"n_params=$(LaplaceRedux.n_params(la)),λ=$(round(λ,digits=2))\"\n\n # plot \n plt = plot(la, X, y; title=title, zoom=-5)\n plts = vcat(plts..., plt)\n nns = vcat(nns..., nn)\nend\nplot(plts..., layout=(1,3), size=(1200,300))","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"(Image: )","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"# Number of points to generate.\nxs, ys = LaplaceRedux.Data.toy_data_non_linear(200)\nX = hcat(xs...) # bring into tabular format\ndata = zip(xs,ys)\n\nn_hiddens = [5, 10, 50]\nD = size(X,1)\nplts = []\nnns = []\nopt=Flux.Adam(1e-3)\nfor n_hidden ∈ n_hiddens\n nn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, 1)\n ) \n loss(x, y) = Flux.Losses.mse(nn(x), y)\n # train\n epochs = 100\n for epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n Flux.update!(opt, Flux.params(nn), gs)\n end\n end\n # laplace\n la = Laplace(nn; likelihood=:classification)\n fit!(la, data) \n optimize_prior!(la)\n \n λ = unique(diag(la.prior.P₀))[1]\n title = \"n_params=$(LaplaceRedux.n_params(la)),λ=$(round(λ,digits=2))\"\n\n # plot \n plt = plot(la, X, ys; title=title, zoom=-1, clim=(0,1))\n plts = vcat(plts..., plt)\n nns = vcat(nns..., nn)\nend\nplot(plts..., layout=(1,3), size=(1200,300))","category":"page"},{"location":"tutorials/prior/","page":"A note on the prior ...","title":"A note on the prior ...","text":"(Image: )","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/regression/#Libraries","page":"MLP Regression","title":"Libraries","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"Import the libraries required to run this example","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux\ntheme(:wong)","category":"page"},{"location":"tutorials/regression/#Data","page":"MLP Regression","title":"Data","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"We first generate some synthetic data:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"using LaplaceRedux.Data\nn = 3000 # number of observations\nσtrue = 0.30 # true observational noise\nx, y = Data.toy_data_regression(n;noise=σtrue,seed=1234)\nxs = [[x] for x in x]\nX = permutedims(x)","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"and split them in a training set and a test set","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"# Shuffle the data\nRandom.seed!(1234) # Set a seed for reproducibility\nshuffle_indices = shuffle(1:n)\n\n# Define split ratios\ntrain_ratio = 0.8\ntest_ratio = 0.2\n\n# Calculate split indices\ntrain_end = Int(floor(train_ratio * n))\n\n# Split the data\ntrain_indices = shuffle_indices[1:train_end]\ntest_indices = shuffle_indices[train_end+1:end]\n\n# Create the splits\nx_train, y_train = x[train_indices], y[train_indices]\nx_test, y_test = x[test_indices], y[test_indices]\n\n# Optional: Convert to desired format\nxs_train = [[x] for x in x_train]\nxs_test = [[x] for x in x_test]\nX_train = permutedims(x_train)\nX_test = permutedims(x_test)","category":"page"},{"location":"tutorials/regression/#MLP","page":"MLP Regression","title":"MLP","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"We set up a model and loss with weight regularization:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"train_data = zip(xs_train,y_train)\nn_hidden = 50\nD = size(X,1)\nnn = Chain(\n Dense(D, n_hidden, tanh),\n Dense(n_hidden, 1)\n) \nloss(x, y) = Flux.Losses.mse(nn(x), y)","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"We train the model:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"using Flux.Optimise: update!, Adam\nopt = Adam(1e-3)\nepochs = 1000\navg_loss(train_data) = mean(map(d -> loss(d[1],d[2]), train_data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in train_data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(train_data)\n end\nend","category":"page"},{"location":"tutorials/regression/#Laplace-Approximation","page":"MLP Regression","title":"Laplace Approximation","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"Laplace approximation can be implemented as follows:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"subset_w = :all\nla = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)\nfit!(la, train_data)\nplot(la, X_train, y_train; zoom=-5, size=(400,400))","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"(Image: )","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"Next we optimize the prior precision P_0 and and observational noise sigma using Empirical Bayes:","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"optimize_prior!(la; verbose=true)\nplot(la, X_train, y_train; zoom=-5, size=(400,400))","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"loss(exp.(logP₀), exp.(logσ)) = 668.3714946472106\nLog likelihood: -618.5175117610522\nLog det ratio: 68.76532606873238\nScatter: 30.942639703584522\nloss(exp.(logP₀), exp.(logσ)) = 719.2536119935747\nLog likelihood: -673.0996963447847\nLog det ratio: 76.53255037599948\nScatter: 15.775280921580569\nloss(exp.(logP₀), exp.(logσ)) = 574.605864472924\nLog likelihood: -528.694286608232\n\n\nLog det ratio: 80.73114330857285\nScatter: 11.092012420811196\nloss(exp.(logP₀), exp.(logσ)) = 568.4433850825203\nLog likelihood: -522.4407550111031\nLog det ratio: 82.10089958560243\nScatter: 9.90436055723207\n\n\nloss(exp.(logP₀), exp.(logσ)) = 566.9485255672008\nLog likelihood: -520.9682443835385\nLog det ratio: 81.84516297272847\nScatter: 10.11539939459612\nloss(exp.(logP₀), exp.(logσ)) = 559.9852101992792\nLog likelihood: -514.0625630685765\nLog det ratio: 80.97813304453496\nScatter: 10.867161216870441\n\nloss(exp.(logP₀), exp.(logσ)) = 559.1404593114019\nLog likelihood: -513.2449017869876\nLog det ratio: 80.16026747795866\nScatter: 11.630847570869795\nloss(exp.(logP₀), exp.(logσ)) = 559.3201392562346\nLog likelihood: -513.4273312363501\nLog det ratio: 79.68892769076004\nScatter: 12.096688349008877\n\n\nloss(exp.(logP₀), exp.(logσ)) = 559.2111983983311\nLog likelihood: -513.3174948065804\nLog det ratio: 79.56631681347287\nScatter: 12.2210903700287\nloss(exp.(logP₀), exp.(logσ)) = 559.1107459310829\nLog likelihood: -513.2176579845662\nLog det ratio: 79.63946732368183\nScatter: 12.146708569351494","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"(Image: )","category":"page"},{"location":"tutorials/regression/#Calibration-Plot","page":"MLP Regression","title":"Calibration Plot","text":"","category":"section"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"Once the prior precision has been optimized it is possible to evaluate the quality of the predictive distribution obtained through a calibration plot and a test dataset (ytest, Xtest).","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"First, we apply the trained network on the test dataset (ytest, Xtest) and collect the neural network’s predicted distributions","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"predicted_distributions= predict(la, X_test,ret_distr=true)","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"600×1 Matrix{Distributions.Normal{Float64}}:\n Distributions.Normal{Float64}(μ=-0.1137533187866211, σ=0.07161056521032018)\n Distributions.Normal{Float64}(μ=0.7063850164413452, σ=0.050697938829269665)\n Distributions.Normal{Float64}(μ=-0.2211049497127533, σ=0.06876939416479119)\n Distributions.Normal{Float64}(μ=0.720299243927002, σ=0.08665125572287981)\n Distributions.Normal{Float64}(μ=-0.8338974714279175, σ=0.06464012115237727)\n Distributions.Normal{Float64}(μ=0.9910320043563843, σ=0.07452060172164382)\n Distributions.Normal{Float64}(μ=0.1507074236869812, σ=0.07316299850461126)\n Distributions.Normal{Float64}(μ=0.20875799655914307, σ=0.05507748397231652)\n Distributions.Normal{Float64}(μ=0.973572850227356, σ=0.07899004963915071)\n Distributions.Normal{Float64}(μ=0.9497100114822388, σ=0.07750126389821968)\n Distributions.Normal{Float64}(μ=0.22462180256843567, σ=0.07103664786246695)\n Distributions.Normal{Float64}(μ=-0.7654240131378174, σ=0.05501397704409917)\n Distributions.Normal{Float64}(μ=1.0029183626174927, σ=0.07619466916431794)\n ⋮\n Distributions.Normal{Float64}(μ=0.7475956678390503, σ=0.049875919157527815)\n Distributions.Normal{Float64}(μ=0.019430622458457947, σ=0.07445076746045155)\n Distributions.Normal{Float64}(μ=-0.9451781511306763, σ=0.05929712369810892)\n Distributions.Normal{Float64}(μ=-0.9813591241836548, σ=0.05844012710417755)\n Distributions.Normal{Float64}(μ=-0.6470385789871216, σ=0.055754609087554294)\n Distributions.Normal{Float64}(μ=-0.34288135170936584, σ=0.05533523375842789)\n Distributions.Normal{Float64}(μ=0.9912381172180176, σ=0.07872473667398772)\n Distributions.Normal{Float64}(μ=-0.824547290802002, σ=0.05499258101374759)\n Distributions.Normal{Float64}(μ=-0.3306621015071869, σ=0.06745251908756716)\n Distributions.Normal{Float64}(μ=0.3742436170578003, σ=0.10588913330223387)\n Distributions.Normal{Float64}(μ=0.0875578224658966, σ=0.07436153828228255)\n Distributions.Normal{Float64}(μ=-0.34871187806129456, σ=0.06742745343084512)","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"then we can plot the calibration plot of our neural model","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"Calibration_Plot(la,y_test,vec(predicted_distributions);n_bins = 20)","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"(Image: )","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"and compute the sharpness of the predictive distribution","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"sharpness_regression(vec(predicted_distributions))","category":"page"},{"location":"tutorials/regression/","page":"MLP Regression","title":"MLP Regression","text":"0.005058067743863281","category":"page"},{"location":"tutorials/multi/#Multi-class-problem","page":"MLP Multi-Label Classifier","title":"Multi-class problem","text":"","category":"section"},{"location":"tutorials/multi/#Libraries","page":"MLP Multi-Label Classifier","title":"Libraries","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"using Pkg; Pkg.activate(\"docs\")\n# Import libraries\nusing Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux\ntheme(:lime)","category":"page"},{"location":"tutorials/multi/#Data","page":"MLP Multi-Label Classifier","title":"Data","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"using LaplaceRedux.Data\nseed = 1234\nx, y = Data.toy_data_multi(seed=seed)\nX = hcat(x...)\ny_onehot = Flux.onehotbatch(y, unique(y))\ny_onehot = Flux.unstack(y_onehot',1)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"split in training and test datasets","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"# Shuffle the data\nRandom.seed!(seed)\nn = length(y)\nindices = randperm(n)\n\n# Define the split ratio\nsplit_ratio = 0.8\nsplit_index = Int(floor(split_ratio * n))\n\n# Split the data into training and test sets\ntrain_indices = indices[1:split_index]\ntest_indices = indices[split_index+1:end]\n\nx_train = x[train_indices]\nx_test = x[test_indices]\ny_onehot_train = y_onehot[train_indices,:]\ny_onehot_test = y_onehot[test_indices,:]\n\ny_train = vec(y[train_indices,:])\ny_test = vec(y[test_indices,:])\n# bring into tabular format\nX_train = hcat(x_train...) \nX_test = hcat(x_test...) \n\ndata = zip(x_train,y_onehot_train)\n#data = zip(x,y_onehot)","category":"page"},{"location":"tutorials/multi/#MLP","page":"MLP Multi-Label Classifier","title":"MLP","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"We set up a model","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"n_hidden = 3\nD = size(X,1)\nout_dim = length(unique(y))\nnn = Chain(\n Dense(D, n_hidden, σ),\n Dense(n_hidden, out_dim)\n) \nloss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"training:","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"using Flux.Optimise: update!, Adam\nopt = Adam()\nepochs = 100\navg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))\nshow_every = epochs/10\n\nfor epoch = 1:epochs\n for d in data\n gs = gradient(Flux.params(nn)) do\n l = loss(d...)\n end\n update!(opt, Flux.params(nn), gs)\n end\n if epoch % show_every == 0\n println(\"Epoch \" * string(epoch))\n @show avg_loss(data)\n end\nend","category":"page"},{"location":"tutorials/multi/#Laplace-Approximation","page":"MLP Multi-Label Classifier","title":"Laplace Approximation","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"The Laplace approximation can be implemented as follows:","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"la = Laplace(nn; likelihood=:classification)\nfit!(la, data)\noptimize_prior!(la; verbose=true, n_steps=100)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"with either the probit approximation:","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"_labels = sort(unique(y))\nplt_list = []\nfor target in _labels\n plt = plot(la, X_test, y_test; target=target, clim=(0,1))\n push!(plt_list, plt)\nend\nplot(plt_list...)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"or the plugin approximation:","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"_labels = sort(unique(y))\nplt_list = []\nfor target in _labels\n plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin)\n push!(plt_list, plt)\nend\nplot(plt_list...)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/multi/#Calibration-Plots","page":"MLP Multi-Label Classifier","title":"Calibration Plots","text":"","category":"section"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"In the case of multiclass classification tasks, we cannot plot the calibration plots directly since they can only be used in the binary classification case. However, we can use them to plot the calibration of the predictions for 1 class against all the others. To do so, we first have to collect the predicted categorical distributions","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"predicted_distributions= predict(la, X_test,ret_distr=true)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"1×20 Matrix{Distributions.Categorical{Float64, Vector{Float64}}}:\n Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569184, 0.196066, 0.0296796, 0.717336]) … Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569634, 0.195727, 0.0296449, 0.717665])","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"then we transform the categorical distributions into Bernoulli distributions by taking only the probability of the class of interest, for example the third one.","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"using Distributions\nbernoulli_distributions = [Bernoulli(p.p[3]) for p in vec(predicted_distributions)]","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"20-element Vector{Bernoulli{Float64}}:\n Bernoulli{Float64}(p=0.029679590887034743)\n Bernoulli{Float64}(p=0.6682373773598078)\n Bernoulli{Float64}(p=0.20912995228011141)\n Bernoulli{Float64}(p=0.20913322913224044)\n Bernoulli{Float64}(p=0.02971989045895732)\n Bernoulli{Float64}(p=0.668431087463204)\n Bernoulli{Float64}(p=0.03311710703617972)\n Bernoulli{Float64}(p=0.20912981531862682)\n Bernoulli{Float64}(p=0.11273726979027407)\n Bernoulli{Float64}(p=0.2490744632745955)\n Bernoulli{Float64}(p=0.029886357844211404)\n Bernoulli{Float64}(p=0.02965323602487074)\n Bernoulli{Float64}(p=0.1126799374664026)\n Bernoulli{Float64}(p=0.11278538625980777)\n Bernoulli{Float64}(p=0.6683139127616431)\n Bernoulli{Float64}(p=0.029644435143197145)\n Bernoulli{Float64}(p=0.11324691083703237)\n Bernoulli{Float64}(p=0.6681422555922787)\n Bernoulli{Float64}(p=0.668424345470233)\n Bernoulli{Float64}(p=0.029644891255330787)","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"Now we can use Calibration_Plot to see the level of calibration of the neural network","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"plt = Calibration_Plot(la,hcat(y_onehot_test...)[3,:],bernoulli_distributions;n_bins = 10);","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"(Image: )","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"The plot is peaked around 0.7.","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"A possible reason is that class 3 is relatively easy for the model to identify from the other classes, although it remains a bit underconfident in its predictions. Another reason for the peak may be the lack of cases where the predicted probability is lower (e.g., around 0.5), which could indicate that the network has not encountered ambiguous or difficult-to-classify examples for such class. This once again might be because either class 3 has distinct features that the model can easily learn, leading to fewer uncertain predictions, or is a consequence of the limited dataset.","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"We can measure how sharp the neural network is by computing the sharpness score","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"sharpnessclassification(hcat(yonehottest…)[3,:],vec(bernoullidistributions))","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"```","category":"page"},{"location":"tutorials/multi/","page":"MLP Multi-Label Classifier","title":"MLP Multi-Label Classifier","text":"The neural network seems to be able to correctly classify the majority of samples not belonging to class 3 with a relative high confidence, but remains more uncertain when he encounter examples belonging to class 3.","category":"page"},{"location":"reference/","page":"Reference","title":"Reference","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"reference/#All-functions-and-types","page":"Reference","title":"All functions and types","text":"","category":"section"},{"location":"reference/","page":"Reference","title":"Reference","text":"","category":"page"},{"location":"reference/#Exported-functions","page":"Reference","title":"Exported functions","text":"","category":"section"},{"location":"reference/","page":"Reference","title":"Reference","text":"Modules = [\n LaplaceRedux,\n LaplaceRedux.Curvature,\n LaplaceRedux.Data,\n]\nPrivate = false","category":"page"},{"location":"reference/#LaplaceRedux.Laplace","page":"Reference","title":"LaplaceRedux.Laplace","text":"Laplace\n\nConcrete type for Laplace approximation. This type is a subtype of AbstractLaplace and is used to store all the necessary information for a Laplace approximation.\n\nFields\n\nmodel::Flux.Chain: The model to be approximated.\nlikelihood::Symbol: The likelihood function to be used.\nest_params::EstimationParams: The estimation parameters.\nprior::Prior: The parameters defining prior distribution.\nposterior::Posterior: The posterior distribution.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Laplace-Tuple{Any}","page":"Reference","title":"LaplaceRedux.Laplace","text":"Laplace(model::Any; likelihood::Symbol, kwargs...)\n\nOuter constructor for Laplace approximation. This function constructs a Laplace object from a given model and likelihood function.\n\nArguments\n\nmodel::Any: The model to be approximated (a Flux.Chain).\nlikelihood::Symbol: The likelihood function to be used. Possible values are :regression and :classification.\n\nKeyword Arguments\n\nSee LaplaceParams for a description of the keyword arguments.\n\nReturns\n\nla::Laplace: The Laplace object.\n\nExamples\n\nusing Flux, LaplaceRedux\nnn = Chain(Dense(2,1))\nla = Laplace(nn, likelihood=:regression)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.LaplaceClassification","page":"Reference","title":"LaplaceRedux.LaplaceClassification","text":"MLJBase.@mlj_model mutable struct LaplaceClassification <: MLJFlux.MLJFluxProbabilistic\n\nA mutable struct representing a Laplace Classification model that extends the MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. \n\nThe model is defined by the following default parameters for all MLJFlux models:\n\nbuilder: a Flux model that constructs the neural network.\nfinaliser: a Flux model that processes the output of the neural network.\noptimiser: a Flux optimiser.\nloss: a loss function that takes the predicted output and the true output as arguments.\nepochs: the number of epochs.\nbatch_size: the size of a batch.\nlambda: the regularization strength.\nalpha: the regularization mix (0 for all l2, 1 for all l1).\nrng: a random number generator.\noptimiser_changes_trigger_retraining: a boolean indicating whether changes in the optimiser trigger retraining.\nacceleration: the computational resource to use.\n\nThe model also has the following parameters, which are specific to the Laplace approximation:\n\nsubset_of_weights: the subset of weights to use, either :all, :last_layer, or :subnetwork.\nsubnetwork_indices: the indices of the subnetworks.\nhessian_structure: the structure of the Hessian matrix, either :full or :diagonal.\nbackend: the backend to use, either :GGN or :EmpiricalFisher.\nσ: the standard deviation of the prior distribution.\nμ₀: the mean of the prior distribution.\nP₀: the covariance matrix of the prior distribution.\nlink_approx: the link approximation to use, either :probit or :plugin.\npredict_proba: a boolean that select whether to predict probabilities or not.\nret_distr: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.\nfit_prior_nsteps: the number of steps used to fit the priors.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.LaplaceRegression","page":"Reference","title":"LaplaceRedux.LaplaceRegression","text":"MLJBase.@mlj_model mutable struct LaplaceRegression <: MLJFlux.MLJFluxProbabilistic\n\nA mutable struct representing a Laplace regression model that extends the MLJFlux.MLJFluxProbabilistic abstract type. It uses Laplace approximation to estimate the posterior distribution of the weights of a neural network. \n\nThe model is defined by the following default parameters for all MLJFlux models:\n\nbuilder: a Flux model that constructs the neural network.\noptimiser: a Flux optimiser.\nloss: a loss function that takes the predicted output and the true output as arguments.\nepochs: the number of epochs.\nbatch_size: the size of a batch.\nlambda: the regularization strength.\nalpha: the regularization mix (0 for all l2, 1 for all l1).\nrng: a random number generator.\noptimiser_changes_trigger_retraining: a boolean indicating whether changes in the optimiser trigger retraining.\nacceleration: the computational resource to use.\n\nThe model also has the following parameters, which are specific to the Laplace approximation:\n\nsubset_of_weights: the subset of weights to use, either :all, :last_layer, or :subnetwork.\nsubnetwork_indices: the indices of the subnetworks.\nhessian_structure: the structure of the Hessian matrix, either :full or :diagonal.\nbackend: the backend to use, either :GGN or :EmpiricalFisher.\nσ: the standard deviation of the prior distribution.\nμ₀: the mean of the prior distribution.\nP₀: the covariance matrix of the prior distribution.\nret_distr: a boolean that tells predict to either return distributions (true) objects from Distributions.jl or just the probabilities.\nfit_prior_nsteps: the number of steps used to fit the priors.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.empirical_frequency_binary_classification-Tuple{Any, Vector{Distributions.Bernoulli{Float64}}}","page":"Reference","title":"LaplaceRedux.empirical_frequency_binary_classification","text":"empirical_frequency_binary_classification(y_binary, distributions::Vector{Bernoulli{Float64}}; n_bins::Int=20)\n\nFOR BINARY CLASSIFICATION MODELS.\nGiven a calibration dataset (x_t y_t) for i 1T let p_t= H(x_t)01 be the forecasted probability. \nWe group the p_t into intervals I_j for j= 12m that form a partition of [0,1]. The function computes the observed average p_j= T^-1_j _tp_t I_j y_j in each interval I_j. \nSource: Kuleshov, Fenner, Ermon 2018\n\nInputs: \n - y_binary: the array of outputs y_t numerically coded: 1 for the target class, 0 for the null class. \n - distributions: an array of Bernoulli distributions \n - n_bins: number of equally spaced bins to use.\n\nOutputs: \n - num_p_per_interval: array with the number of probabilities falling within interval. \n - emp_avg: array with the observed empirical average per interval. \n - bin_centers: array with the centers of the bins. \n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.empirical_frequency_regression-Tuple{Any, Vector{Distributions.Normal{Float64}}}","page":"Reference","title":"LaplaceRedux.empirical_frequency_regression","text":"empirical_frequency_regression(Y_cal, distributions::Distributions.Normal, n_bins=20)\n\nDispatched version for Normal distributions FOR REGRESSION MODELS. \nGiven a calibration dataset (x_t y_t) for i 1T and an array of predicted distributions, the function calculates the empirical frequency\n\np^hat_j = y_tF_t(y_t)= p_j t= 1TT\n\nwhere T is the number of calibration points, p_j is the confidence level and F_t is the cumulative distribution function of the predicted distribution targeting y_t. \nSource: Kuleshov, Fenner, Ermon 2018\n\nInputs: \n - Y_cal: a vector of values y_t\n - distributions:a Vector{Distributions.Normal{Float64}} of distributions stacked row-wise.\n For example the output of LaplaceRedux.predict(la,Xcal). \n - `nbins`: number of equally spaced bins to use.\nOutputs:\n - counts: an array cointaining the empirical frequencies for each quantile interval.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.extract_mean_and_variance-Union{Tuple{Array{Distributions.Normal{T}, 1}}, Tuple{T}} where T<:AbstractFloat","page":"Reference","title":"LaplaceRedux.extract_mean_and_variance","text":"extract_mean_and_variance(distr::Vector{Normal{<: AbstractFloat}})\n\nExtract the mean and the variance of each distributions and return them in two separate lists.\n\nInputs: - distributions: a Vector of Normal distributions \n\nOutputs: - means: the list of the means - variances: the list of the variances \n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.fit!-Tuple{LaplaceRedux.AbstractLaplace, Any}","page":"Reference","title":"LaplaceRedux.fit!","text":"fit!(la::AbstractLaplace,data)\n\nFits the Laplace approximation for a data set. The function returns the number of observations (n_data) that were used to update the Laplace object. It does not return the updated Laplace object itself because the function modifies the input Laplace object in place (as denoted by the use of '!' in the function's name).\n\nExamples\n\nusing Flux, LaplaceRedux\nx, y = LaplaceRedux.Data.toy_data_linear()\ndata = zip(x,y)\nnn = Chain(Dense(2,1))\nla = Laplace(nn)\nfit!(la, data)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.fit!-Tuple{LaplaceRedux.AbstractLaplace, MLUtils.DataLoader}","page":"Reference","title":"LaplaceRedux.fit!","text":"Fit the Laplace approximation, with batched data.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.glm_predictive_distribution-Tuple{LaplaceRedux.AbstractLaplace, AbstractArray}","page":"Reference","title":"LaplaceRedux.glm_predictive_distribution","text":"glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)\n\nComputes the linearized GLM predictive.\n\nArguments\n\nla::AbstractLaplace: A Laplace object.\nX::AbstractArray: Input data.\n\nReturns\n\nnormal_distr A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.- normal_distr A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.\nfμ::AbstractArray: Mean of the predictive distribution. The output shape is column-major as in Flux.\nfvar::AbstractArray: Variance of the predictive distribution. The output shape is column-major as in Flux.\n\nExamples\n\n```julia-repl using Flux, LaplaceRedux using LaplaceRedux.Data: toydatalinear x, y = toydatalinear() data = zip(x,y) nn = Chain(Dense(2,1)) la = Laplace(nn; likelihood=:classification) fit!(la, data) glmpredictivedistribution(la, hcat(x...))\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.optimize_prior!-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.optimize_prior!","text":"optimize_prior!(\n la::AbstractLaplace; \n n_steps::Int=100, lr::Real=1e-1,\n λinit::Union{Nothing,Real}=nothing,\n σinit::Union{Nothing,Real}=nothing\n)\n\nOptimize the prior precision post-hoc through Empirical Bayes (marginal log-likelihood maximization).\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.posterior_covariance","page":"Reference","title":"LaplaceRedux.posterior_covariance","text":"posterior_covariance(la::AbstractLaplace, P=la.P)\n\nComputes the posterior covariance as the inverse of the posterior precision: Sigma=P^-1.\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.posterior_precision","page":"Reference","title":"LaplaceRedux.posterior_precision","text":"posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.P₀)\n\nComputes the posterior precision P for a fitted Laplace Approximation as follows,\n\nP = sum_n=1^Nnabla_theta^2 log p(mathcalD_ntheta)_hattheta + nabla_theta^2 log p(theta)_hattheta\n\nwhere sum_n=1^Nnabla_theta^2log p(mathcalD_ntheta)_hattheta=H is the Hessian and nabla_theta^2 log p(theta)_hattheta=P_0 is the prior precision and hattheta is the MAP estimate.\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.predict-Tuple{LaplaceRedux.AbstractLaplace, AbstractArray}","page":"Reference","title":"LaplaceRedux.predict","text":"predict(la::AbstractLaplace, X::AbstractArray; link_approx=:probit, predict_proba::Bool=true)\n\nComputes predictions from Bayesian neural network.\n\nArguments\n\nla::AbstractLaplace: A Laplace object.\nX::AbstractArray: Input data.\nlink_approx::Symbol=:probit: Link function approximation. Options are :probit and :plugin.\npredict_proba::Bool=true: If true (default) apply a sigmoid or a softmax function to the output of the Flux model.\nreturn_distr::Bool=false: if false (default), the function output either the direct output of the chain or pseudo-probabilities (if predict_proba= true). if true predict return a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.\n\nReturns\n\nFor classification tasks, LaplaceRedux provides different options: if retdistr is false: - fμ::AbstractArray: Mean of the predictive distribution if link function is set to :plugin, otherwise the probit approximation. The output shape is column-major as in Flux. if retdistr is true: - a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks. For regression tasks:\n\nnormal_distr::Distributions.Normal:the array of Normal distributions computed by glmpredictivedistribution. \n\nExamples\n\nusing Flux, LaplaceRedux\nusing LaplaceRedux.Data: toy_data_linear\nx, y = toy_data_linear()\ndata = zip(x,y)\nnn = Chain(Dense(2,1))\nla = Laplace(nn; likelihood=:classification)\nfit!(la, data)\npredict(la, hcat(x...))\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.rescale_stddev-Union{Tuple{T}, Tuple{Array{Distributions.Normal{T}, 1}, T}} where T<:AbstractFloat","page":"Reference","title":"LaplaceRedux.rescale_stddev","text":"rescale_stddev(distr::Vector{Normal{T}}, s::T) where {T<:AbstractFloat}\n\nRescale the standard deviation of the Normal distributions received as argument and return a vector of rescaled Normal distributions. Inputs: \n - distr: a Vector of Normal distributions - s: a scale factor of type T.\n\nOutputs: \n - Vector{Normal{T}}: a Vector of rescaled Normal distributions.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.sharpness_classification-Tuple{Any, Vector{Distributions.Bernoulli{Float64}}}","page":"Reference","title":"LaplaceRedux.sharpness_classification","text":"sharpness_classification(y_binary,distributions::Distributions.Bernoulli)\n\ndispatched for Bernoulli Distributions FOR BINARY CLASSIFICATION MODELS. \nAssess the sharpness of the model by looking at the distribution of model predictions. When forecasts are sharp, most predictions are close to either 0 or 1 \nSource: Kuleshov, Fenner, Ermon 2018\n\nInputs: \n - y_binary: the array of outputs y_t numerically coded: 1 for the target class, 0 for the negative result. \n - distributions: an array of Bernoulli distributions describing the probability of of the output belonging to the target class \n Outputs: \n - mean_class_one: a scalar that measure the average prediction for the target class \n - mean_class_zero: a scalar that measure the average prediction for the null class \n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.sharpness_regression-Tuple{Vector{Distributions.Normal{Float64}}}","page":"Reference","title":"LaplaceRedux.sharpness_regression","text":"sharpness_regression(distributions::Distributions.Normal)\n\nDispatched version for Normal distributions FOR REGRESSION MODELS. \nGiven a calibration dataset (x_t y_t) for i 1T and an array of predicted distributions, the function calculates the sharpness of the predicted distributions, i.e., the average of the variances sigma^2(F_t) predicted by the forecaster for each x_t. \nsource: Kuleshov, Fenner, Ermon 2018\n\nInputs: \n - distributions: an array of normal distributions F(x_t) stacked row-wise. \nOutputs: \n - sharpness: a scalar that measure the level of sharpness of the regressor\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.sigma_scaling-Union{Tuple{T}, Tuple{Array{Distributions.Normal{T}, 1}, Vector{<:AbstractFloat}}} where T<:AbstractFloat","page":"Reference","title":"LaplaceRedux.sigma_scaling","text":"sigma_scaling(distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat}) where T <: AbstractFloat\n\nCompute the value of Σ that maximize the conditional log-likelihood:\n\n m ln(Σ) +12 * Σ^-2 _i=1^i=m y_cal_i - y_mean_i ^2 σ^2_i \n\nwhere m is the number of elements in the calibration set (xcal,ycal). \nSource: Laves,Ihler,Fast, Kahrs, Ortmaier,2020 Inputs: \n - distr: a Vector of Normal distributions \n - y_cal: a Vector of true results.\n\nOutputs: \n - sigma: the scalar that maximize the likelihood.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJModelInterface.predict-Tuple{LaplaceClassification, Any, Any}","page":"Reference","title":"MLJModelInterface.predict","text":"predict(model::LaplaceClassification, Xnew)\n\nPredicts the class labels for new data using the LaplaceClassification model.\n\nArguments\n\nmodel::LaplaceClassification: The trained LaplaceClassification model.\nfitresult: the fitresult output produced by MLJFlux.fit!\nXnew: The new data to make predictions on.\n\nReturns\n\nAn array of predicted class labels.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJModelInterface.predict-Tuple{LaplaceRegression, Any, Any}","page":"Reference","title":"MLJModelInterface.predict","text":"predict(model::LaplaceRegression, Xnew)\n\nPredict the output for new input data using a Laplace regression model.\n\nArguments\n\nmodel::LaplaceRegression: The trained Laplace regression model.\nthe fitresult output produced by MLJFlux.fit!\nXnew: The new input data.\n\nReturns\n\nThe predicted output for the new input data.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.CurvatureInterface","page":"Reference","title":"LaplaceRedux.Curvature.CurvatureInterface","text":"Base type for any curvature interface.\n\n\n\n\n\n","category":"type"},{"location":"reference/#Internal-functions","page":"Reference","title":"Internal functions","text":"","category":"section"},{"location":"reference/","page":"Reference","title":"Reference","text":"Modules = [\n LaplaceRedux,\n LaplaceRedux.Curvature,\n LaplaceRedux.Data,\n]\nPublic = false","category":"page"},{"location":"reference/#LaplaceRedux.AbstractDecomposition","page":"Reference","title":"LaplaceRedux.AbstractDecomposition","text":"Abstract type of Hessian decompositions.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.AbstractLaplace","page":"Reference","title":"LaplaceRedux.AbstractLaplace","text":"Abstract base type for all Laplace approximations in this library. All subclasses implemented are parametric.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.AbstractLaplace-Tuple{AbstractArray}","page":"Reference","title":"LaplaceRedux.AbstractLaplace","text":"(la::AbstractLaplace)(X::AbstractArray)\n\nCalling a model with Laplace Approximation on an array of inputs is equivalent to explicitly calling the predict function.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.EstimationParams","page":"Reference","title":"LaplaceRedux.EstimationParams","text":"EstimationParams\n\nContainer for the parameters of a Laplace approximation. \n\nFields\n\nsubset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.\nsubnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.\nhessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.\ncurvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.EstimationParams-Tuple{LaplaceRedux.LaplaceParams, Any, Symbol}","page":"Reference","title":"LaplaceRedux.EstimationParams","text":"EstimationParams(params::LaplaceParams)\n\nExtracts the estimation parameters from a LaplaceParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.FullHessian","page":"Reference","title":"LaplaceRedux.FullHessian","text":"Concrete type for full Hessian structure. This is the default structure.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.HessianStructure","page":"Reference","title":"LaplaceRedux.HessianStructure","text":"Abstract type for Hessian structure.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Kron","page":"Reference","title":"LaplaceRedux.Kron","text":"Kronecker-factored approximate curvature representation for a neural network model. Each element in kfacs represents two Kronecker factors (𝐆, 𝐀), such that the full block Hessian approximation would be approximated as 𝐀⊗𝐆.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.KronDecomposed","page":"Reference","title":"LaplaceRedux.KronDecomposed","text":"KronDecomposed\n\nDecomposed Kronecker-factored approximate curvature representation for a neural network model.\n\nDecomposition is required to add the prior (diagonal matrix) to the posterior (KronDecomposed). It also has the benefits of reducing the costs for computation of inverses and log-determinants.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.KronHessian","page":"Reference","title":"LaplaceRedux.KronHessian","text":"Concrete type for Kronecker-factored Hessian structure.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.LaplaceParams","page":"Reference","title":"LaplaceRedux.LaplaceParams","text":"LaplaceParams\n\nContainer for the parameters of a Laplace approximation.\n\nFields\n\nsubset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.\nsubnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.\nhessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.\nbackend::Symbol: the backend to use. Possible values are :GGN and :Fisher.\ncurvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.\nσ::Real: the observation noise\nμ₀::Real: the prior mean\nλ::Real: the prior precision\nP₀::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Posterior","page":"Reference","title":"LaplaceRedux.Posterior","text":"Posterior\n\nContainer for the results of a Laplace approximation.\n\nFields\n\nμ::AbstractVector: the MAP estimate of the parameters\nH::Union{AbstractArray,AbstractDecomposition,Nothing}: the Hessian matrix\nP::Union{AbstractArray,AbstractDecomposition,Nothing}: the posterior precision matrix\nΣ::Union{AbstractArray,Nothing}: the posterior covariance matrix\nn_data::Union{Int,Nothing}: the number of data points\nn_params::Union{Int,Nothing}: the number of parameters\nn_out::Union{Int,Nothing}: the number of outputs\nloss::Real: the loss value\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Posterior-Tuple{Any, LaplaceRedux.EstimationParams}","page":"Reference","title":"LaplaceRedux.Posterior","text":"Posterior(model::Any, est_params::EstimationParams)\n\nOuter constructor for Posterior object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Prior","page":"Reference","title":"LaplaceRedux.Prior","text":"Prior\n\nContainer for the prior parameters of a Laplace approximation.\n\nFields\n\nσ::Real: the observation noise\nμ₀::Real: the prior mean\nλ::Real: the prior precision\nP₀::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Prior-Tuple{LaplaceRedux.LaplaceParams, Any, Symbol}","page":"Reference","title":"LaplaceRedux.Prior","text":"Prior(params::LaplaceParams)\n\nExtracts the prior parameters from a LaplaceParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:*-Tuple{LaplaceRedux.KronDecomposed, Number}","page":"Reference","title":"Base.:*","text":"Multiply by a scalar by changing the eigenvalues. Distribute the scalar along the factors of a block.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:*-Tuple{Real, LaplaceRedux.Kron}","page":"Reference","title":"Base.:*","text":"Kronecker-factored curvature scalar scaling.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:+-Tuple{LaplaceRedux.Kron, LaplaceRedux.Kron}","page":"Reference","title":"Base.:+","text":"Kronecker-factored curvature sum.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:+-Tuple{LaplaceRedux.KronDecomposed, LinearAlgebra.Diagonal}","page":"Reference","title":"Base.:+","text":"Shift the factors by a diagonal (assumed uniform scaling)\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:+-Tuple{LaplaceRedux.KronDecomposed, Number}","page":"Reference","title":"Base.:+","text":"Shift the factors by a scalar across the diagonal.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.:==-Tuple{LaplaceRedux.Kron, LaplaceRedux.Kron}","page":"Reference","title":"Base.:==","text":"Kronecker-factored curvature equality.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.getindex-Tuple{LaplaceRedux.Kron, Int64}","page":"Reference","title":"Base.getindex","text":"Get Kronecker-factored block represenation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.getindex-Tuple{LaplaceRedux.KronDecomposed, Int64}","page":"Reference","title":"Base.getindex","text":"Get i-th block of a a Kronecker-factored curvature.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Base.length-Tuple{LaplaceRedux.KronDecomposed}","page":"Reference","title":"Base.length","text":"Number of blocks in a Kronecker-factored curvature.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Flux.params-Tuple{Any, LaplaceRedux.EstimationParams}","page":"Reference","title":"Flux.params","text":"Flux.params(model::Any, params::EstimationParams)\n\nExtracts the parameters of a model based on the subset of weights specified in the EstimationParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#Flux.params-Tuple{Laplace}","page":"Reference","title":"Flux.params","text":"Flux.params(la::Laplace)\n\nOverloads the params function for a Laplace object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._H_factor-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux._H_factor","text":"_H_factor(la::AbstractLaplace)\n\nReturns the factor σ⁻², where σ is used in the zero-centered Gaussian prior p(θ) = N(θ;0,σ²I)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._fit!-Tuple{Laplace, LaplaceRedux.FullHessian, Any}","page":"Reference","title":"LaplaceRedux._fit!","text":"_fit!(la::Laplace, hessian_structure::FullHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)\n\nFit a Laplace approximation to the posterior distribution of a model using the full Hessian.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._fit!-Tuple{Laplace, LaplaceRedux.KronHessian, Any}","page":"Reference","title":"LaplaceRedux._fit!","text":"_fit!(la::Laplace, hessian_structure::KronHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)\n\nFit a Laplace approximation to the posterior distribution of a model using the Kronecker-factored Hessian.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._init_H-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux._init_H","text":"_init_H(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux._weight_penalty-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux._weight_penalty","text":"_weight_penalty(la::AbstractLaplace)\n\nThe weight penalty term is a regularization term used to prevent overfitting. Weight regularization methods such as weight decay introduce a penalty to the loss function when training a neural network to encourage the network to use small weights. Smaller weights in a neural network can result in a model that is more stable and less likely to overfit the training dataset, in turn having better performance when making a prediction on new data.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.approximate-Tuple{LaplaceRedux.Curvature.CurvatureInterface, LaplaceRedux.FullHessian, Tuple}","page":"Reference","title":"LaplaceRedux.approximate","text":"approximate(curvature::CurvatureInterface, hessian_structure::FullHessian, d::Tuple; batched::Bool=false)\n\nCompute the full approximation, for either a single input-output datapoint or a batch of such. \n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.approximate-Tuple{LaplaceRedux.Curvature.CurvatureInterface, LaplaceRedux.KronHessian, Any}","page":"Reference","title":"LaplaceRedux.approximate","text":"approximate(curvature::CurvatureInterface, hessian_structure::KronHessian, data; batched::Bool=false)\n\nCompute the eigendecomposed Kronecker-factored approximate curvature as the Fisher information matrix.\n\nNote, since the network predictive distribution is used in a weighted sum, and the number of backward passes is linear in the number of target classes, e.g. 100 for CIFAR-100.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.clamp-Tuple{LinearAlgebra.Eigen}","page":"Reference","title":"LaplaceRedux.clamp","text":"Clamp eigenvalues in an eigendecomposition to be non-negative.\n\nSince the Fisher information matrix is a positive-semidefinite by construction, the (near-zero) negative eigenvalues should be neglected.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.convert_subnetwork_indices-Tuple{Vector{Vector{Int64}}, AbstractArray}","page":"Reference","title":"LaplaceRedux.convert_subnetwork_indices","text":"convertsubnetworkindices(subnetwork_indices::AbstractArray)\n\nConverts the subnetwork indices from the user given format [theta, row, column] to an Int i that corresponds to the index of that weight in the flattened array of weights.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.decompose-Tuple{LaplaceRedux.Kron}","page":"Reference","title":"LaplaceRedux.decompose","text":"decompose(K::Kron)\n\nEigendecompose Kronecker factors and turn into KronDecomposed.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.functional_variance-Tuple{Any, Any}","page":"Reference","title":"LaplaceRedux.functional_variance","text":"functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)\n\nCompute the functional variance for the GLM predictive. Dispatches to the appropriate method based on the Hessian structure.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.functional_variance-Tuple{Laplace, LaplaceRedux.FullHessian, Any}","page":"Reference","title":"LaplaceRedux.functional_variance","text":"functional_variance(la::Laplace,𝐉)\n\nCompute the linearized GLM predictive variance as 𝐉ₙΣ𝐉ₙ' where 𝐉=∇f(x;θ)|θ̂ is the Jacobian evaluated at the MAP estimate and Σ = P⁻¹.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.functional_variance-Tuple{Laplace, LaplaceRedux.KronHessian, Matrix}","page":"Reference","title":"LaplaceRedux.functional_variance","text":"functionalvariance(la::Laplace, hessianstructure::KronHessian, 𝐉::Matrix)\n\nCompute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ, where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output 𝐉=∇f(x;θ)|θ̂.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.get_loss_fun-Tuple{Symbol, Flux.Chain}","page":"Reference","title":"LaplaceRedux.get_loss_fun","text":"get_loss_fun(likelihood::Symbol)\n\nHelper function to choose loss function based on specified model likelihood.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.get_loss_type-Tuple{Symbol, Flux.Chain}","page":"Reference","title":"LaplaceRedux.get_loss_type","text":"get_loss_type(likelihood::Symbol)\n\nChoose loss function type based on specified model likelihood.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.get_map_estimate-Tuple{Any, LaplaceRedux.EstimationParams}","page":"Reference","title":"LaplaceRedux.get_map_estimate","text":"get_map_estimate(model::Any, est_params::EstimationParams)\n\nHelper function to extract the MAP estimate of the parameters for the model based on the subset of weights specified in the EstimationParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.get_prior_mean-Tuple{Laplace}","page":"Reference","title":"LaplaceRedux.get_prior_mean","text":"get_prior_mean(la::Laplace)\n\nHelper function to extract the prior mean of the parameters from a Laplace approximation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.has_softmax_or_sigmoid_final_layer-Tuple{Flux.Chain}","page":"Reference","title":"LaplaceRedux.has_softmax_or_sigmoid_final_layer","text":"has_softmax_or_sigmoid_final_layer(model::Flux.Chain)\n\nCheck if the FLux model ends with a sigmoid or with a softmax layer\n\nInput: - model: the Flux Chain object that represent the neural network. Return: - has_finaliser: true if the check is positive, false otherwise.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.hessian_approximation-Tuple{LaplaceRedux.AbstractLaplace, Any}","page":"Reference","title":"LaplaceRedux.hessian_approximation","text":"hessian_approximation(la::AbstractLaplace, d; batched::Bool=false)\n\nComputes the local Hessian approximation at a single datapoint d.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.instantiate_curvature!-Tuple{LaplaceRedux.EstimationParams, Any, Symbol, Symbol}","page":"Reference","title":"LaplaceRedux.instantiate_curvature!","text":"instantiate_curvature!(params::EstimationParams, model::Any, likelihood::Symbol, backend::Symbol)\n\nInstantiates the curvature interface for a Laplace approximation. The curvature interface is a concrete subtype of CurvatureInterface and is used to compute the Hessian matrix. The curvature interface is stored in the curvature field of the EstimationParams object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.interleave-Tuple","page":"Reference","title":"LaplaceRedux.interleave","text":"Interleave elements of multiple iterables in order provided.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.inv_square_form-Tuple{LaplaceRedux.KronDecomposed, Matrix}","page":"Reference","title":"LaplaceRedux.inv_square_form","text":"function invsquareform(K::KronDecomposed, W::Matrix)\n\nSpecial function to compute the inverse square form 𝐉𝐏⁻¹𝐉ᵀ (or 𝐖𝐊⁻¹𝐖ᵀ)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_det_posterior_precision-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_det_posterior_precision","text":"log_det_posterior_precision(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_det_prior_precision-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_det_prior_precision","text":"log_det_prior_precision(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_det_ratio-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_det_ratio","text":"log_det_ratio(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_likelihood-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_likelihood","text":"log_likelihood(la::AbstractLaplace)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.log_marginal_likelihood-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.log_marginal_likelihood","text":"log_marginal_likelihood(la::AbstractLaplace; P₀::Union{Nothing,UniformScaling}=nothing, σ::Union{Nothing, Real}=nothing)\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.logdetblock-Tuple{Tuple{LinearAlgebra.Eigen, LinearAlgebra.Eigen}, Number}","page":"Reference","title":"LaplaceRedux.logdetblock","text":"logdetblock(block::Tuple{Eigen,Eigen}, delta::Number)\n\nLog-determinant of a block in KronDecomposed, shifted by delta by on the diagonal.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.mm-Tuple{LaplaceRedux.KronDecomposed, Any}","page":"Reference","title":"LaplaceRedux.mm","text":"Matrix-multuply for the KronDecomposed Hessian approximation K and a 2-d matrix W, applying an exponent to K and transposing W before multiplication. Return (K^x)W^T, where x is the exponent.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.n_params-Tuple{Any, LaplaceRedux.EstimationParams}","page":"Reference","title":"LaplaceRedux.n_params","text":"n_params(model::Any, params::EstimationParams)\n\nHelper function to determine the number of parameters of a Flux.Chain with Laplace approximation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.n_params-Tuple{Laplace}","page":"Reference","title":"LaplaceRedux.n_params","text":"LaplaceRedux.n_params(la::Laplace)\n\nOverloads the n_params function for a Laplace object.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.outdim-Tuple{Flux.Chain}","page":"Reference","title":"LaplaceRedux.outdim","text":"outdim(model::Chain)\n\nHelper function to determine the output dimension of a Flux.Chain, corresponding to the number of neurons on the last layer of the NN.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.outdim-Tuple{LaplaceRedux.AbstractLaplace}","page":"Reference","title":"LaplaceRedux.outdim","text":"outdim(la::AbstractLaplace)\n\nHelper function to determine the output dimension, corresponding to the number of neurons on the last layer of the NN, of a Flux.Chain with Laplace approximation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.prior_precision-Tuple{Laplace}","page":"Reference","title":"LaplaceRedux.prior_precision","text":"prior_precision(la::Laplace)\n\nHelper function to extract the prior precision matrix from a Laplace approximation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.probit-Tuple{AbstractArray, AbstractArray}","page":"Reference","title":"LaplaceRedux.probit","text":"probit(fμ::AbstractArray, fvar::AbstractArray)\n\nCompute the probit approximation of the predictive distribution.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.validate_subnetwork_indices-Tuple{Union{Nothing, Vector{Vector{Int64}}}, Any}","page":"Reference","title":"LaplaceRedux.validate_subnetwork_indices","text":"validatesubnetworkindices( subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}, params )\n\nDetermines whether subnetwork_indices is a valid input for specified parameters.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LinearAlgebra.det-Tuple{LaplaceRedux.KronDecomposed}","page":"Reference","title":"LinearAlgebra.det","text":"det(K::KronDecomposed)\n\nLog-determinant of the KronDecomposed block-diagonal matrix, as the exponentiated log-determinant.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LinearAlgebra.logdet-Tuple{LaplaceRedux.KronDecomposed}","page":"Reference","title":"LinearAlgebra.logdet","text":"logdet(K::KronDecomposed)\n\nLog-determinant of the KronDecomposed block-diagonal matrix, as the product of the determinants of the blocks\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.build-Tuple{LaplaceClassification, Any, Any}","page":"Reference","title":"MLJFlux.build","text":"MLJFlux.build(model::LaplaceClassification, rng, shape)\n\nBuilds an MLJFlux model for Laplace classification compatible with the dimensions of the input and output layers specified by shape.\n\nArguments\n\nmodel::LaplaceClassification: The Laplace classification model.\nrng: A random number generator to ensure reproducibility.\nshape: A tuple or array specifying the dimensions of the input and output layers.\n\nReturns\n\nThe constructed MLJFlux model, compatible with the specified input and output dimensions.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.build-Tuple{LaplaceRegression, Any, Any}","page":"Reference","title":"MLJFlux.build","text":"MLJFlux.build(model::LaplaceRegression, rng, shape)\n\nBuilds an MLJFlux model for Laplace regression compatible with the dimensions of the input and output layers specified by shape.\n\nArguments\n\nmodel::LaplaceRegression: The Laplace regression model.\nrng: A random number generator to ensure reproducibility.\nshape: A tuple or array specifying the dimensions of the input and output layers.\n\nReturns\n\nThe constructed MLJFlux model, compatible with the specified input and output dimensions.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.fitresult-Tuple{LaplaceClassification, Any, Any}","page":"Reference","title":"MLJFlux.fitresult","text":"MLJFlux.fitresult(model::LaplaceClassification, chain, y)\n\nComputes the fit result for a Laplace classification model, returning the model chain and the number of unique classes in the target data.\n\nArguments\n\nmodel::LaplaceClassification: The Laplace classification model to be evaluated.\nchain: The trained model chain.\ny: The target data, typically a vector of class labels.\n\nReturns\n\nReturns\n\nA tuple containing:\n\nThe trained Flux chain.\na deepcopy of the laplace model.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.fitresult-Tuple{LaplaceRegression, Any, Any}","page":"Reference","title":"MLJFlux.fitresult","text":"MLJFlux.fitresult(model::LaplaceRegression, chain, y)\n\nComputes the fit result for a Laplace Regression model, returning the model chain and the number of output variables in the target data.\n\nArguments\n\nmodel::LaplaceRegression: The Laplace Regression model to be evaluated.\nchain: The trained model chain.\ny: The target data, typically a vector of class labels.\n\nReturns\n\nA tuple containing:\n\nThe trained Flux chain.\na deepcopy of the laplace model.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.shape-Tuple{LaplaceRegression, Any, Any}","page":"Reference","title":"MLJFlux.shape","text":"MLJFlux.shape(model::LaplaceRegression, X, y)\n\nCompute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.\n\nArguments\n\nmodel::LaplaceRegression: The LaplaceRegression model to fit.\nX: The input data for training.\ny: The target labels for training one-hot encoded.\n\nReturns\n\n(input size, output size)\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.train-Tuple{LaplaceClassification, Vararg{Any, 7}}","page":"Reference","title":"MLJFlux.train","text":"MLJFlux.train(model::LaplaceClassification, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)\n\nFit the LaplaceRegression model using Flux.jl.\n\nArguments\n\nmodel::LaplaceClassification: The LaplaceClassification model.\nregularized_optimiser: the regularized optimiser to apply to the loss function.\noptimiser_state: thestate of the optimiser.\nepochs: The number of epochs for training.\nverbosity: The verbosity level for training.\nX: The input data for training.\ny: The target labels for training.\n\nReturns (fitresult, cache, report )\n\nwhere\n\nla: the fitted Laplace model.\noptimiser_state: the state of the optimiser.\nhistory: the training loss history.\n\n\n\n\n\n","category":"method"},{"location":"reference/#MLJFlux.train-Tuple{LaplaceRegression, Vararg{Any, 7}}","page":"Reference","title":"MLJFlux.train","text":"MLJFlux.train(model::LaplaceRegression, chain, regularized_optimiser, optimiser_state, epochs, verbosity, X, y)\n\nFit the LaplaceRegression model using Flux.jl.\n\nArguments\n\nmodel::LaplaceRegression: The LaplaceRegression model.\nregularized_optimiser: the regularized optimiser to apply to the loss function.\noptimiser_state: thestate of the optimiser.\nepochs: The number of epochs for training.\nverbosity: The verbosity level for training.\nX: The input data for training.\ny: The target labels for training.\n\nReturns (la, optimiser_state, history )\n\nwhere\n\nla: the fitted Laplace model.\noptimiser_state: the state of the optimiser.\nhistory: the training loss history.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.@zb-Tuple{Any}","page":"Reference","title":"LaplaceRedux.@zb","text":"Macro for zero-based indexing. Example of usage: (@zb A[0]) = ...\n\n\n\n\n\n","category":"macro"},{"location":"reference/#LaplaceRedux.Curvature.EmpiricalFisher","page":"Reference","title":"LaplaceRedux.Curvature.EmpiricalFisher","text":"Constructor for curvature approximated by empirical Fisher.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Curvature.GGN","page":"Reference","title":"LaplaceRedux.Curvature.GGN","text":"Constructor for curvature approximated by Generalized Gauss-Newton.\n\n\n\n\n\n","category":"type"},{"location":"reference/#LaplaceRedux.Curvature.full_batched-Tuple{LaplaceRedux.Curvature.EmpiricalFisher, Tuple}","page":"Reference","title":"LaplaceRedux.Curvature.full_batched","text":"full_batched(curvature::EmpiricalFisher, d::Tuple)\n\nCompute the full empirical Fisher for batch of inputs-outputs, with the batch dimension at the end.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.full_batched-Tuple{LaplaceRedux.Curvature.GGN, Tuple}","page":"Reference","title":"LaplaceRedux.Curvature.full_batched","text":"full_batched(curvature::GGN, d::Tuple)\n\nCompute the full GGN for batch of inputs-outputs, with the batch dimension at the end.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.full_unbatched-Tuple{LaplaceRedux.Curvature.EmpiricalFisher, Tuple}","page":"Reference","title":"LaplaceRedux.Curvature.full_unbatched","text":"full_unbatched(curvature::EmpiricalFisher, d::Tuple)\n\nCompute the full empirical Fisher for a single datapoint.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.full_unbatched-Tuple{LaplaceRedux.Curvature.GGN, Tuple}","page":"Reference","title":"LaplaceRedux.Curvature.full_unbatched","text":"full_unbatched(curvature::GGN, d::Tuple)\n\nCompute the full GGN for a singular input-ouput datapoint. \n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.gradients-Tuple{LaplaceRedux.Curvature.CurvatureInterface, AbstractArray, Union{Number, AbstractArray}}","page":"Reference","title":"LaplaceRedux.Curvature.gradients","text":"gradients(curvature::CurvatureInterface, X::AbstractArray, y::Number)\n\nCompute the gradients with respect to the loss function: ∇ℓ(f(x;θ),y) where f: ℝᴰ ↦ ℝᴷ.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.jacobians-Tuple{LaplaceRedux.Curvature.CurvatureInterface, AbstractArray}","page":"Reference","title":"LaplaceRedux.Curvature.jacobians","text":"jacobians(curvature::CurvatureInterface, X::AbstractArray; batched::Bool=false)\n\nComputes the Jacobian ∇f(x;θ) where f: ℝᴰ ↦ ℝᴷ.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.jacobians_batched-Tuple{LaplaceRedux.Curvature.CurvatureInterface, AbstractArray}","page":"Reference","title":"LaplaceRedux.Curvature.jacobians_batched","text":"jacobians_batched(curvature::CurvatureInterface, X::AbstractArray)\n\nCompute Jacobians of the model output w.r.t. model parameters for points in X, with batching.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Curvature.jacobians_unbatched-Tuple{LaplaceRedux.Curvature.CurvatureInterface, AbstractArray}","page":"Reference","title":"LaplaceRedux.Curvature.jacobians_unbatched","text":"jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray)\n\nCompute the Jacobian of the model output w.r.t. model parameters for the point X, without batching. Here, the nn function is wrapped in an anonymous function using the () -> syntax, which allows it to be differentiated using automatic differentiation.\n\n\n\n\n\n","category":"method"},{"location":"reference/#LaplaceRedux.Data.toy_data_linear","page":"Reference","title":"LaplaceRedux.Data.toy_data_linear","text":"toy_data_linear(N=100)\n\nExamples\n\ntoy_data_linear()\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.Data.toy_data_multi","page":"Reference","title":"LaplaceRedux.Data.toy_data_multi","text":"toy_data_multi(N=100)\n\nExamples\n\ntoy_data_multi()\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.Data.toy_data_non_linear","page":"Reference","title":"LaplaceRedux.Data.toy_data_non_linear","text":"toy_data_non_linear(N=100)\n\nExamples\n\ntoy_data_non_linear()\n\n\n\n\n\n","category":"function"},{"location":"reference/#LaplaceRedux.Data.toy_data_regression","page":"Reference","title":"LaplaceRedux.Data.toy_data_regression","text":"toy_data_regression(N=25, p=1; noise=0.3, fun::Function=f(x)=sin(2 * π * x))\n\nA helper function to generate synthetic data for regression.\n\n\n\n\n\n","category":"function"},{"location":"mlj_interface/","page":"MLJ interface","title":"MLJ interface","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"mlj_interface/#Interface-to-the-MLJ-framework","page":"MLJ interface","title":"Interface to the MLJ framework","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"","page":"Home","title":"Home","text":"(Image: )","category":"page"},{"location":"","page":"Home","title":"Home","text":"Documentation for LaplaceRedux.jl.","category":"page"},{"location":"#LaplaceRedux","page":"Home","title":"LaplaceRedux","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"LaplaceRedux.jl is a library written in pure Julia that can be used for effortless Bayesian Deep Learning through Laplace Approximation (LA). In the development of this package I have drawn inspiration from this Python library and its companion paper (Daxberger et al. 2021).","category":"page"},{"location":"#Installation","page":"Home","title":"🚩 Installation","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"The stable version of this package can be installed as follows:","category":"page"},{"location":"","page":"Home","title":"Home","text":"using Pkg\nPkg.add(\"LaplaceRedux.jl\")","category":"page"},{"location":"","page":"Home","title":"Home","text":"The development version can be installed like so:","category":"page"},{"location":"","page":"Home","title":"Home","text":"using Pkg\nPkg.add(\"https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl\")","category":"page"},{"location":"#Getting-Started","page":"Home","title":"🏃 Getting Started","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"If you are new to Deep Learning in Julia or simply prefer learning through videos, check out this awesome YouTube tutorial by doggo.jl 🐶. Additionally, you can also find a video of my presentation at JuliaCon 2022 on YouTube.","category":"page"},{"location":"#Basic-Usage","page":"Home","title":"🖥️ Basic Usage","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"LaplaceRedux.jl can be used for any neural network trained in Flux.jl. Below we show basic usage examples involving two simple models for a regression and a classification task, respectively.","category":"page"},{"location":"#Regression","page":"Home","title":"Regression","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"A complete worked example for a regression model can be found in the docs. Here we jump straight to Laplace Approximation and take the pre-trained model nn as given. Then LA can be implemented as follows, where we specify the model likelihood. The plot shows the fitted values overlaid with a 95% confidence interval. As expected, predictive uncertainty quickly increases in areas that are not populated by any training data.","category":"page"},{"location":"","page":"Home","title":"Home","text":"la = Laplace(nn; likelihood=:regression)\nfit!(la, data)\noptimize_prior!(la)\nplot(la, X, y; zoom=-5, size=(500,500))","category":"page"},{"location":"","page":"Home","title":"Home","text":"(Image: )","category":"page"},{"location":"#Binary-Classification","page":"Home","title":"Binary Classification","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"Once again we jump straight to LA and refer to the docs for a complete worked example involving binary classification. In this case we need to specify likelihood=:classification. The plot below shows the resulting posterior predictive distributions as contours in the two-dimensional feature space: note how the Plugin Approximation on the left compares to the Laplace Approximation on the right.","category":"page"},{"location":"","page":"Home","title":"Home","text":"la = Laplace(nn; likelihood=:classification)\nfit!(la, data)\nla_untuned = deepcopy(la) # saving for plotting\noptimize_prior!(la; n_steps=100)\n\n# Plot the posterior predictive distribution:\nzoom=0\np_plugin = plot(la, X, ys; title=\"Plugin\", link_approx=:plugin, clim=(0,1))\np_untuned = plot(la_untuned, X, ys; title=\"LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))\", clim=(0,1), zoom=zoom)\np_laplace = plot(la, X, ys; title=\"LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))\", clim=(0,1), zoom=zoom)\nplot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))","category":"page"},{"location":"","page":"Home","title":"Home","text":"(Image: )","category":"page"},{"location":"#JuliaCon-2022","page":"Home","title":"📢 JuliaCon 2022","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"This project was presented at JuliaCon 2022 in July 2022. See here for details.","category":"page"},{"location":"#Contribute","page":"Home","title":"🛠️ Contribute","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"Contributions are very much welcome! Please follow the SciML ColPrac guide. You may want to start by having a look at any open issues.","category":"page"},{"location":"#References","page":"Home","title":"🎓 References","text":"","category":"section"},{"location":"","page":"Home","title":"Home","text":"Daxberger, Erik, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. 2021. “Laplace Redux-Effortless Bayesian Deep Learning.” Advances in Neural Information Processing Systems 34.","category":"page"},{"location":"resources/_resources/#Additional-Resources","page":"Additional Resources","title":"Additional Resources","text":"","category":"section"},{"location":"resources/_resources/#JuliaCon-2022","page":"Additional Resources","title":"JuliaCon 2022","text":"","category":"section"},{"location":"resources/_resources/","page":"Additional Resources","title":"Additional Resources","text":"Slides: link","category":"page"},{"location":"resources/_resources/","page":"Additional Resources","title":"Additional Resources","text":"","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"CurrentModule = LaplaceRedux","category":"page"},{"location":"tutorials/calibration/#Uncertainty-Calibration","page":"Calibrated forecasts","title":"Uncertainty Calibration","text":"","category":"section"},{"location":"tutorials/calibration/#The-issue-of-calibrated-uncertainty-distributions","page":"Calibrated forecasts","title":"The issue of calibrated uncertainty distributions","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"Bayesian methods offer a general framework for quantifying uncertainty. However, due to model misspecification and the use of approximate inference techniques,uncertainty estimates are often inaccurate: for example, a 90% credible interval may not contain the true outcome 90% of the time, in such cases the model is said to be miscalibrated. This problem arises due to the limitations of the model itself: a predictor may not be sufficiently expressive to assign the right probability to every credible interval, just as it may not be able to always assign the right label to a datapoint. Miscalibrated credible intervals reduce the trustworthiness of the forecaster because they lead to a false sense of precision and either overconfidence or underconfidence in the results.","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"A forecaster is said to be perfectly calibrated if a 90% credible interval contains the true outcome approximately 90% of the time. Perfect calibration however cannot be achieved with limited data, because with limited data comes inherent statistical fluctuations that can cause the estimated credible intervals to deviate from the ideal coverage probability. Furthermore, a finite sample of collected data points cannot eliminate completely the influence of the possible misjudged prior probabilities. On top of these issues, which stem directly from Bayes’ theorem, with Bayesian neural networks there is also the problems introduced by the approximate inference method adopted to compute the posterior distribution of the weights.","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"To introduce the concept of Average Calibration and the Calibration Plots, we will follow closely the paper Accurate Uncertainties for Deep Learning Using Calibrated Regression, written by Volodymyr Kuleshov, Nathan Fenner and Stefano Ermon, although with some small differences. We will hightlight these differences in the following paragraphs whenever they appear. We present here the theoretical basis necessary to understand the issue of calibration and we refer to the tutorials for the coding examples.","category":"page"},{"location":"tutorials/calibration/#Notation","page":"Calibrated forecasts","title":"Notation","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"We are given a labeled dataset x_t y_t in X times Y for t = 1 2 T of i.i.d. realizations of random variables X Y sim P, where P is the data distribution. Given x_t, a forecaster H X rightarrow (Y rightarrow 0 1) outputs at each step t a CDF F_t(y) targeting the label y_t. When Y is continuous, F_t is a cumulative probability distribution (CDF). We will use F^1_t 0 1 Y to denote the quantile function F^1_t (p) = infy p F_t(y).","category":"page"},{"location":"tutorials/calibration/#Calibration-in-the-Regression-case","page":"Calibrated forecasts","title":"Calibration in the Regression case","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"In the regression case, we say that the forecaster H is (on average) calibrated if","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"fracsum_t=1^T mathbb1 y_t leq F_t^-1(p) T rightarrow p quad textfor allquad p in 01","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"as T rightarrow infty.  In other words, the empirical and the predicted CDFs should match as the dataset size goes to infinity. Perfect Calibration is a sufficient condition for average calibration, the opposite however is not necessarily true: a model can be average calibrated but not perfectly calibrated. From now on when we talk about calibration we will implicitly talk about average calibration rather than perfect calibration.","category":"page"},{"location":"tutorials/calibration/#Sharpness","page":"Calibrated forecasts","title":"Sharpness","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"Calibration by itself is not sufficient to produce a useful forecast. For example, it is easy to see that if we use for the forecast the marginal distribution F(y) = mathbbP(Y y), without considering the input feature X, the forecast will be calibrated but still not accurate. In order to be useful, forecasts must also be sharp, which (in a regression context) means that the confidence intervals should all be as tight as possible around a single value. More formally, we want the variance var(F_t) of the random variable whose CDF is F_t to be small. As a sharpness score of the forecaster, Kuleshov et al. proposed the average predicted variance","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"sharpness(F_1 dots F_T) = frac1T sum_t=1^T varF_t","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"the smaller the sharpness, the tighter will be the confidence intervals on average.","category":"page"},{"location":"tutorials/calibration/#Calibration-Plots","page":"Calibrated forecasts","title":"Calibration Plots","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"To check the level of calibration, Kuleshov et al. proposed a calibration plot that displays the true frequency of points in each confidence interval relative to the predicted fraction of points in that interval. More formally, we choose m confidence levels 0 p_1 p_2 p_m 1; for each threshold p_j , and compute the empirical frequency","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"hatp_j = frac y_tF_t(y_t) leq p_t t= 12dotsT T","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"To visualize the level of average calibration, we plot (p_jhatp_j) _j=1^M; A forecaster that is calibrated will correspond to a straight line on the plot that goes from 00 to 11 . o measure the level of miscalibration, we can compute the area between the diagonal line (the line of perfect calibration) and the calibration curve produced by the forecaster. This area represents the degree to which the predicted probabilities deviate from the actual observed frequencies. Alternatively, the original paper suggests using the calibration error as a numerical score to describe the quality of forecast calibration: cal(F_1F_2dotsF_Ty_T) = sum_j=1^m w_j (p_j -hatp_j)^2 where the scalars w_j are tunable weights. Both methods — using the area between the calibration curve and the diagonal line or using the calibration error — are equivalent in that they both provide a numerical measure of miscalibration. The calibration error can be seen as a discretized approximation of the area, where the weights 𝑤_𝑗 adjust for the distribution of samples across different bins.","category":"page"},{"location":"tutorials/calibration/#Post-training-calibration","page":"Calibrated forecasts","title":"Post-training calibration","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"As we have said previously, uncertainty estimates obtained by deep BNNs tend to be miscalibrated. We introduced the support to a post-training technique for regression problems presented in Recalibration of Aleatoric and Epistemic Regression Uncertainty in Medical Imaging by Max-Heinrich Laves, Sontje Ihler, Jacob F. Fast, Lüder A. Kahrs and Tobias Ortmaier and usually referred to as sigma-scaling. Using a Gaussian model, the technique consist in scaling the predicted standard deviation sigma with a scalar value s to recalibrate the probability density function","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"p(yx haty(x) hatσ^2(x)) = mathbbN( y haty(x)(s cdot hatσ(x))^2 )","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"This results in the following minimization objective:","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"L_G(s) = m log(s) + frac12s^2 sum_i=1^m (hatσ^(i)_θ)^2 y^(i) haty_theta^(i)^2","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"In general, this equation can be optimized respect to s with fixed values for the parameters theta using gradient descent in a second phase over a separate calibration set. However, for the case of a gaussian distribution, the analytical solution is known and takes the closed form","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"s = pm sqrtfrac1m sum_i=1^m (hatσ^(i)_θ)^2 y^(i) haty_theta^(i)^2","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"Once the scalar s is computed, all we have to do to obtain better calibrated predictions is to multiply the predicted standard deviation with the scalar. The main difference from the other supported technique for optimization of the uncertainty estimates, which is based on Empirical Bayes, is that sigma-scaling tries to optimize directly the posterior predictive distribution using a separate calibration dataset. Empirical Bayes instead involves estimating the parameters of the prior distribution by maximizing the marginal likelihood of the observed data, effectively finding a prior that best explains the data without needing a separate calibration dataset. Empirical Bayes uses the observed data itself to indirectly improve the calibration of uncertainty estimates by refining the prior, whereas sigma-scaling focuses on a direct post-hoc adjustment to the posterior predictive uncertainty.","category":"page"},{"location":"tutorials/calibration/#Calibration-in-the-Binary-Classification-case","page":"Calibrated forecasts","title":"Calibration in the Binary Classification case","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"In binary classification, we have Y = 0 1, and we say that H is calibrated if","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"fracsum_t=1^T y_tmathbb1H(x_t)=psum_t=1^Tmathbb1H(x_t)=p rightarrow p quad textfor all quad pin 01","category":"page"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"as T rightarrow infty. For simplicity, we have denoted H(x_t) as the probability of the event y_t=1. Once again, perfect calibration is a sufficient condition for calibration.","category":"page"},{"location":"tutorials/calibration/#Sharpness-2","page":"Calibrated forecasts","title":"Sharpness","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"We can assess sharpness by looking at the distribution of model predictions. When forecasts are sharp, most predicted probabilities for the correct class are close to 1; unsharp forecasters make predictions closer to 05.","category":"page"},{"location":"tutorials/calibration/#Calibration-Plots-2","page":"Calibrated forecasts","title":"Calibration Plots","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"Given a dataset (x_t y_t)^T_t=1, let p_t = H(x_t) 0 1 be the forecasted probability. We group the p_t into intervals I_j for j = 1 2 m that form a partition of 0 1. A calibration curve plots the predicted average $ pj = T^{−1}*j *{t:pt∈ Ij} pt $ in each interval I_j against the observed empirical average p_j = T^1_j sum_tp_t I_jy_t where T_j = t p_t I_j. Perfect calibration corresponds once again to a straight line.","category":"page"},{"location":"tutorials/calibration/#Multiclass-Case","page":"Calibrated forecasts","title":"Multiclass Case","text":"","category":"section"},{"location":"tutorials/calibration/","page":"Calibrated forecasts","title":"Calibrated forecasts","text":"For multiclass classification tasks the above technique can be extended by plotting each class versus all the remaining classes considered as one.","category":"page"}] } diff --git a/dev/tutorials/calibration.qmd b/dev/tutorials/calibration.qmd new file mode 100644 index 00000000..8c9a3470 --- /dev/null +++ b/dev/tutorials/calibration.qmd @@ -0,0 +1,97 @@ + + +``` @meta +CurrentModule = LaplaceRedux +``` + +# Uncertainty Calibration + +# The issue of calibrated uncertainty distributions + + +Bayesian methods offer a general framework for quantifying uncertainty. However, due to model misspecification and the use of approximate inference techniques,uncertainty estimates are often inaccurate: for example, a 90% credible interval may not contain the true outcome 90% of the time, in such cases the model is said to be miscalibrated. This problem arises due to the limitations of the model itself: a predictor may not be sufficiently expressive to assign the right probability to every credible interval, just as it may not be able to always assign the right label to a datapoint. Miscalibrated credible intervals reduce the trustworthiness of the forecaster because they lead to a false sense of precision and either overconfidence or underconfidence in the results. + +A forecaster is said to be perfectly calibrated if a 90% credible interval contains the true outcome approximately 90% of the time. Perfect calibration however cannot be achieved with limited data, because with limited data comes inherent statistical fluctuations that can cause the estimated credible intervals to deviate from the ideal coverage probability. Furthermore, a finite sample of collected data points cannot eliminate completely the influence of the possible misjudged prior probabilities. +On top of these issues, which stem directly from Bayes' theorem, with Bayesian neural networks there is also the problems introduced by the approximate inference method adopted to compute the posterior distribution of the weights. + +To introduce the concept of Average Calibration and the Calibration Plots, we will follow closely the paper [Accurate Uncertainties for Deep Learning Using Calibrated Regression](https://arxiv.org/abs/1807.00263), written by Volodymyr Kuleshov, Nathan Fenner and Stefano Ermon, although with some small differences. We will hightlight these differences in the following paragraphs whenever they appear. We present here the theoretical basis necessary to understand the issue of calibration and we refer to the tutorials for the coding examples. + +## Notation +We are given a labeled dataset $x_t, y_t \in X \times Y$ for $t = 1, 2, ..., T$ of i.i.d. realizations of random variables $X, Y \sim P$, where $P$ is the data distribution. +Given $x_t$, a forecaster $H : X \rightarrow (Y \rightarrow [0, 1])$ outputs at each step $t$ a CDF $F_t(y)$ targeting the label $y_t$. When $Y$ is continuous, $F_t$ is a cumulative probability distribution (CDF). We will use $F^{−1}_t: [0, 1] → Y$ to denote the quantile function $F^{−1}_t (p) = inf\{y : p ≤ F_t(y)\}$. + +## Calibration in the Regression case +In the regression case, we say that the forecaster H is (on average) calibrated if + +$$\frac{\sum_{t=1}^T \mathbb{1} \{ y_t \leq F_t^{-1}(p) \} }{T} \rightarrow p \quad \text{for all}\quad p \in [0,1]$$ + +as $T \rightarrow \infty$.\ In other words, the empirical and the predicted CDFs should match as the dataset size goes to infinity. +Perfect Calibration is a sufficient condition for average calibration, the opposite however is not necessarily true: a model can be average calibrated but not perfectly calibrated.\ +From now on when we talk about calibration we will implicitly talk about average calibration rather than perfect calibration. + + +### Sharpness +Calibration by itself is not sufficient to produce a useful forecast. For example, it is easy to see that if we use for the forecast the marginal distribution $F(y) = \mathbb{P}(Y ≤ y)$, without considering the input feature $X$, the forecast will be calibrated but still not accurate. In order to be useful, forecasts must also be sharp, which (in a regression context) means that the confidence intervals should all be as tight as possible around a single value. More formally, we want the variance $var(F_t)$ of the random variable whose CDF is $F_t$ to be small.\ +As a sharpness score of the forecaster, Kuleshov et al. proposed the average predicted variance + +$$sharpness(F_{1} ,\dots, F_T) = \frac{1}{T} \sum_{t=1}^T var{F_t}$$ \ +the smaller the sharpness, the tighter will be the confidence intervals on average. + +### Calibration Plots + +To check the level of calibration, Kuleshov et al. proposed a calibration plot that displays the true frequency of points in each confidence interval relative to the predicted fraction of points in that interval.\ +More formally, we choose $m$ confidence levels $0 ≤ p_1 < p_2 < . . . < p_m ≤ 1$; for each threshold $p_j$ , and compute the empirical frequency + +$$\hat{p}_j = \frac{|\{ y_t|F_t(y_t) \leq p_t, t= 1,2,\dots,T \} |}{T}.$$ + +To visualize the level of average calibration, we plot $\{(p_j,\hat{p_j}) \}_{j=1}^M$; A forecaster that is calibrated will correspond to a straight line on the plot that goes from $\{0,0\}$ to $\{1,1\}$ . \ +o measure the level of miscalibration, we can compute the area between the diagonal line (the line of perfect calibration) and the calibration curve produced by the forecaster. This area represents the degree to which the predicted probabilities deviate from the actual observed frequencies. +Alternatively, the original paper suggests using the calibration error as a numerical score to describe the quality of forecast calibration: +$$cal(F_1,F_2,\dots,F_T,y_T) = \sum_{j=1}^m w_j (p_j -\hat{p}_j)^2$$ +where the scalars $w_j$ are tunable weights. +Both methods — using the area between the calibration curve and the diagonal line or using the calibration error — are equivalent in that they both provide a numerical measure of miscalibration. The calibration error can be seen as a discretized approximation of the area, where the weights $𝑤_𝑗$ adjust for the distribution of samples across different bins. + + +### Post-training calibration +As we have said previously, uncertainty estimates obtained by deep BNNs tend to be miscalibrated. We introduced the support to a post-training technique for regression problems presented in [Recalibration of Aleatoric and Epistemic Regression Uncertainty in Medical Imaging](https://arxiv.org/abs/2104.12376) +by Max-Heinrich Laves, Sontje Ihler, Jacob F. Fast, Lüder A. Kahrs and Tobias Ortmaier and usually referred to as sigma-scaling. Using a Gaussian model, the technique consist in scaling the predicted standard deviation $\sigma$ with a scalar value $s$ to recalibrate the probability density function + +$$p(y|x; \hat{y}(x), \hat{σ}^2(x)) = \mathbb{N}( y; \hat{y}(x),(s \cdot \hat{σ}(x))^2 ).$$ + +This results in the following minimization objective: + +$$L_G(s) = m \log(s) + \frac{1}{2}s^{−2} \sum_{i=1}^m (\hat{σ}^{(i)}_θ)^{−2} || y^{(i)} − \hat{y}_{\theta}^{(i)}||^2.$$ + +In general, this equation can be optimized respect to $s$ with fixed values for the parameters $\theta$ using gradient descent in a second phase over a separate calibration set. However, for the case of a gaussian distribution, the analytical solution is known and takes the closed form + +$$s = \pm \sqrt{\frac{1}{m} \sum_{i=1}^m (\hat{σ}^{(i)}_θ)^{−2} || y^{(i)} − \hat{y}_{\theta}^{(i)}||^2}.$$ + +Once the scalar $s$ is computed, all we have to do to obtain better calibrated predictions is to multiply the predicted standard deviation with the scalar. +The main difference from the other supported technique for optimization of the uncertainty estimates, which is based on Empirical Bayes, is that sigma-scaling tries to optimize directly the posterior predictive distribution using a separate calibration dataset. Empirical Bayes instead involves estimating the parameters of the prior distribution by maximizing the marginal likelihood of the observed data, effectively finding a prior that best explains the data without needing a separate calibration dataset. Empirical Bayes uses the observed data itself to indirectly improve the calibration of uncertainty estimates by refining the prior, whereas sigma-scaling focuses on a direct post-hoc adjustment to the posterior predictive uncertainty. + + +## Calibration in the Binary Classification case + + +In binary classification, we have $Y = {0, 1}$, and we say that H is calibrated if + +$$\frac{\sum_{t=1}^{T} y_t\mathbb{1}\{H(x_t)=p\}}{\sum_{t=1}^{T}\mathbb{1}\{H(x_t)=p\}} \rightarrow p \quad \text{for all} \quad p\in [0,1]$$ + +as $T \rightarrow \infty$. For simplicity, we have denoted $H(x_t)$ as the probability of the event $y_t=1$. Once again, perfect calibration is a sufficient condition for calibration. + + +### Sharpness +We can assess sharpness by looking at the distribution of model predictions. When forecasts are sharp, most +predicted probabilities for the correct class are close to $1$; unsharp forecasters make predictions closer to $0.5$. + + +## Calibration Plots + Given a dataset ${(x_t, y_t)}^T_t=1$, let $p_t = H(x_t) ∈ [0, 1]$ be the forecasted probability. We group the $p_t$ into intervals $I_j$ for +$j = 1, 2, ..., m$ that form a partition of $[0, 1]$. \ +A calibration curve plots the predicted average $ p_j = T^{−1}_j \sum_{t:p_t∈ I_j} p_t $ in each interval $I_j$ against the observed empirical average +$$p_j = T^{−1}_j \sum_{t:p_t ∈ I_j}y_t,$$ where $T_j = |{t : p_t ∈ I_j}|$. Perfect calibration corresponds once again to a straight line. + + + +## Multiclass Case +For multiclass classification tasks the above technique can be extended by plotting each class versus all the remaining classes considered as one. diff --git a/dev/tutorials/calibration/index.html b/dev/tutorials/calibration/index.html new file mode 100644 index 00000000..6133a86d --- /dev/null +++ b/dev/tutorials/calibration/index.html @@ -0,0 +1,2 @@ + +Calibrated forecasts · LaplaceRedux.jl

Uncertainty Calibration

The issue of calibrated uncertainty distributions

Bayesian methods offer a general framework for quantifying uncertainty. However, due to model misspecification and the use of approximate inference techniques,uncertainty estimates are often inaccurate: for example, a 90% credible interval may not contain the true outcome 90% of the time, in such cases the model is said to be miscalibrated. This problem arises due to the limitations of the model itself: a predictor may not be sufficiently expressive to assign the right probability to every credible interval, just as it may not be able to always assign the right label to a datapoint. Miscalibrated credible intervals reduce the trustworthiness of the forecaster because they lead to a false sense of precision and either overconfidence or underconfidence in the results.

A forecaster is said to be perfectly calibrated if a 90% credible interval contains the true outcome approximately 90% of the time. Perfect calibration however cannot be achieved with limited data, because with limited data comes inherent statistical fluctuations that can cause the estimated credible intervals to deviate from the ideal coverage probability. Furthermore, a finite sample of collected data points cannot eliminate completely the influence of the possible misjudged prior probabilities. On top of these issues, which stem directly from Bayes’ theorem, with Bayesian neural networks there is also the problems introduced by the approximate inference method adopted to compute the posterior distribution of the weights.

To introduce the concept of Average Calibration and the Calibration Plots, we will follow closely the paper Accurate Uncertainties for Deep Learning Using Calibrated Regression, written by Volodymyr Kuleshov, Nathan Fenner and Stefano Ermon, although with some small differences. We will hightlight these differences in the following paragraphs whenever they appear. We present here the theoretical basis necessary to understand the issue of calibration and we refer to the tutorials for the coding examples.

Notation

We are given a labeled dataset $x_t, y_t \in X \times Y$ for $t = 1, 2, ..., T$ of i.i.d. realizations of random variables $X, Y \sim P$, where $P$ is the data distribution. Given $x_t$, a forecaster $H : X \rightarrow (Y \rightarrow [0, 1])$ outputs at each step $t$ a CDF $F_t(y)$ targeting the label $y_t$. When $Y$ is continuous, $F_t$ is a cumulative probability distribution (CDF). We will use $F^{−1}_t: [0, 1] → Y$ to denote the quantile function $F^{−1}_t (p) = inf\{y : p ≤ F_t(y)\}$.

Calibration in the Regression case

In the regression case, we say that the forecaster H is (on average) calibrated if

\[\frac{\sum_{t=1}^T \mathbb{1} \{ y_t \leq F_t^{-1}(p) \} }{T} \rightarrow p \quad \text{for all}\quad p \in [0,1]\]

as $T \rightarrow \infty$.  In other words, the empirical and the predicted CDFs should match as the dataset size goes to infinity. Perfect Calibration is a sufficient condition for average calibration, the opposite however is not necessarily true: a model can be average calibrated but not perfectly calibrated. From now on when we talk about calibration we will implicitly talk about average calibration rather than perfect calibration.

Sharpness

Calibration by itself is not sufficient to produce a useful forecast. For example, it is easy to see that if we use for the forecast the marginal distribution $F(y) = \mathbb{P}(Y ≤ y)$, without considering the input feature $X$, the forecast will be calibrated but still not accurate. In order to be useful, forecasts must also be sharp, which (in a regression context) means that the confidence intervals should all be as tight as possible around a single value. More formally, we want the variance $var(F_t)$ of the random variable whose CDF is $F_t$ to be small. As a sharpness score of the forecaster, Kuleshov et al. proposed the average predicted variance

\[sharpness(F_{1} ,\dots, F_T) = \frac{1}{T} \sum_{t=1}^T var{F_t}\]

the smaller the sharpness, the tighter will be the confidence intervals on average.

Calibration Plots

To check the level of calibration, Kuleshov et al. proposed a calibration plot that displays the true frequency of points in each confidence interval relative to the predicted fraction of points in that interval. More formally, we choose $m$ confidence levels $0 ≤ p_1 < p_2 < . . . < p_m ≤ 1$; for each threshold $p_j$ , and compute the empirical frequency

\[\hat{p}_j = \frac{|\{ y_t|F_t(y_t) \leq p_t, t= 1,2,\dots,T \} |}{T}.\]

To visualize the level of average calibration, we plot $\{(p_j,\hat{p_j}) \}_{j=1}^M$; A forecaster that is calibrated will correspond to a straight line on the plot that goes from $\{0,0\}$ to $\{1,1\}$ . o measure the level of miscalibration, we can compute the area between the diagonal line (the line of perfect calibration) and the calibration curve produced by the forecaster. This area represents the degree to which the predicted probabilities deviate from the actual observed frequencies. Alternatively, the original paper suggests using the calibration error as a numerical score to describe the quality of forecast calibration: $cal(F_1,F_2,\dots,F_T,y_T) = \sum_{j=1}^m w_j (p_j -\hat{p}_j)^2$ where the scalars $w_j$ are tunable weights. Both methods — using the area between the calibration curve and the diagonal line or using the calibration error — are equivalent in that they both provide a numerical measure of miscalibration. The calibration error can be seen as a discretized approximation of the area, where the weights $𝑤_𝑗$ adjust for the distribution of samples across different bins.

Post-training calibration

As we have said previously, uncertainty estimates obtained by deep BNNs tend to be miscalibrated. We introduced the support to a post-training technique for regression problems presented in Recalibration of Aleatoric and Epistemic Regression Uncertainty in Medical Imaging by Max-Heinrich Laves, Sontje Ihler, Jacob F. Fast, Lüder A. Kahrs and Tobias Ortmaier and usually referred to as sigma-scaling. Using a Gaussian model, the technique consist in scaling the predicted standard deviation $\sigma$ with a scalar value $s$ to recalibrate the probability density function

\[p(y|x; \hat{y}(x), \hat{σ}^2(x)) = \mathbb{N}( y; \hat{y}(x),(s \cdot \hat{σ}(x))^2 ).\]

This results in the following minimization objective:

\[L_G(s) = m \log(s) + \frac{1}{2}s^{−2} \sum_{i=1}^m (\hat{σ}^{(i)}_θ)^{−2} || y^{(i)} − \hat{y}_{\theta}^{(i)}||^2.\]

In general, this equation can be optimized respect to $s$ with fixed values for the parameters $\theta$ using gradient descent in a second phase over a separate calibration set. However, for the case of a gaussian distribution, the analytical solution is known and takes the closed form

\[s = \pm \sqrt{\frac{1}{m} \sum_{i=1}^m (\hat{σ}^{(i)}_θ)^{−2} || y^{(i)} − \hat{y}_{\theta}^{(i)}||^2}.\]

Once the scalar $s$ is computed, all we have to do to obtain better calibrated predictions is to multiply the predicted standard deviation with the scalar. The main difference from the other supported technique for optimization of the uncertainty estimates, which is based on Empirical Bayes, is that sigma-scaling tries to optimize directly the posterior predictive distribution using a separate calibration dataset. Empirical Bayes instead involves estimating the parameters of the prior distribution by maximizing the marginal likelihood of the observed data, effectively finding a prior that best explains the data without needing a separate calibration dataset. Empirical Bayes uses the observed data itself to indirectly improve the calibration of uncertainty estimates by refining the prior, whereas sigma-scaling focuses on a direct post-hoc adjustment to the posterior predictive uncertainty.

Calibration in the Binary Classification case

In binary classification, we have $Y = {0, 1}$, and we say that H is calibrated if

\[\frac{\sum_{t=1}^{T} y_t\mathbb{1}\{H(x_t)=p\}}{\sum_{t=1}^{T}\mathbb{1}\{H(x_t)=p\}} \rightarrow p \quad \text{for all} \quad p\in [0,1]\]

as $T \rightarrow \infty$. For simplicity, we have denoted $H(x_t)$ as the probability of the event $y_t=1$. Once again, perfect calibration is a sufficient condition for calibration.

Sharpness

We can assess sharpness by looking at the distribution of model predictions. When forecasts are sharp, most predicted probabilities for the correct class are close to $1$; unsharp forecasters make predictions closer to $0.5$.

Calibration Plots

Given a dataset ${(x_t, y_t)}^T_t=1$, let $p_t = H(x_t) ∈ [0, 1]$ be the forecasted probability. We group the $p_t$ into intervals $I_j$ for $j = 1, 2, ..., m$ that form a partition of $[0, 1]$. A calibration curve plots the predicted average $ pj = T^{−1}*j *{t:pt∈ Ij} pt $ in each interval $I_j$ against the observed empirical average $p_j = T^{−1}_j \sum_{t:p_t ∈ I_j}y_t,$ where $T_j = |{t : p_t ∈ I_j}|$. Perfect calibration corresponds once again to a straight line.

Multiclass Case

For multiclass classification tasks the above technique can be extended by plotting each class versus all the remaining classes considered as one.

diff --git a/dev/tutorials/logit.qmd b/dev/tutorials/logit.qmd index 0b3643a8..bdafa613 100644 --- a/dev/tutorials/logit.qmd +++ b/dev/tutorials/logit.qmd @@ -18,10 +18,38 @@ theme(:lime) We will use synthetic data with linearly separable samples: ```{julia} +# set seed +seed= 1234 +Random.seed!(seed) # Number of points to generate. -xs, ys = LaplaceRedux.Data.toy_data_linear(100) +xs, ys = LaplaceRedux.Data.toy_data_linear(100; seed=seed) X = hcat(xs...) # bring into tabular format -data = zip(xs,ys) +``` + +split in a training and test set + +```{julia} +# Shuffle the data +n = length(ys) +indices = randperm(n) + +# Define the split ratio +split_ratio = 0.8 +split_index = Int(floor(split_ratio * n)) + +# Split the data into training and test sets +train_indices = indices[1:split_index] +test_indices = indices[split_index+1:end] + +xs_train = xs[train_indices] +xs_test = xs[test_indices] +ys_train = ys[train_indices] +ys_test = ys[test_indices] +# bring into tabular format +X_train = hcat(xs_train...) +X_test = hcat(xs_test...) + +data = zip(xs_train,ys_train) ``` @@ -74,6 +102,7 @@ optimize_prior!(la; verbose=true, n_steps=500) The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right). ```{julia} +#| output: true zoom = 0 p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1)) p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom) @@ -81,3 +110,25 @@ p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400)) ``` +Now we can test the level of calibration of the neural network. +First we collect the predicted results over the test dataset + +```{julia} +#| output: true + predicted_distributions= predict(la, X_test,ret_distr=true) +``` + +then we plot the calibration plot +```{julia} +#| output: true +Calibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10) +``` + +as we can see from the plot, although extremely accurate, the neural network does not seem to be calibrated well. This is, however, an effect of the extreme accuracy reached by the neural network which causes the lack of predictions with high uncertainty (low certainty). We can see this by looking at the level of sharpness for the two classes which are extremely close to 1, indicating the high level of trust that the neural network has in the predictions. + +```{julia} +#| output: true +sharpness_classification(ys_test,vec(predicted_distributions)) +``` + + diff --git a/dev/tutorials/logit/index.html b/dev/tutorials/logit/index.html index 6197bb4f..0d0bbef5 100644 --- a/dev/tutorials/logit/index.html +++ b/dev/tutorials/logit/index.html @@ -1,11 +1,33 @@ -Logistic Regression · LaplaceRedux.jl

Bayesian Logistic Regression

Libraries

using Pkg; Pkg.activate("docs")
+Logistic Regression · LaplaceRedux.jl

Bayesian Logistic Regression

Libraries

using Pkg; Pkg.activate("docs")
 # Import libraries
 using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra
-theme(:lime)

Data

We will use synthetic data with linearly separable samples:

# Number of points to generate.
-xs, ys = LaplaceRedux.Data.toy_data_linear(100)
-X = hcat(xs...) # bring into tabular format
-data = zip(xs,ys)

Model

Logistic regression with weight decay can be implemented in Flux.jl as a single dense (linear) layer with binary logit crossentropy loss:

nn = Chain(Dense(2,1))
+theme(:lime)

Data

We will use synthetic data with linearly separable samples:

# set seed
+seed= 1234
+Random.seed!(seed)
+# Number of points to generate.
+xs, ys = LaplaceRedux.Data.toy_data_linear(100; seed=seed)
+X = hcat(xs...) # bring into tabular format

split in a training and test set

# Shuffle the data
+n = length(ys)
+indices = randperm(n)
+
+# Define the split ratio
+split_ratio = 0.8
+split_index = Int(floor(split_ratio * n))
+
+# Split the data into training and test sets
+train_indices = indices[1:split_index]
+test_indices = indices[split_index+1:end]
+
+xs_train = xs[train_indices]
+xs_test = xs[test_indices]
+ys_train = ys[train_indices]
+ys_test = ys[test_indices]
+# bring into tabular format
+X_train = hcat(xs_train...) 
+X_test = hcat(xs_test...) 
+
+data = zip(xs_train,ys_train)

Model

Logistic regression with weight decay can be implemented in Flux.jl as a single dense (linear) layer with binary logit crossentropy loss:

nn = Chain(Dense(2,1))
 λ = 0.5
 sqnorm(x) = sum(abs2, x)
 weight_regularization(λ=λ) = 1/2 * λ^2 * sum(sqnorm, Flux.params(nn))
@@ -33,4 +55,5 @@
 p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1))
 p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
 p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
-plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))
+plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

Now we can test the level of calibration of the neural network. First we collect the predicted results over the test dataset

 predicted_distributions= predict(la, X_test,ret_distr=true)
1×20 Matrix{Distributions.Bernoulli{Float64}}:
+ Distributions.Bernoulli{Float64}(p=0.13122)  …  Distributions.Bernoulli{Float64}(p=0.109559)

then we plot the calibration plot

Calibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)

as we can see from the plot, although extremely accurate, the neural network does not seem to be calibrated well. This is, however, an effect of the extreme accuracy reached by the neural network which causes the lack of predictions with high uncertainty (low certainty). We can see this by looking at the level of sharpness for the two classes which are extremely close to 1, indicating the high level of trust that the neural network has in the predictions.

sharpness_classification(ys_test,vec(predicted_distributions))
(0.9131870336577175, 0.8865055827351365)
diff --git a/dev/tutorials/logit_files/figure-commonmark/cell-10-output-1.svg b/dev/tutorials/logit_files/figure-commonmark/cell-10-output-1.svg new file mode 100644 index 00000000..7668448d --- /dev/null +++ b/dev/tutorials/logit_files/figure-commonmark/cell-10-output-1.svg @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/logit_files/figure-commonmark/cell-7-output-1.svg b/dev/tutorials/logit_files/figure-commonmark/cell-7-output-1.svg new file mode 100644 index 00000000..a8b17423 --- /dev/null +++ b/dev/tutorials/logit_files/figure-commonmark/cell-7-output-1.svg @@ -0,0 +1,572 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/logit_files/figure-commonmark/cell-8-output-1.svg b/dev/tutorials/logit_files/figure-commonmark/cell-8-output-1.svg new file mode 100644 index 00000000..0636acdb --- /dev/null +++ b/dev/tutorials/logit_files/figure-commonmark/cell-8-output-1.svg @@ -0,0 +1,605 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/mlp.qmd b/dev/tutorials/mlp.qmd index d9d0b192..59f9ad62 100644 --- a/dev/tutorials/mlp.qmd +++ b/dev/tutorials/mlp.qmd @@ -18,18 +18,43 @@ theme(:lime) This time we use a synthetic dataset containing samples that are not linearly separable: ```{julia} +#set seed +seed = 1234 +Random.seed!(seed) # Number of points to generate. -xs, ys = LaplaceRedux.Data.toy_data_non_linear(200) -X = hcat(xs...) # bring into tabular format -data = zip(xs,ys) +xs, ys = LaplaceRedux.Data.toy_data_non_linear(400; seed = seed) +# Shuffle the data +n = length(ys) +indices = randperm(n) + +# Define the split ratio +split_ratio = 0.8 +split_index = Int(floor(split_ratio * n)) + +# Split the data into training and test sets +train_indices = indices[1:split_index] +test_indices = indices[split_index+1:end] + +xs_train = xs[train_indices] +xs_test = xs[test_indices] +ys_train = ys[train_indices] +ys_test = ys[test_indices] +# bring into tabular format +X_train = hcat(xs_train...) +X_test = hcat(xs_test...) + +data = zip(xs_train,ys_train) ``` + + + ## Model For the classification task we build a neural network with weight decay composed of a single hidden layer. ```{julia} n_hidden = 10 -D = size(X,1) +D = size(X_train,1) nn = Chain( Dense(D, n_hidden, σ), Dense(n_hidden, 1) @@ -78,9 +103,9 @@ The plot below shows the resulting posterior predictive surface for the plugin e # Plot the posterior distribution with a contour plot. zoom=0 -p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1)) -p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom) -p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom) +p_plugin = plot(la, X_train, ys_train; title="Plugin", link_approx=:plugin, clim=(0,1)) +p_untuned = plot(la_untuned, X_train, ys_train; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom) +p_laplace = plot(la, X_train, ys_train; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom) plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400)) ``` @@ -90,10 +115,22 @@ Zooming out we can note that the plugin estimator produces high-confidence estim #| output: true zoom=-50 -p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1)) -p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom) -p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom) +p_plugin = plot(la, X_train, ys_train; title="Plugin", link_approx=:plugin, clim=(0,1)) +p_untuned = plot(la_untuned, X_train, ys_train; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom) +p_laplace = plot(la, X_train, ys_train; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom) plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400)) ``` +We plot now the calibration plot to assess the level of average calibration reached by the neural network. + +```{julia} +#| output: true +predicted_distributions= predict(la, X_test,ret_distr=true) +Calibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10) +``` +and the sharpness score +```{julia} +#| output: true +sharpness_classification(ys_test,vec(predicted_distributions)) +``` \ No newline at end of file diff --git a/dev/tutorials/mlp/index.html b/dev/tutorials/mlp/index.html index 10ba528c..1b163c6d 100644 --- a/dev/tutorials/mlp/index.html +++ b/dev/tutorials/mlp/index.html @@ -1,12 +1,34 @@ -MLP Binary Classifier · LaplaceRedux.jl

Bayesian MLP

Libraries

using Pkg; Pkg.activate("docs")
+MLP Binary Classifier · LaplaceRedux.jl

Bayesian MLP

Libraries

using Pkg; Pkg.activate("docs")
 # Import libraries
 using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra
-theme(:lime)

Data

This time we use a synthetic dataset containing samples that are not linearly separable:

# Number of points to generate.
-xs, ys = LaplaceRedux.Data.toy_data_non_linear(200)
-X = hcat(xs...) # bring into tabular format
-data = zip(xs,ys)

Model

For the classification task we build a neural network with weight decay composed of a single hidden layer.

n_hidden = 10
-D = size(X,1)
+theme(:lime)

Data

This time we use a synthetic dataset containing samples that are not linearly separable:

#set seed
+seed = 1234
+Random.seed!(seed)
+# Number of points to generate.
+xs, ys = LaplaceRedux.Data.toy_data_non_linear(400; seed = seed)
+# Shuffle the data
+n = length(ys)
+indices = randperm(n)
+
+# Define the split ratio
+split_ratio = 0.8
+split_index = Int(floor(split_ratio * n))
+
+# Split the data into training and test sets
+train_indices = indices[1:split_index]
+test_indices = indices[split_index+1:end]
+
+xs_train = xs[train_indices]
+xs_test = xs[test_indices]
+ys_train = ys[train_indices]
+ys_test = ys[test_indices]
+# bring into tabular format
+X_train = hcat(xs_train...) 
+X_test = hcat(xs_test...) 
+
+data = zip(xs_train,ys_train)

Model

For the classification task we build a neural network with weight decay composed of a single hidden layer.

n_hidden = 10
+D = size(X_train,1)
 nn = Chain(
     Dense(D, n_hidden, σ),
     Dense(n_hidden, 1)
@@ -33,11 +55,12 @@
 la_untuned = deepcopy(la)   # saving for plotting
 optimize_prior!(la; verbose=true, n_steps=500)

The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).

# Plot the posterior distribution with a contour plot.
 zoom=0
-p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1))
-p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
-p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
+p_plugin = plot(la, X_train, ys_train; title="Plugin", link_approx=:plugin, clim=(0,1))
+p_untuned = plot(la_untuned, X_train, ys_train; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
+p_laplace = plot(la, X_train, ys_train; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
 plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

Zooming out we can note that the plugin estimator produces high-confidence estimates in regions scarce of any samples. The Laplace approximation is much more conservative about these regions.

zoom=-50
-p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1))
-p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
-p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
-plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

+p_plugin = plot(la, X_train, ys_train; title="Plugin", link_approx=:plugin, clim=(0,1)) +p_untuned = plot(la_untuned, X_train, ys_train; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom) +p_laplace = plot(la, X_train, ys_train; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom) +plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

We plot now the calibration plot to assess the level of average calibration reached by the neural network.

predicted_distributions= predict(la, X_test,ret_distr=true)
+Calibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)

and the sharpness score

sharpness_classification(ys_test,vec(predicted_distributions))
(0.9277189055456709, 0.9196132560599691)
diff --git a/dev/tutorials/mlp_files/figure-commonmark/cell-10-output-1.svg b/dev/tutorials/mlp_files/figure-commonmark/cell-10-output-1.svg new file mode 100644 index 00000000..9ebc9597 --- /dev/null +++ b/dev/tutorials/mlp_files/figure-commonmark/cell-10-output-1.svg @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/mlp_files/figure-commonmark/cell-7-output-1.svg b/dev/tutorials/mlp_files/figure-commonmark/cell-7-output-1.svg index 05967f3c..a5b54a86 100644 --- a/dev/tutorials/mlp_files/figure-commonmark/cell-7-output-1.svg +++ b/dev/tutorials/mlp_files/figure-commonmark/cell-7-output-1.svg @@ -1,912 +1,1284 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/mlp_files/figure-commonmark/cell-8-output-1.svg b/dev/tutorials/mlp_files/figure-commonmark/cell-8-output-1.svg index f91bb77b..5bfb3c48 100644 --- a/dev/tutorials/mlp_files/figure-commonmark/cell-8-output-1.svg +++ b/dev/tutorials/mlp_files/figure-commonmark/cell-8-output-1.svg @@ -1,1032 +1,1398 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/mlp_files/figure-commonmark/cell-9-output-1.svg b/dev/tutorials/mlp_files/figure-commonmark/cell-9-output-1.svg new file mode 100644 index 00000000..1cf4d90e --- /dev/null +++ b/dev/tutorials/mlp_files/figure-commonmark/cell-9-output-1.svg @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/multi.qmd b/dev/tutorials/multi.qmd index d0159fb4..6195a0ae 100644 --- a/dev/tutorials/multi.qmd +++ b/dev/tutorials/multi.qmd @@ -12,18 +12,51 @@ theme(:lime) ```{julia} using LaplaceRedux.Data -x, y = Data.toy_data_multi() +seed = 1234 +x, y = Data.toy_data_multi(seed=seed) X = hcat(x...) -y_train = Flux.onehotbatch(y, unique(y)) -y_train = Flux.unstack(y_train',1) +y_onehot = Flux.onehotbatch(y, unique(y)) +y_onehot = Flux.unstack(y_onehot',1) ``` + +split in training and test datasets + +```{julia} +# Shuffle the data +Random.seed!(seed) +n = length(y) +indices = randperm(n) + +# Define the split ratio +split_ratio = 0.8 +split_index = Int(floor(split_ratio * n)) + +# Split the data into training and test sets +train_indices = indices[1:split_index] +test_indices = indices[split_index+1:end] + +x_train = x[train_indices] +x_test = x[test_indices] +y_onehot_train = y_onehot[train_indices,:] +y_onehot_test = y_onehot[test_indices,:] + +y_train = vec(y[train_indices,:]) +y_test = vec(y[test_indices,:]) +# bring into tabular format +X_train = hcat(x_train...) +X_test = hcat(x_test...) + +data = zip(x_train,y_onehot_train) +#data = zip(x,y_onehot) +``` + + ## MLP We set up a model ```{julia} -data = zip(x,y_train) n_hidden = 3 D = size(X,1) out_dim = length(unique(y)) @@ -67,26 +100,67 @@ fit!(la, data) optimize_prior!(la; verbose=true, n_steps=100) ``` +with either the probit approximation: + ```{julia} #| output: true _labels = sort(unique(y)) plt_list = [] for target in _labels - plt = plot(la, X, y; target=target, clim=(0,1)) + plt = plot(la, X_test, y_test; target=target, clim=(0,1)) push!(plt_list, plt) end plot(plt_list...) ``` + or the plugin approximation: + ```{julia} #| output: true _labels = sort(unique(y)) plt_list = [] for target in _labels - plt = plot(la, X, y; target=target, clim=(0,1), link_approx=:plugin) + plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin) push!(plt_list, plt) end plot(plt_list...) -``` \ No newline at end of file +``` + +## Calibration Plots + +In the case of multiclass classification tasks, we cannot plot the calibration plots directly since they can only be used in the binary classification case. However, we can use them to plot the calibration of the predictions for 1 class against all the others. To do so, we first have to collect the predicted categorical distributions + +```{julia} +#| output: true +predicted_distributions= predict(la, X_test,ret_distr=true) +``` + +then we transform the categorical distributions into Bernoulli distributions by taking only the probability of the class of interest, for example the third one. +```{julia} +#| output: true +using Distributions +bernoulli_distributions = [Bernoulli(p.p[3]) for p in vec(predicted_distributions)] +``` + +Now we can use ```Calibration_Plot``` to see the level of calibration of the neural network +```{julia} +#| output: true +plt = Calibration_Plot(la,hcat(y_onehot_test...)[3,:],bernoulli_distributions;n_bins = 10); + +``` + +The plot is peaked around 0.7. + +A possible reason is that class 3 is relatively easy for the model to identify from the other classes, although it remains a bit underconfident in its predictions. +Another reason for the peak may be the lack of cases where the predicted probability is lower (e.g., around 0.5), which could indicate that the network has not encountered ambiguous or difficult-to-classify examples for such class. This once again might be because either class 3 has distinct features that the model can easily learn, leading to fewer uncertain predictions, or is a consequence of the limited dataset. + + We can measure how sharp the neural network is by computing the sharpness score + ```{julia} +#| output: true +sharpness_classification(hcat(y_onehot_test...)[3,:],vec(bernoulli_distributions)) + +``` + +The neural network seems to be able to correctly classify the majority of samples not belonging to class 3 with a relative high confidence, but remains more uncertain when he encounter examples belonging to class 3. \ No newline at end of file diff --git a/dev/tutorials/multi/index.html b/dev/tutorials/multi/index.html index 148d9081..bdcc93b5 100644 --- a/dev/tutorials/multi/index.html +++ b/dev/tutorials/multi/index.html @@ -1,13 +1,38 @@ -MLP Multi-Label Classifier · LaplaceRedux.jl

Multi-class problem

Libraries

using Pkg; Pkg.activate("docs")
+MLP Multi-Label Classifier · LaplaceRedux.jl

Multi-class problem

Libraries

using Pkg; Pkg.activate("docs")
 # Import libraries
 using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux
 theme(:lime)

Data

using LaplaceRedux.Data
-x, y = Data.toy_data_multi()
+seed = 1234
+x, y = Data.toy_data_multi(seed=seed)
 X = hcat(x...)
-y_train = Flux.onehotbatch(y, unique(y))
-y_train = Flux.unstack(y_train',1)

MLP

We set up a model

data = zip(x,y_train)
-n_hidden = 3
+y_onehot = Flux.onehotbatch(y, unique(y))
+y_onehot = Flux.unstack(y_onehot',1)

split in training and test datasets

# Shuffle the data
+Random.seed!(seed)
+n = length(y)
+indices = randperm(n)
+
+# Define the split ratio
+split_ratio = 0.8
+split_index = Int(floor(split_ratio * n))
+
+# Split the data into training and test sets
+train_indices = indices[1:split_index]
+test_indices = indices[split_index+1:end]
+
+x_train = x[train_indices]
+x_test = x[test_indices]
+y_onehot_train = y_onehot[train_indices,:]
+y_onehot_test = y_onehot[test_indices,:]
+
+y_train = vec(y[train_indices,:])
+y_test = vec(y[test_indices,:])
+# bring into tabular format
+X_train = hcat(x_train...) 
+X_test = hcat(x_test...) 
+
+data = zip(x_train,y_onehot_train)
+#data = zip(x,y_onehot)

MLP

We set up a model

n_hidden = 3
 D = size(X,1)
 out_dim = length(unique(y))
 nn = Chain(
@@ -33,16 +58,38 @@
     end
 end

Laplace Approximation

The Laplace approximation can be implemented as follows:

la = Laplace(nn; likelihood=:classification)
 fit!(la, data)
-optimize_prior!(la; verbose=true, n_steps=100)
_labels = sort(unique(y))
+optimize_prior!(la; verbose=true, n_steps=100)

with either the probit approximation:

_labels = sort(unique(y))
 plt_list = []
 for target in _labels
-    plt = plot(la, X, y; target=target, clim=(0,1))
+    plt = plot(la, X_test, y_test; target=target, clim=(0,1))
     push!(plt_list, plt)
 end
-plot(plt_list...)

_labels = sort(unique(y))
+plot(plt_list...)

or the plugin approximation:

_labels = sort(unique(y))
 plt_list = []
 for target in _labels
-    plt = plot(la, X, y; target=target, clim=(0,1), link_approx=:plugin)
+    plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin)
     push!(plt_list, plt)
 end
-plot(plt_list...)

+plot(plt_list...)

Calibration Plots

In the case of multiclass classification tasks, we cannot plot the calibration plots directly since they can only be used in the binary classification case. However, we can use them to plot the calibration of the predictions for 1 class against all the others. To do so, we first have to collect the predicted categorical distributions

predicted_distributions= predict(la, X_test,ret_distr=true)
1×20 Matrix{Distributions.Categorical{Float64, Vector{Float64}}}:
+ Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569184, 0.196066, 0.0296796, 0.717336])  …  Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569634, 0.195727, 0.0296449, 0.717665])

then we transform the categorical distributions into Bernoulli distributions by taking only the probability of the class of interest, for example the third one.

using Distributions
+bernoulli_distributions = [Bernoulli(p.p[3]) for p in vec(predicted_distributions)]
20-element Vector{Bernoulli{Float64}}:
+ Bernoulli{Float64}(p=0.029679590887034743)
+ Bernoulli{Float64}(p=0.6682373773598078)
+ Bernoulli{Float64}(p=0.20912995228011141)
+ Bernoulli{Float64}(p=0.20913322913224044)
+ Bernoulli{Float64}(p=0.02971989045895732)
+ Bernoulli{Float64}(p=0.668431087463204)
+ Bernoulli{Float64}(p=0.03311710703617972)
+ Bernoulli{Float64}(p=0.20912981531862682)
+ Bernoulli{Float64}(p=0.11273726979027407)
+ Bernoulli{Float64}(p=0.2490744632745955)
+ Bernoulli{Float64}(p=0.029886357844211404)
+ Bernoulli{Float64}(p=0.02965323602487074)
+ Bernoulli{Float64}(p=0.1126799374664026)
+ Bernoulli{Float64}(p=0.11278538625980777)
+ Bernoulli{Float64}(p=0.6683139127616431)
+ Bernoulli{Float64}(p=0.029644435143197145)
+ Bernoulli{Float64}(p=0.11324691083703237)
+ Bernoulli{Float64}(p=0.6681422555922787)
+ Bernoulli{Float64}(p=0.668424345470233)
+ Bernoulli{Float64}(p=0.029644891255330787)

Now we can use Calibration_Plot to see the level of calibration of the neural network

plt = Calibration_Plot(la,hcat(y_onehot_test...)[3,:],bernoulli_distributions;n_bins = 10);

The plot is peaked around 0.7.

A possible reason is that class 3 is relatively easy for the model to identify from the other classes, although it remains a bit underconfident in its predictions. Another reason for the peak may be the lack of cases where the predicted probability is lower (e.g., around 0.5), which could indicate that the network has not encountered ambiguous or difficult-to-classify examples for such class. This once again might be because either class 3 has distinct features that the model can easily learn, leading to fewer uncertain predictions, or is a consequence of the limited dataset.

We can measure how sharp the neural network is by computing the sharpness score

sharpnessclassification(hcat(yonehottest…)[3,:],vec(bernoullidistributions))

```

The neural network seems to be able to correctly classify the majority of samples not belonging to class 3 with a relative high confidence, but remains more uncertain when he encounter examples belonging to class 3.

diff --git a/dev/tutorials/multi_files/figure-commonmark/cell-12-output-1.svg b/dev/tutorials/multi_files/figure-commonmark/cell-12-output-1.svg new file mode 100644 index 00000000..462326cf --- /dev/null +++ b/dev/tutorials/multi_files/figure-commonmark/cell-12-output-1.svg @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/multi_files/figure-commonmark/cell-8-output-1.svg b/dev/tutorials/multi_files/figure-commonmark/cell-8-output-1.svg index 96d98bbc..5c2be9e8 100644 --- a/dev/tutorials/multi_files/figure-commonmark/cell-8-output-1.svg +++ b/dev/tutorials/multi_files/figure-commonmark/cell-8-output-1.svg @@ -1,792 +1,495 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/multi_files/figure-commonmark/cell-9-output-1.svg b/dev/tutorials/multi_files/figure-commonmark/cell-9-output-1.svg new file mode 100644 index 00000000..cab18e17 --- /dev/null +++ b/dev/tutorials/multi_files/figure-commonmark/cell-9-output-1.svg @@ -0,0 +1,516 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/prior/index.html b/dev/tutorials/prior/index.html index 69238828..3d5abc28 100644 --- a/dev/tutorials/prior/index.html +++ b/dev/tutorials/prior/index.html @@ -1,5 +1,5 @@ -A note on the prior ... · LaplaceRedux.jl

Libraries

using Pkg; Pkg.activate("docs")
+A note on the prior ... · LaplaceRedux.jl

Libraries

using Pkg; Pkg.activate("docs")
 # Import libraries
 using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra
In Progress

    This documentation is still incomplete.

A quick note on the prior

General Effect

High prior precision $\rightarrow$ only observation noise. Low prior precision $\rightarrow$ high posterior uncertainty.

using LaplaceRedux.Data
 n = 150       # number of observations
@@ -122,4 +122,4 @@
     plts = vcat(plts..., plt)
     nns = vcat(nns..., nn)
 end
-plot(plts..., layout=(1,3), size=(1200,300))

+plot(plts..., layout=(1,3), size=(1200,300))

diff --git a/dev/tutorials/prior_files/figure-commonmark/cell-4-output-1.svg b/dev/tutorials/prior_files/figure-commonmark/cell-4-output-1.svg index 6c64e376..857bb7b0 100644 --- a/dev/tutorials/prior_files/figure-commonmark/cell-4-output-1.svg +++ b/dev/tutorials/prior_files/figure-commonmark/cell-4-output-1.svg @@ -1,578 +1,578 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/prior_files/figure-commonmark/cell-5-output-1.svg b/dev/tutorials/prior_files/figure-commonmark/cell-5-output-1.svg index 13f81d2c..290ae1b1 100644 --- a/dev/tutorials/prior_files/figure-commonmark/cell-5-output-1.svg +++ b/dev/tutorials/prior_files/figure-commonmark/cell-5-output-1.svg @@ -1,578 +1,578 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/prior_files/figure-commonmark/cell-6-output-1.svg b/dev/tutorials/prior_files/figure-commonmark/cell-6-output-1.svg index 035a6e00..afa9d368 100644 --- a/dev/tutorials/prior_files/figure-commonmark/cell-6-output-1.svg +++ b/dev/tutorials/prior_files/figure-commonmark/cell-6-output-1.svg @@ -1,858 +1,887 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/regression.qmd b/dev/tutorials/regression.qmd index f87ebfe2..6cce7c06 100644 --- a/dev/tutorials/regression.qmd +++ b/dev/tutorials/regression.qmd @@ -12,25 +12,54 @@ using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux theme(:wong) ``` + ## Data We first generate some synthetic data: ```{julia} using LaplaceRedux.Data -n = 300 # number of observations +n = 3000 # number of observations σtrue = 0.30 # true observational noise -x, y = Data.toy_data_regression(n;noise=σtrue) +x, y = Data.toy_data_regression(n;noise=σtrue,seed=1234) xs = [[x] for x in x] X = permutedims(x) ``` +and split them in a training set and a test set +```{julia} +# Shuffle the data +Random.seed!(1234) # Set a seed for reproducibility +shuffle_indices = shuffle(1:n) + +# Define split ratios +train_ratio = 0.8 +test_ratio = 0.2 + +# Calculate split indices +train_end = Int(floor(train_ratio * n)) + +# Split the data +train_indices = shuffle_indices[1:train_end] +test_indices = shuffle_indices[train_end+1:end] + +# Create the splits +x_train, y_train = x[train_indices], y[train_indices] +x_test, y_test = x[test_indices], y[test_indices] + +# Optional: Convert to desired format +xs_train = [[x] for x in x_train] +xs_test = [[x] for x in x_test] +X_train = permutedims(x_train) +X_test = permutedims(x_test) +``` + ## MLP We set up a model and loss with weight regularization: ```{julia} -data = zip(xs,y) +train_data = zip(xs_train,y_train) n_hidden = 50 D = size(X,1) nn = Chain( @@ -46,11 +75,11 @@ We train the model: using Flux.Optimise: update!, Adam opt = Adam(1e-3) epochs = 1000 -avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data)) +avg_loss(train_data) = mean(map(d -> loss(d[1],d[2]), train_data)) show_every = epochs/10 for epoch = 1:epochs - for d in data + for d in train_data gs = gradient(Flux.params(nn)) do l = loss(d...) end @@ -58,7 +87,7 @@ for epoch = 1:epochs end if epoch % show_every == 0 println("Epoch " * string(epoch)) - @show avg_loss(data) + @show avg_loss(train_data) end end ``` @@ -72,8 +101,8 @@ Laplace approximation can be implemented as follows: subset_w = :all la = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w) -fit!(la, data) -plot(la, X, y; zoom=-5, size=(400,400)) +fit!(la, train_data) +plot(la, X_train, y_train; zoom=-5, size=(400,400)) ``` Next we optimize the prior precision $P_0$ and and observational noise $\sigma$ using Empirical Bayes: @@ -82,6 +111,30 @@ Next we optimize the prior precision $P_0$ and and observational noise $\sigma$ #| output: true optimize_prior!(la; verbose=true) -plot(la, X, y; zoom=-5, size=(400,400)) +plot(la, X_train, y_train; zoom=-5, size=(400,400)) ``` +## Calibration Plot +Once the prior precision has been optimized it is possible to evaluate the quality of the predictive distribution +obtained through a calibration plot and a test dataset (y_test, X_test). + +First, we apply the trained network on the test dataset (y_test, X_test) and collect the neural network's predicted distributions +```{julia} +#| output: true +predicted_distributions= predict(la, X_test,ret_distr=true) +``` + +then we can plot the calibration plot of our neural model + +```{julia} +#| output: true +Calibration_Plot(la,y_test,vec(predicted_distributions);n_bins = 20) +``` + +and compute the sharpness of the predictive distribution + +```{julia} +#| output: true +sharpness_regression(vec(predicted_distributions)) +``` + diff --git a/dev/tutorials/regression/index.html b/dev/tutorials/regression/index.html index f17433c3..f1f3fe09 100644 --- a/dev/tutorials/regression/index.html +++ b/dev/tutorials/regression/index.html @@ -1,13 +1,36 @@ -MLP Regression · LaplaceRedux.jl

Libraries

Import the libraries required to run this example

using Pkg; Pkg.activate("docs")
+MLP Regression · LaplaceRedux.jl

Libraries

Import the libraries required to run this example

using Pkg; Pkg.activate("docs")
 # Import libraries
 using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux
 theme(:wong)

Data

We first generate some synthetic data:

using LaplaceRedux.Data
-n = 300       # number of observations
+n = 3000       # number of observations
 σtrue = 0.30  # true observational noise
-x, y = Data.toy_data_regression(n;noise=σtrue)
+x, y = Data.toy_data_regression(n;noise=σtrue,seed=1234)
 xs = [[x] for x in x]
-X = permutedims(x)

MLP

We set up a model and loss with weight regularization:

data = zip(xs,y)
+X = permutedims(x)

and split them in a training set and a test set

# Shuffle the data
+Random.seed!(1234)  # Set a seed for reproducibility
+shuffle_indices = shuffle(1:n)
+
+# Define split ratios
+train_ratio = 0.8
+test_ratio = 0.2
+
+# Calculate split indices
+train_end = Int(floor(train_ratio * n))
+
+# Split the data
+train_indices = shuffle_indices[1:train_end]
+test_indices = shuffle_indices[train_end+1:end]
+
+# Create the splits
+x_train, y_train = x[train_indices], y[train_indices]
+x_test, y_test = x[test_indices], y[test_indices]
+
+# Optional: Convert to desired format
+xs_train = [[x] for x in x_train]
+xs_test = [[x] for x in x_test]
+X_train = permutedims(x_train)
+X_test = permutedims(x_test)

MLP

We set up a model and loss with weight regularization:

train_data = zip(xs_train,y_train)
 n_hidden = 50
 D = size(X,1)
 nn = Chain(
@@ -17,11 +40,11 @@
 loss(x, y) = Flux.Losses.mse(nn(x), y)

We train the model:

using Flux.Optimise: update!, Adam
 opt = Adam(1e-3)
 epochs = 1000
-avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
+avg_loss(train_data) = mean(map(d -> loss(d[1],d[2]), train_data))
 show_every = epochs/10
 
 for epoch = 1:epochs
-  for d in data
+  for d in train_data
     gs = gradient(Flux.params(nn)) do
       l = loss(d...)
     end
@@ -29,53 +52,82 @@
   end
   if epoch % show_every == 0
     println("Epoch " * string(epoch))
-    @show avg_loss(data)
+    @show avg_loss(train_data)
   end
 end

Laplace Approximation

Laplace approximation can be implemented as follows:

subset_w = :all
 la = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)
-fit!(la, data)
-plot(la, X, y; zoom=-5, size=(400,400))

Next we optimize the prior precision $P_0$ and and observational noise $\sigma$ using Empirical Bayes:

optimize_prior!(la; verbose=true)
-plot(la, X, y; zoom=-5, size=(400,400))
loss(exp.(logP₀), exp.(logσ)) = 104.78561546028183
-Log likelihood: -70.48742092717352
-Log det ratio: 41.1390695290454
-Scatter: 27.45731953717124
-loss(exp.(logP₀), exp.(logσ)) = 104.9736282327825
-Log likelihood: -74.85481357633174
-Log det ratio: 46.59827618892447
-Scatter: 13.639353123977058
-loss(exp.(logP₀), exp.(logσ)) = 84.38222356291794
-Log likelihood: -54.86985627702764
-Log det ratio: 49.92347667032635
-Scatter: 9.101257901454279
+fit!(la, train_data)
+plot(la, X_train, y_train; zoom=-5, size=(400,400))

Next we optimize the prior precision $P_0$ and and observational noise $\sigma$ using Empirical Bayes:

optimize_prior!(la; verbose=true)
+plot(la, X_train, y_train; zoom=-5, size=(400,400))
loss(exp.(logP₀), exp.(logσ)) = 668.3714946472106
+Log likelihood: -618.5175117610522
+Log det ratio: 68.76532606873238
+Scatter: 30.942639703584522
+loss(exp.(logP₀), exp.(logσ)) = 719.2536119935747
+Log likelihood: -673.0996963447847
+Log det ratio: 76.53255037599948
+Scatter: 15.775280921580569
+loss(exp.(logP₀), exp.(logσ)) = 574.605864472924
+Log likelihood: -528.694286608232
+
+
+Log det ratio: 80.73114330857285
+Scatter: 11.092012420811196
+loss(exp.(logP₀), exp.(logσ)) = 568.4433850825203
+Log likelihood: -522.4407550111031
+Log det ratio: 82.10089958560243
+Scatter: 9.90436055723207
+
 
-loss(exp.(logP₀), exp.(logσ)) = 84.53493863039972
-Log likelihood: -55.013137224636
-Log det ratio: 51.43622180356522
-Scatter: 7.607381007962245
-loss(exp.(logP₀), exp.(logσ)) = 83.95921598606084
-Log likelihood: -54.41492266831395
-Log det ratio: 51.794520967146354
-Scatter: 7.294065668347427
-loss(exp.(logP₀), exp.(logσ)) = 83.03505059021086
-Log likelihood: -53.50540374805591
-Log det ratio: 51.574749787874794
-Scatter: 7.484543896435117
+loss(exp.(logP₀), exp.(logσ)) = 566.9485255672008
+Log likelihood: -520.9682443835385
+Log det ratio: 81.84516297272847
+Scatter: 10.11539939459612
+loss(exp.(logP₀), exp.(logσ)) = 559.9852101992792
+Log likelihood: -514.0625630685765
+Log det ratio: 80.97813304453496
+Scatter: 10.867161216870441
 
-loss(exp.(logP₀), exp.(logσ)) = 82.97840036025443
-Log likelihood: -53.468475394115416
-Log det ratio: 51.17273666609066
-Scatter: 7.847113266187348
-loss(exp.(logP₀), exp.(logσ)) = 82.98550025321256
-Log likelihood: -53.48508828283467
-Log det ratio: 50.81442045868749
-Scatter: 8.186403482068298
-loss(exp.(logP₀), exp.(logσ)) = 82.9584040552644
-Log likelihood: -53.45989630330948
-Log det ratio: 50.59063282947659
-Scatter: 8.406382674433235
+loss(exp.(logP₀), exp.(logσ)) = 559.1404593114019
+Log likelihood: -513.2449017869876
+Log det ratio: 80.16026747795866
+Scatter: 11.630847570869795
+loss(exp.(logP₀), exp.(logσ)) = 559.3201392562346
+Log likelihood: -513.4273312363501
+Log det ratio: 79.68892769076004
+Scatter: 12.096688349008877
 
 
-loss(exp.(logP₀), exp.(logσ)) = 82.94465052328141
-Log likelihood: -53.44600301956443
-Log det ratio: 50.500079294094405
-Scatter: 8.497215713339543

+loss(exp.(logP₀), exp.(logσ)) = 559.2111983983311 +Log likelihood: -513.3174948065804 +Log det ratio: 79.56631681347287 +Scatter: 12.2210903700287 +loss(exp.(logP₀), exp.(logσ)) = 559.1107459310829 +Log likelihood: -513.2176579845662 +Log det ratio: 79.63946732368183 +Scatter: 12.146708569351494

Calibration Plot

Once the prior precision has been optimized it is possible to evaluate the quality of the predictive distribution obtained through a calibration plot and a test dataset (ytest, Xtest).

First, we apply the trained network on the test dataset (ytest, Xtest) and collect the neural network’s predicted distributions

predicted_distributions= predict(la, X_test,ret_distr=true)
600×1 Matrix{Distributions.Normal{Float64}}:
+ Distributions.Normal{Float64}(μ=-0.1137533187866211, σ=0.07161056521032018)
+ Distributions.Normal{Float64}(μ=0.7063850164413452, σ=0.050697938829269665)
+ Distributions.Normal{Float64}(μ=-0.2211049497127533, σ=0.06876939416479119)
+ Distributions.Normal{Float64}(μ=0.720299243927002, σ=0.08665125572287981)
+ Distributions.Normal{Float64}(μ=-0.8338974714279175, σ=0.06464012115237727)
+ Distributions.Normal{Float64}(μ=0.9910320043563843, σ=0.07452060172164382)
+ Distributions.Normal{Float64}(μ=0.1507074236869812, σ=0.07316299850461126)
+ Distributions.Normal{Float64}(μ=0.20875799655914307, σ=0.05507748397231652)
+ Distributions.Normal{Float64}(μ=0.973572850227356, σ=0.07899004963915071)
+ Distributions.Normal{Float64}(μ=0.9497100114822388, σ=0.07750126389821968)
+ Distributions.Normal{Float64}(μ=0.22462180256843567, σ=0.07103664786246695)
+ Distributions.Normal{Float64}(μ=-0.7654240131378174, σ=0.05501397704409917)
+ Distributions.Normal{Float64}(μ=1.0029183626174927, σ=0.07619466916431794)
+ ⋮
+ Distributions.Normal{Float64}(μ=0.7475956678390503, σ=0.049875919157527815)
+ Distributions.Normal{Float64}(μ=0.019430622458457947, σ=0.07445076746045155)
+ Distributions.Normal{Float64}(μ=-0.9451781511306763, σ=0.05929712369810892)
+ Distributions.Normal{Float64}(μ=-0.9813591241836548, σ=0.05844012710417755)
+ Distributions.Normal{Float64}(μ=-0.6470385789871216, σ=0.055754609087554294)
+ Distributions.Normal{Float64}(μ=-0.34288135170936584, σ=0.05533523375842789)
+ Distributions.Normal{Float64}(μ=0.9912381172180176, σ=0.07872473667398772)
+ Distributions.Normal{Float64}(μ=-0.824547290802002, σ=0.05499258101374759)
+ Distributions.Normal{Float64}(μ=-0.3306621015071869, σ=0.06745251908756716)
+ Distributions.Normal{Float64}(μ=0.3742436170578003, σ=0.10588913330223387)
+ Distributions.Normal{Float64}(μ=0.0875578224658966, σ=0.07436153828228255)
+ Distributions.Normal{Float64}(μ=-0.34871187806129456, σ=0.06742745343084512)

then we can plot the calibration plot of our neural model

Calibration_Plot(la,y_test,vec(predicted_distributions);n_bins = 20)

and compute the sharpness of the predictive distribution

sharpness_regression(vec(predicted_distributions))
0.005058067743863281
diff --git a/dev/tutorials/regression_files/figure-commonmark/cell-10-output-1.svg b/dev/tutorials/regression_files/figure-commonmark/cell-10-output-1.svg new file mode 100644 index 00000000..f854d3d1 --- /dev/null +++ b/dev/tutorials/regression_files/figure-commonmark/cell-10-output-1.svg @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/regression_files/figure-commonmark/cell-7-output-1.svg b/dev/tutorials/regression_files/figure-commonmark/cell-7-output-1.svg new file mode 100644 index 00000000..f9858849 --- /dev/null +++ b/dev/tutorials/regression_files/figure-commonmark/cell-7-output-1.svg @@ -0,0 +1,2452 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/regression_files/figure-commonmark/cell-8-output-5.svg b/dev/tutorials/regression_files/figure-commonmark/cell-8-output-5.svg new file mode 100644 index 00000000..f250d2fb --- /dev/null +++ b/dev/tutorials/regression_files/figure-commonmark/cell-8-output-5.svg @@ -0,0 +1,2452 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/regression_files/figure-commonmark/cell-8-output-6.svg b/dev/tutorials/regression_files/figure-commonmark/cell-8-output-6.svg new file mode 100644 index 00000000..e8952350 --- /dev/null +++ b/dev/tutorials/regression_files/figure-commonmark/cell-8-output-6.svg @@ -0,0 +1,2452 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dev/tutorials/regression_files/figure-commonmark/miscalibration.svg b/dev/tutorials/regression_files/figure-commonmark/miscalibration.svg new file mode 100644 index 00000000..9ca6b50d --- /dev/null +++ b/dev/tutorials/regression_files/figure-commonmark/miscalibration.svg @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +