diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 00000000..e69de29b diff --git a/404.html b/404.html new file mode 100644 index 00000000..41a2adfb --- /dev/null +++ b/404.html @@ -0,0 +1,1031 @@ + + + +
+ + + + + + + + + + + + + + + + + + +posteriors.ekf.dense_fisher.build(log_likelihood, lr, transition_cov=0.0, per_sample=False, init_cov=1.0)
+
+𝞡Builds a transform to implement an extended Kalman Filter update.
+EKF applies an online update to a Gaussian posterior over the parameters.
+The approximate Bayesian update is based on the linearization +$$ +\log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F(μ) (θ - μ) +$$ +where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate +(or equivalently the likelihood inverse temperature), +\(g(μ)\) is the gradient of the log likelihood at μ and \(F(μ)\) is the +empirical Fisher information matrix at \(μ\) for data \(y\).
+For more information on extended Kalman filtering as well as an equivalence +to (online) natural gradient descent see Ollivier, 2019.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
log_likelihood |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log-likelihood value as well as auxiliary information, +e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Inverse temperature of the update, which behaves like a learning rate. + |
+ + required + | +
transition_cov |
+
+ Tensor | float
+ |
+
+
+
+ Covariance of the transition noise, to additively +inflate the covariance before the update. + |
+
+ 0.0
+ |
+
per_sample |
+
+ bool
+ |
+
+
+
+ If True, then log_likelihood is assumed to return a vector of +log likelihoods for each sample in the batch. If False, then log_likelihood +is assumed to return a scalar log likelihood for the whole batch, in this +case torch.func.vmap will be called, this is typically slower than +directly writing log_likelihood to be per sample. + |
+
+ False
+ |
+
init_cov |
+
+ Tensor | float
+ |
+
+
+
+ Initial covariance of the Normal distribution. Can be torch.Tensor or scalar. + |
+
+ 1.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ EKF transform instance. + |
+
posteriors/ekf/dense_fisher.py
18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 |
|
posteriors.ekf.dense_fisher.EKFDenseState
+
+
+𝞡
+ Bases: NamedTuple
State encoding a Normal distribution over parameters.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+
cov |
+
+ Tensor
+ |
+
+
+
+ Covariance matrix of the +Normal distribution. + |
+
log_likelihood |
+
+ Tensor
+ |
+
+
+
+ Log likelihood of the data given the parameters. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_likelihood call. + |
+
posteriors/ekf/dense_fisher.py
69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 |
|
posteriors.ekf.dense_fisher.init(params, init_cov=1.0)
+
+𝞡Initialise Multivariate Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Initial mean of the Normal distribution. + |
+ + required + | +
init_cov |
+
+ Tensor | float
+ |
+
+
+
+ Initial covariance matrix of the Multivariate Normal distribution. +If it is a float, it is defined as an identity matrix scaled by that float. + |
+
+ 1.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ EKFDenseState
+ |
+
+
+
+ Initial EKFDenseState. + |
+
posteriors/ekf/dense_fisher.py
86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 |
|
posteriors.ekf.dense_fisher.update(state, batch, log_likelihood, lr, transition_cov=0.0, per_sample=False, inplace=False)
+
+𝞡Applies an extended Kalman Filter update to the Multivariate Normal distribution. +The approximate Bayesian update is based on the linearization +$$ +\log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F(μ) (θ - μ) +$$ +where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate +(or equivalently the likelihood inverse temperature), +\(g(μ)\) is the gradient of the log likelihood at μ and \(F(μ)\) is the +empirical Fisher information matrix at \(μ\) for data \(y\).
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ EKFDenseState
+ |
+
+
+
+ Current state. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Input data to log_likelihood. + |
+ + required + | +
log_likelihood |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log-likelihood value as well as auxiliary information, +e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Inverse temperature of the update, which behaves like a learning rate. + |
+ + required + | +
transition_cov |
+
+ Tensor | float
+ |
+
+
+
+ Covariance of the transition noise, to additively +inflate the covariance before the update. + |
+
+ 0.0
+ |
+
per_sample |
+
+ bool
+ |
+
+
+
+ If True, then log_likelihood is assumed to return a vector of +log likelihoods for each sample in the batch. If False, then log_likelihood +is assumed to return a scalar log likelihood for the whole batch, in this +case torch.func.vmap will be called, this is typically slower than +directly writing log_likelihood to be per sample. + |
+
+ False
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ Whether to update the state parameters in-place. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ EKFDenseState
+ |
+
+
+
+ Updated EKFDenseState. + |
+
posteriors/ekf/dense_fisher.py
107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 |
|
posteriors.ekf.dense_fisher.sample(state, sample_shape=torch.Size([]))
+
+𝞡Single sample from Multivariate Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ EKFDenseState
+ |
+
+
+
+ State encoding mean and covariance. + |
+ + required + | +
sample_shape |
+
+ Size
+ |
+
+
+
+ Shape of the desired samples. + |
+
+ Size([])
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ Sample(s) from Multivariate Normal distribution. + |
+
posteriors/ekf/dense_fisher.py
179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 |
|
posteriors.ekf.diag_fisher.build(log_likelihood, lr, transition_sd=0.0, per_sample=False, init_sds=1.0)
+
+𝞡Builds a transform to implement an extended Kalman Filter update.
+EKF applies an online update to a (diagonal) Gaussian posterior over the parameters.
+The approximate Bayesian update is based on the linearization +$$ +\log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F_d(μ) (θ - μ) +$$ +where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate +(or equivalently the likelihood inverse temperature), +\(g(μ)\) is the gradient of the log likelihood at μ and \(F_d(μ)\) is the diagonal +empirical Fisher information matrix at \(μ\) for data \(y\). Completing the square +regains a diagonal Normal distribution over the parameters.
+For more information on extended Kalman filtering as well as an equivalence +to (online) natural gradient descent see Ollivier, 2019.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
log_likelihood |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log-likelihood value as well as auxiliary information, +e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Inverse temperature of the update, which behaves like a learning rate. + |
+ + required + | +
transition_sd |
+
+ float
+ |
+
+
+
+ Standard deviation of the transition noise, to additively +inflate the diagonal covariance before the update. + |
+
+ 0.0
+ |
+
per_sample |
+
+ bool
+ |
+
+
+
+ If True, then log_likelihood is assumed to return a vector of +log likelihoods for each sample in the batch. If False, then log_likelihood +is assumed to return a scalar log likelihood for the whole batch, in this +case torch.func.vmap will be called, this is typically slower than +directly writing log_likelihood to be per sample. + |
+
+ False
+ |
+
init_sds |
+
+ TensorTree | float
+ |
+
+
+
+ Initial square-root diagonal of the covariance matrix +of the Normal distribution. Can be tree like params or scalar. + |
+
+ 1.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ Diagonal EKF transform instance. + |
+
posteriors/ekf/diag_fisher.py
17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 |
|
posteriors.ekf.diag_fisher.EKFDiagState
+
+
+𝞡
+ Bases: NamedTuple
State encoding a diagonal Normal distribution over parameters.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+
sd_diag |
+
+ TensorTree
+ |
+
+
+
+ Square-root diagonal of the covariance matrix of the +Normal distribution. + |
+
log_likelihood |
+
+ Tensor
+ |
+
+
+
+ Log likelihood of the data given the parameters. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_likelihood call. + |
+
posteriors/ekf/diag_fisher.py
70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 |
|
posteriors.ekf.diag_fisher.init(params, init_sds=1.0)
+
+𝞡Initialise diagonal Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Initial mean of the Normal distribution. + |
+ + required + | +
init_sds |
+
+ TensorTree | float
+ |
+
+
+
+ Initial square-root diagonal of the covariance matrix +of the Normal distribution. Can be tree like params or scalar. + |
+
+ 1.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ EKFDiagState
+ |
+
+
+
+ Initial EKFDiagState. + |
+
posteriors/ekf/diag_fisher.py
87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 |
|
posteriors.ekf.diag_fisher.update(state, batch, log_likelihood, lr, transition_sd=0.0, per_sample=False, inplace=False)
+
+𝞡Applies an extended Kalman Filter update to the diagonal Normal distribution. +The approximate Bayesian update is based on the linearization +$$ +\log p(θ | y) ≈ \log p(θ) + ε g(μ)ᵀ(θ - μ) + \frac12 ε (θ - μ)^T F_d(μ) (θ - μ) +$$ +where \(μ\) is the mean of the prior distribution, \(ε\) is the learning rate +(or equivalently the likelihood inverse temperature), +\(g(μ)\) is the gradient of the log likelihood at μ and \(F_d(μ)\) is the diagonal +empirical Fisher information matrix at \(μ\) for data \(y\). Completing the square +regains a diagonal Normal distribution over the parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ EKFDiagState
+ |
+
+
+
+ Current state. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Input data to log_likelihood. + |
+ + required + | +
log_likelihood |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log-likelihood value as well as auxiliary information, +e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Inverse temperature of the update, which behaves like a learning rate. + |
+ + required + | +
transition_sd |
+
+ float
+ |
+
+
+
+ Standard deviation of the transition noise, to additively +inflate the diagonal covariance before the update. + |
+
+ 0.0
+ |
+
per_sample |
+
+ bool
+ |
+
+
+
+ If True, then log_likelihood is assumed to return a vector of +log likelihoods for each sample in the batch. If False, then log_likelihood +is assumed to return a scalar log likelihood for the whole batch, in this +case torch.func.vmap will be called, this is typically slower than +directly writing log_likelihood to be per sample. + |
+
+ False
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ Whether to update the state parameters in-place. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ EKFDiagState
+ |
+
+
+
+ Updated EKFDiagState. + |
+
posteriors/ekf/diag_fisher.py
110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 |
|
posteriors.ekf.diag_fisher.sample(state, sample_shape=torch.Size([]))
+
+𝞡Single sample from diagonal Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ EKFDiagState
+ |
+
+
+
+ State encoding mean and standard deviations. + |
+ + required + | +
sample_shape |
+
+ Size
+ |
+
+
+
+ Shape of the desired samples. + |
+
+ Size([])
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ Sample(s) from Normal distribution. + |
+
posteriors/ekf/diag_fisher.py
183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 |
|
ekf.dense_fisher
applies an online Bayesian update based
+on a Taylor approximation of the log-likelihood. Uses the empirical Fisher
+information matrix as a positive-definite alternative to the Hessian.
+Natural gradient descent equivalence following Ollivier, 2019.ekf.diag_fisher
same as ekf.dense_fisher
but
+uses the diagonal of the empirical Fisher information matrix instead.laplace.dense_fisher
calculates the empirical Fisher
+information matrix and uses it to approximate the posterior precision, i.e. a Laplace
+approximation.laplace.dense_ggn
calculates the Generalised
+Gauss-Newton matrix which is equivalent to the non-empirical Fisher in most
+neural network settings - see Martens, 2020.laplace.diag_fisher
same as laplace.dense_fisher
but
+uses the diagonal of the empirical Fisher information matrix instead.laplace.diag_ggn
same as laplace.dense_ggn
but
+uses the diagonal of the Generalised Gauss-Newton matrix instead.All Laplace transforms leave the parameters unmodified. Comprehensive details on Laplace approximations can be found in Daxberger et al, 2021.
+sgmcmc.sgld
implements stochastic gradient Langevin dynamics
+(SGLD) from Welling and Teh, 2011.sgmcmc.sghmc
implements the stochastic gradient Hamiltonian
+Monte Carlo (SGHMC) algorithm from Chen et al, 2014
+(without momenta resampling).sgmcmc.sgnht
implements the stochastic gradient Nosé-Hoover
+thermostat (SGNHT) algorithm from Ding et al, 2014,
+(SGHMC with adaptive friction coefficient).For an overview and unifying framework for SGMCMC methods, see Ma et al, 2015.
+vi.diag
implements a diagonal Gaussian variational distribution.
+Expects a torchopt
optimizer for handling the
+minimization of the NELBO. Also find vi.diag.nelbo
for simply calculating the NELBO
+with respect to a log_posterior
and diagonal Gaussian distribution.A review of variational inference can be found in Blei et al, 2017.
+optim
wrapper for torch.optim
optimizers within the unified posteriors
+API that allows for easy swapping with UQ methods.posteriors.laplace.dense_fisher.build(log_posterior, per_sample=False, init_prec=0.0)
+
+𝞡Builds a transform for dense empirical Fisher information +Laplace approximation.
+The empirical Fisher is defined here as: +$$ +F(θ) = \sum_i ∇_θ \log p(y_i, θ | x_i) ∇_θ \log p(y_i, θ | x_i)^T +$$ +where \(p(y_i, θ | x_i)\) is the joint model distribution (equivalent to the posterior +up to proportionality) with parameters \(θ\), inputs \(x_i\) and labels \(y_i\).
+More info on empirical Fisher matrices can be found in +Martens, 2020 and +their use within a Laplace approximation in Daxberger et al, 2021.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
per_sample |
+
+ bool
+ |
+
+
+
+ If True, then log_posterior is assumed to return a vector of +log posteriors for each sample in the batch. If False, then log_posterior +is assumed to return a scalar log posterior for the whole batch, in this +case torch.func.vmap will be called, this is typically slower than +directly writing log_posterior to be per sample. + |
+
+ False
+ |
+
init_prec |
+
+ Tensor | float
+ |
+
+
+
+ Initial precision matrix. +If it is a float, it is defined as an identity matrix +scaled by that float. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ Empirical Fisher information Laplace approximation transform instance. + |
+
posteriors/laplace/dense_fisher.py
17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 |
|
posteriors.laplace.dense_fisher.DenseLaplaceState
+
+
+𝞡
+ Bases: NamedTuple
State encoding a Normal distribution over parameters, +with a dense precision matrix
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+
prec |
+
+ Tensor
+ |
+
+
+
+ Precision matrix of the Normal distribution. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_posterior call. + |
+
posteriors/laplace/dense_fisher.py
57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 |
|
posteriors.laplace.dense_fisher.init(params, init_prec=0.0)
+
+𝞡Initialise Normal distribution over parameters +with a dense precision matrix.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+ + required + | +
init_prec |
+
+ Tensor | float
+ |
+
+
+
+ Initial precision matrix. +If it is a float, it is defined as an identity matrix +scaled by that float. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ DenseLaplaceState
+ |
+
+
+
+ Initial DenseLaplaceState. + |
+
posteriors/laplace/dense_fisher.py
72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 |
|
posteriors.laplace.dense_fisher.update(state, batch, log_posterior, per_sample=False, inplace=False)
+
+𝞡Adds empirical Fisher information matrix of covariance summed over +given batch.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ DenseLaplaceState
+ |
+
+
+
+ Current state. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Input data to log_posterior. + |
+ + required + | +
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) + |
+ + required + | +
per_sample |
+
+ bool
+ |
+
+
+
+ If True, then log_posterior is assumed to return a vector of +log posteriors for each sample in the batch. If False, then log_posterior +is assumed to return a scalar log posterior for the whole batch, in this +case torch.func.vmap will be called, this is typically slower than +directly writing log_posterior to be per sample. + |
+
+ False
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ If True, the state is updated in place. Otherwise, a new +state is returned. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ DenseLaplaceState
+ |
+
+
+
+ Updated DenseLaplaceState. + |
+
posteriors/laplace/dense_fisher.py
96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 |
|
posteriors.laplace.dense_fisher.sample(state, sample_shape=torch.Size([]))
+
+𝞡Sample from Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ DenseLaplaceState
+ |
+
+
+
+ State encoding mean and precision matrix. + |
+ + required + | +
sample_shape |
+
+ Size
+ |
+
+
+
+ Shape of the desired samples. + |
+
+ Size([])
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ Sample(s) from the Normal distribution. + |
+
posteriors/laplace/dense_fisher.py
137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 |
|
posteriors.laplace.dense_ggn.build(forward, outer_log_likelihood, init_prec=0.0)
+
+𝞡Builds a transform for a Generalized Gauss-Newton (GGN) +Laplace approximation.
+Equivalent to the (non-empirical) Fisher information matrix when
+the outer_log_likelihood
is exponential family with natural parameter equal to
+the output from forward
.
forward
should output auxiliary information (or torch.tensor([])
),
+outer_log_likelihood
should not.
The GGN is defined as +$$ +G(θ) = J_f(θ) H_l(z) J_f(θ)^T +$$ +where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) +is a loss (negative log-likelihood) that maps the output of \(f\) to a scalar output.
+More info on Fisher and GGN matrices can be found in +Martens, 2020 and +their use within a Laplace approximation in Daxberger et al, 2021.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
forward |
+
+ ForwardFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns a forward value (e.g. logits), not reduced over the batch, +as well as auxiliary information. + |
+ + required + | +
outer_log_likelihood |
+
+ OuterLogProbFn
+ |
+
+
+
+ A function that takes the output of |
+ + required + | +
init_prec |
+
+ TensorTree | float
+ |
+
+
+
+ Initial precision matrix. +If it is a float, it is defined as an identity matrix +scaled by that float. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ GGN Laplace approximation transform instance. + |
+
posteriors/laplace/dense_ggn.py
21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 |
|
posteriors.laplace.dense_ggn.DenseLaplaceState
+
+
+𝞡
+ Bases: NamedTuple
State encoding a Normal distribution over parameters, +with a dense precision matrix
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+
prec |
+
+ Tensor
+ |
+
+
+
+ Precision matrix of the Normal distribution. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_posterior call. + |
+
posteriors/laplace/dense_ggn.py
68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 |
|
posteriors.laplace.dense_ggn.init(params, init_prec=0.0)
+
+𝞡Initialise Normal distribution over parameters +with a dense precision matrix.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+ + required + | +
init_prec |
+
+ Tensor | float
+ |
+
+
+
+ Initial precision matrix. +If it is a float, it is defined as an identity matrix +scaled by that float. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ DenseLaplaceState
+ |
+
+
+
+ Initial DenseLaplaceState. + |
+
posteriors/laplace/dense_ggn.py
83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 |
|
posteriors.laplace.dense_ggn.update(state, batch, forward, outer_log_likelihood, inplace=False)
+
+𝞡Adds GGN matrix over given batch.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ DenseLaplaceState
+ |
+
+
+
+ Current state. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Input data to model. + |
+ + required + | +
forward |
+
+ ForwardFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns a forward value (e.g. logits), not reduced over the batch, +as well as auxiliary information. + |
+ + required + | +
outer_log_likelihood |
+
+ OuterLogProbFn
+ |
+
+
+
+ A function that takes the output of |
+ + required + | +
inplace |
+
+ bool
+ |
+
+
+
+ If True, then the state is updated in place, otherwise a new state +is returned. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ DenseLaplaceState
+ |
+
+
+
+ Updated DenseLaplaceState. + |
+
posteriors/laplace/dense_ggn.py
107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 |
|
posteriors.laplace.dense_ggn.sample(state, sample_shape=torch.Size([]))
+
+𝞡Sample from Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ DenseLaplaceState
+ |
+
+
+
+ State encoding mean and precision matrix. + |
+ + required + | +
sample_shape |
+
+ Size
+ |
+
+
+
+ Shape of the desired samples. + |
+
+ Size([])
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ Sample(s) from the Normal distribution. + |
+
posteriors/laplace/dense_ggn.py
151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 |
|
posteriors.laplace.diag_fisher.build(log_posterior, per_sample=False, init_prec_diag=0.0)
+
+𝞡Builds a transform for diagonal empirical Fisher information +Laplace approximation.
+The empirical Fisher is defined here as: +$$ +F(θ) = \sum_i ∇_θ \log p(y_i, θ | x_i) ∇_θ \log p(y_i, θ | x_i)^T +$$ +where \(p(y_i, θ | x_i)\) is the joint model distribution (equivalent to the posterior +up to proportionality) with parameters \(θ\), inputs \(x_i\) and labels \(y_i\).
+More info on empirical Fisher matrices can be found in +Martens, 2020 and +their use within a Laplace approximation in Daxberger et al, 2021.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
per_sample |
+
+ bool
+ |
+
+
+
+ If True, then log_posterior is assumed to return a vector of +log posteriors for each sample in the batch. If False, then log_posterior +is assumed to return a scalar log posterior for the whole batch, in this +case torch.func.vmap will be called, this is typically slower than +directly writing log_posterior to be per sample. + |
+
+ False
+ |
+
init_prec_diag |
+
+ TensorTree | float
+ |
+
+
+
+ Initial diagonal precision matrix. +Can be tree like params or scalar. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ Diagonal empirical Fisher information Laplace approximation transform instance. + |
+
posteriors/laplace/diag_fisher.py
17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 |
|
posteriors.laplace.diag_fisher.DiagLaplaceState
+
+
+𝞡
+ Bases: NamedTuple
State encoding a diagonal Normal distribution over parameters.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+
prec_diag |
+
+ TensorTree
+ |
+
+
+
+ Diagonal of the precision matrix of the Normal distribution. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_posterior call. + |
+
posteriors/laplace/diag_fisher.py
56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 |
|
posteriors.laplace.diag_fisher.init(params, init_prec_diag=0.0)
+
+𝞡Initialise diagonal Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+ + required + | +
init_prec_diag |
+
+ TensorTree | float
+ |
+
+
+
+ Initial diagonal precision matrix. +Can be tree like params or scalar. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ DiagLaplaceState
+ |
+
+
+
+ Initial DiagLaplaceState. + |
+
posteriors/laplace/diag_fisher.py
70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 |
|
posteriors.laplace.diag_fisher.update(state, batch, log_posterior, per_sample=False, inplace=False)
+
+𝞡Adds diagonal empirical Fisher information matrix of covariance summed over +given batch.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ DiagLaplaceState
+ |
+
+
+
+ Current state. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Input data to log_posterior. + |
+ + required + | +
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
per_sample |
+
+ bool
+ |
+
+
+
+ If True, then log_posterior is assumed to return a vector of +log posteriors for each sample in the batch. If False, then log_posterior +is assumed to return a scalar log posterior for the whole batch, in this +case torch.func.vmap will be called, this is typically slower than +directly writing log_posterior to be per sample. + |
+
+ False
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ If True, then the state is updated in place, otherwise a new state +is returned. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ DiagLaplaceState
+ |
+
+
+
+ Updated DiagLaplaceState. + |
+
posteriors/laplace/diag_fisher.py
93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 |
|
posteriors.laplace.diag_fisher.sample(state, sample_shape=torch.Size([]))
+
+𝞡Sample from diagonal Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ DiagLaplaceState
+ |
+
+
+
+ State encoding mean and diagonal precision. + |
+ + required + | +
sample_shape |
+
+ Size
+ |
+
+
+
+ Shape of the desired samples. + |
+
+ Size([])
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ Sample(s) from Normal distribution. + |
+
posteriors/laplace/diag_fisher.py
139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 |
|
posteriors.laplace.diag_ggn.build(forward, outer_log_likelihood, init_prec_diag=0.0)
+
+𝞡Builds a transform for a diagonal Generalized Gauss-Newton (GGN) +Laplace approximation.
+Equivalent to the diagonal of the (non-empirical) Fisher information matrix when
+the outer_log_likelihood
is exponential family with natural parameter equal to
+the output from forward
.
forward
should output auxiliary information (or torch.tensor([])
),
+outer_log_likelihood
should not.
The GGN is defined as +$$ +G(θ) = J_f(θ) H_l(z) J_f(θ)^T +$$ +where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) +is a loss (negative log-likelihood) that maps the output of \(f\) to a scalar output.
+More info on Fisher and GGN matrices can be found in +Martens, 2020 and +their use within a Laplace approximation in Daxberger et al, 2021.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
forward |
+
+ ForwardFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns a forward value (e.g. logits), not reduced over the batch, +as well as auxiliary information. + |
+ + required + | +
outer_log_likelihood |
+
+ OuterLogProbFn
+ |
+
+
+
+ A function that takes the output of |
+ + required + | +
init_prec_diag |
+
+ TensorTree | float
+ |
+
+
+
+ Initial diagonal precision matrix. +Can be tree like params or scalar. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ Diagonal GGN Laplace approximation transform instance. + |
+
posteriors/laplace/diag_ggn.py
21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 |
|
posteriors.laplace.diag_ggn.DiagLaplaceState
+
+
+𝞡
+ Bases: NamedTuple
State encoding a diagonal Normal distribution over parameters.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+
prec_diag |
+
+ TensorTree
+ |
+
+
+
+ Diagonal of the precision matrix of the Normal distribution. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_posterior call. + |
+
posteriors/laplace/diag_ggn.py
67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 |
|
posteriors.laplace.diag_ggn.init(params, init_prec_diag=0.0)
+
+𝞡Initialise diagonal Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the Normal distribution. + |
+ + required + | +
init_prec_diag |
+
+ TensorTree | float
+ |
+
+
+
+ Initial diagonal precision matrix. +Can be tree like params or scalar. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ DiagLaplaceState
+ |
+
+
+
+ Initial DiagLaplaceState. + |
+
posteriors/laplace/diag_ggn.py
81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 |
|
posteriors.laplace.diag_ggn.update(state, batch, forward, outer_log_likelihood, inplace=False)
+
+𝞡Adds diagonal GGN matrix of covariance summed over given batch.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ DiagLaplaceState
+ |
+
+
+
+ Current state. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Input data to model. + |
+ + required + | +
forward |
+
+ ForwardFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns a forward value (e.g. logits), not reduced over the batch, +as well as auxiliary information. + |
+ + required + | +
outer_log_likelihood |
+
+ OuterLogProbFn
+ |
+
+
+
+ A function that takes the output of |
+ + required + | +
inplace |
+
+ bool
+ |
+
+
+
+ If True, then the state is updated in place, otherwise a new state +is returned. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ DiagLaplaceState
+ |
+
+
+
+ Updated DiagLaplaceState. + |
+
posteriors/laplace/diag_ggn.py
104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 |
|
posteriors.laplace.diag_ggn.sample(state, sample_shape=torch.Size([]))
+
+𝞡Sample from diagonal Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ DiagLaplaceState
+ |
+
+
+
+ State encoding mean and diagonal precision. + |
+ + required + | +
sample_shape |
+
+ Size
+ |
+
+
+
+ Shape of the desired samples. + |
+
+ Size([])
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ Sample(s) from Normal distribution. + |
+
posteriors/laplace/diag_ggn.py
153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 |
|
posteriors.optim.build(loss_fn, optimizer, **kwargs)
+
+𝞡Builds an optimizer transform from torch.optim
+transform = build(loss_fn, torch.optim.Adam, lr=0.1)
+state = transform.init(params)
+
+for batch in dataloader:
+ state = transform.update(state, batch)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
loss_fn |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes the parameters and returns the loss.
+of the form |
+ + required + | +
optimizer |
+
+ Type[Optimizer]
+ |
+
+
+
+ Optimizer class from torch.optim. + |
+ + required + | +
**kwargs |
+
+ Any
+ |
+
+
+
+ Keyword arguments to pass to the optimizer class. + |
+
+ {}
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+
|
+
posteriors/optim.py
10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 |
|
posteriors.optim.OptimState
+
+
+𝞡
+ Bases: NamedTuple
State of an optimizer from torch.optim.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters to be optimized. + |
+
optimizer |
+
+ Optimizer
+ |
+
+
+
+ torch.optim optimizer instance. + |
+
loss |
+
+ tensor
+ |
+
+
+
+ Loss value. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the loss function call. + |
+
posteriors/optim.py
39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 |
|
posteriors.optim.init(params, optimizer_cls, *args, **kwargs)
+
+𝞡Initialise a torch.optim optimizer +state.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters to be optimized. + |
+ + required + | +
optimizer_cls |
+
+ Type[Optimizer]
+ |
+
+
+
+ Optimizer class from torch.optim. + |
+ + required + | +
*args |
+
+ Any
+ |
+
+
+
+ Positional arguments to pass to the optimizer class. + |
+
+ ()
+ |
+
**kwargs |
+
+ Any
+ |
+
+
+
+ Keyword arguments to pass to the optimizer class. + |
+
+ {}
+ |
+
Returns:
+Type | +Description | +
---|---|
+ OptimState
+ |
+
+
+
+ Initial OptimState. + |
+
posteriors/optim.py
55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 |
|
posteriors.optim.update(state, batch, loss_fn, inplace=True)
+
+𝞡Perform a single update step of a torch.optim +optimizer.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ OptimState
+ |
+
+
+
+ Current optimizer state. + |
+ + required + | +
batch |
+
+ TensorTree
+ |
+
+
+
+ Input data to loss_fn. + |
+ + required + | +
loss_fn |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes the parameters and returns the loss.
+of the form |
+ + required + | +
inplace |
+
+ bool
+ |
+
+
+
+ Whether to update the parameters in place. +inplace=False not supported for posteriors.optim + |
+
+ True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ OptimState
+ |
+
+
+
+ Updated OptimState. + |
+
posteriors/optim.py
79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 |
|
posteriors.sgmcmc.sghmc.build(log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, momenta=None)
+
+𝞡Builds SGHMC transform.
+Algorithm from Chen et al, 2014:
+for learning rate \(\epsilon\) and temperature \(T\)
+Targets \(p_T(θ, m) \propto \exp( (\log p(θ) - \frac{1}{2σ^2} m^Tm) / T)\) +with temperature \(T\).
+The log posterior and temperature are recommended to be constructed in tandem +to ensure robust scaling for a large amount of data and variable batch size.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Learning rate. + |
+ + required + | +
alpha |
+
+ float
+ |
+
+
+
+ Friction coefficient. + |
+
+ 0.01
+ |
+
beta |
+
+ float
+ |
+
+
+
+ Gradient noise coefficient (estimated variance). + |
+
+ 0.0
+ |
+
sigma |
+
+ float
+ |
+
+
+
+ Standard deviation of momenta target distribution. + |
+
+ 1.0
+ |
+
temperature |
+
+ float
+ |
+
+
+
+ Temperature of the joint parameter + momenta distribution. + |
+
+ 1.0
+ |
+
momenta |
+
+ TensorTree | float | None
+ |
+
+
+
+ Initial momenta. Can be tree like params or scalar. +Defaults to random iid samples from N(0, 1). + |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ SGHMC transform instance. + |
+
posteriors/sgmcmc/sghmc.py
12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 |
|
posteriors.sgmcmc.sghmc.SGHMCState
+
+
+𝞡
+ Bases: NamedTuple
State encoding params and momenta for SGHMC.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters. + |
+
momenta |
+
+ TensorTree
+ |
+
+
+
+ Momenta for each parameter. + |
+
log_posterior |
+
+ tensor
+ |
+
+
+
+ Log posterior evaluation. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_posterior call. + |
+
posteriors/sgmcmc/sghmc.py
67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 |
|
posteriors.sgmcmc.sghmc.init(params, momenta=None)
+
+𝞡Initialise momenta for SGHMC.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters for which to initialise. + |
+ + required + | +
momenta |
+
+ TensorTree | float | None
+ |
+
+
+
+ Initial momenta. Can be tree like params or scalar. +Defaults to random iid samples from N(0, 1). + |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ SGHMCState
+ |
+
+
+
+ Initial SGHMCState containing momenta. + |
+
posteriors/sgmcmc/sghmc.py
83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 |
|
posteriors.sgmcmc.sghmc.update(state, batch, log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, inplace=False)
+
+𝞡Updates parameters and momenta for SGHMC.
+Update rule from Chen et al, 2014:
+for learning rate \(\epsilon\) and temperature \(T\)
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ SGHMCState
+ |
+
+
+
+ SGHMCState containing params and momenta. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Data batch to be send to log_posterior. + |
+ + required + | +
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Learning rate. + |
+ + required + | +
alpha |
+
+ float
+ |
+
+
+
+ Friction coefficient. + |
+
+ 0.01
+ |
+
beta |
+
+ float
+ |
+
+
+
+ Gradient noise coefficient (estimated variance). + |
+
+ 0.0
+ |
+
sigma |
+
+ float
+ |
+
+
+
+ Standard deviation of momenta target distribution. + |
+
+ 1.0
+ |
+
temperature |
+
+ float
+ |
+
+
+
+ Temperature of the joint parameter + momenta distribution. + |
+
+ 1.0
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ Whether to modify state in place. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ SGHMCState
+ |
+
+
+
+ Updated state + |
+
+ SGHMCState
+ |
+
+
+
+ (which are pointers to the inputted state tensors if inplace=True). + |
+
posteriors/sgmcmc/sghmc.py
108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 |
|
posteriors.sgmcmc.sgld.build(log_posterior, lr, beta=0.0, temperature=1.0)
+
+𝞡Builds SGLD transform.
+Algorithm from Welling and Teh, 2011: +$$ +θ_{t+1} = θ_t + ε \nabla \log p(θ_t, \text{batch}) + N(0, ε (2 - ε β) T \mathbb{I}) +$$ +for learning rate \(\epsilon\) and temperature \(T\).
+Targets \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\).
+The log posterior and temperature are recommended to be constructed in tandem +to ensure robust scaling for a large amount of data and variable batch size.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Learning rate. + |
+ + required + | +
beta |
+
+ float
+ |
+
+
+
+ Gradient noise coefficient (estimated variance). + |
+
+ 0.0
+ |
+
temperature |
+
+ float
+ |
+
+
+
+ Temperature of the sampling distribution. + |
+
+ 1.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ SGLD transform (posteriors.types.Transform instance). + |
+
posteriors/sgmcmc/sgld.py
11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 |
|
posteriors.sgmcmc.sgld.SGLDState
+
+
+𝞡
+ Bases: NamedTuple
State encoding params for SGLD.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters. + |
+
log_posterior |
+
+ tensor
+ |
+
+
+
+ Log posterior evaluation. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_posterior call. + |
+
posteriors/sgmcmc/sgld.py
51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 |
|
posteriors.sgmcmc.sgld.init(params)
+
+𝞡Initialise SGLD.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters for which to initialise. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ SGLDState
+ |
+
+
+
+ Initial SGLDState. + |
+
posteriors/sgmcmc/sgld.py
65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 |
|
posteriors.sgmcmc.sgld.update(state, batch, log_posterior, lr, beta=0.0, temperature=1.0, inplace=False)
+
+𝞡Updates parameters for SGLD.
+Update rule from Welling and Teh, 2011: +$$ +θ_{t+1} = θ_t + ε \nabla \log p(θ_t, \text{batch}) + N(0, ε (2 - ε β) T \mathbb{I}) +$$ +for lr \(\epsilon\) and temperature \(T\).
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ SGLDState
+ |
+
+
+
+ SGLDState containing params. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Data batch to be send to log_posterior. + |
+ + required + | +
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Learning rate. + |
+ + required + | +
beta |
+
+ float
+ |
+
+
+
+ Gradient noise coefficient (estimated variance). + |
+
+ 0.0
+ |
+
temperature |
+
+ float
+ |
+
+
+
+ Temperature of the sampling distribution. + |
+
+ 1.0
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ Whether to modify state in place. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ SGLDState
+ |
+
+
+
+ Updated state (which are pointers to the input state tensors if inplace=True). + |
+
posteriors/sgmcmc/sgld.py
78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
posteriors.sgmcmc.sgnht.build(log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, momenta=None, xi=None)
+
+𝞡Builds SGNHT transform.
+Algorithm from Ding et al, 2014:
+for learning rate \(\epsilon\), temperature \(T\) and parameter dimension \(d\).
+Targets \(p_T(θ, m, ξ) \propto \exp( (\log p(θ) - \frac{1}{2σ^2} m^Tm - \frac{d}{2}(ξ - α)^2) / T)\).
+The log posterior and temperature are recommended to be constructed in tandem +to ensure robust scaling for a large amount of data and variable batch size.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Learning rate. + |
+ + required + | +
alpha |
+
+ float
+ |
+
+
+
+ Friction coefficient. + |
+
+ 0.01
+ |
+
beta |
+
+ float
+ |
+
+
+
+ Gradient noise coefficient (estimated variance). + |
+
+ 0.0
+ |
+
sigma |
+
+ float
+ |
+
+
+
+ Standard deviation of momenta target distribution. + |
+
+ 1.0
+ |
+
temperature |
+
+ float
+ |
+
+
+
+ Temperature of the joint parameter + momenta distribution. + |
+
+ 1.0
+ |
+
momenta |
+
+ TensorTree | float | None
+ |
+
+
+
+ Initial momenta. Can be tree like params or scalar. +Defaults to random iid samples from N(0, 1). + |
+
+ None
+ |
+
xi |
+
+ float
+ |
+
+
+
+ Initial value for scalar thermostat ξ. Defaults to |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ SGNHT transform instance. + |
+
posteriors/sgmcmc/sgnht.py
13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 |
|
posteriors.sgmcmc.sgnht.SGNHTState
+
+
+𝞡
+ Bases: NamedTuple
State encoding params and momenta for SGNHT.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters. + |
+
momenta |
+
+ TensorTree
+ |
+
+
+
+ Momenta for each parameter. + |
+
log_posterior |
+
+ tensor
+ |
+
+
+
+ Log posterior evaluation. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_posterior call. + |
+
posteriors/sgmcmc/sgnht.py
70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 |
|
posteriors.sgmcmc.sgnht.init(params, momenta=None, xi=0.01)
+
+𝞡Initialise momenta for SGNHT.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters for which to initialise. + |
+ + required + | +
momenta |
+
+ TensorTree | float | None
+ |
+
+
+
+ Initial momenta. Can be tree like params or scalar. +Defaults to random iid samples from N(0, 1). + |
+
+ None
+ |
+
xi |
+
+ float | Tensor
+ |
+
+
+
+ Initial value for scalar thermostat ξ. + |
+
+ 0.01
+ |
+
Returns:
+Type | +Description | +
---|---|
+ SGNHTState
+ |
+
+
+
+ Initial SGNHTState containing momenta. + |
+
posteriors/sgmcmc/sgnht.py
87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 |
|
posteriors.sgmcmc.sgnht.update(state, batch, log_posterior, lr, alpha=0.01, beta=0.0, sigma=1.0, temperature=1.0, inplace=False)
+
+𝞡Updates parameters, momenta and xi for SGNHT.
+Update rule from Ding et al, 2014:
+for learning rate \(\epsilon\) and temperature \(T\)
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ SGNHTState
+ |
+
+
+
+ SGNHTState containing params, momenta and xi. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Data batch to be send to log_posterior. + |
+ + required + | +
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior value (which can be unnormalised) +as well as auxiliary information, e.g. from the model call. + |
+ + required + | +
lr |
+
+ float
+ |
+
+
+
+ Learning rate. + |
+ + required + | +
alpha |
+
+ float
+ |
+
+
+
+ Friction coefficient. + |
+
+ 0.01
+ |
+
beta |
+
+ float
+ |
+
+
+
+ Gradient noise coefficient (estimated variance). + |
+
+ 0.0
+ |
+
sigma |
+
+ float
+ |
+
+
+
+ Standard deviation of momenta target distribution. + |
+
+ 1.0
+ |
+
temperature |
+
+ float
+ |
+
+
+
+ Temperature of the joint parameter + momenta distribution. + |
+
+ 1.0
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ Whether to modify state in place. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ SGNHTState
+ |
+
+
+
+ Updated SGNHTState + |
+
+ SGNHTState
+ |
+
+
+
+ (which are pointers to the inputted state tensors if inplace=True). + |
+
posteriors/sgmcmc/sgnht.py
117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 |
|
posteriors.torchopt.build(loss_fn, optimizer)
+
+𝞡Build a TorchOpt optimizer transformation.
+Make sure to use the lower case functional optimizers e.g. torchopt.adam()
.
transform = build(loss_fn, torchopt.adam(lr=0.1))
+state = transform.init(params)
+
+for batch in dataloader:
+ state = transform.update(state, batch)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
loss_fn |
+
+ LogProbFn
+ |
+
+
+
+ Loss function. + |
+ + required + | +
optimizer |
+
+ GradientTransformation
+ |
+
+
+
+ TorchOpt functional optimizer. +Make sure to use lower case e.g. torchopt.adam() + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ Torchopt optimizer transform instance. + |
+
posteriors/torchopt.py
11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 |
|
posteriors.torchopt.TorchOptState
+
+
+𝞡
+ Bases: NamedTuple
State of a TorchOpt optimizer.
+Contains the parameters, the optimizer state for the TorchOpt optimizer, +loss value, and auxiliary information.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters to be optimized. + |
+
opt_state |
+
+ OptState
+ |
+
+
+
+ TorchOpt optimizer state. + |
+
loss |
+
+ tensor
+ |
+
+
+
+ Loss value. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the loss function call. + |
+
posteriors/torchopt.py
40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 |
|
posteriors.torchopt.init(params, optimizer)
+
+𝞡Initialise a TorchOpt optimizer.
+Make sure to use the lower case functional optimizers e.g. torchopt.adam()
.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Parameters to be optimized. + |
+ + required + | +
optimizer |
+
+ GradientTransformation
+ |
+
+
+
+ TorchOpt functional optimizer. +Make sure to use lower case e.g. torchopt.adam() + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ TorchOptState
+ |
+
+
+
+ Initial TorchOptState. + |
+
posteriors/torchopt.py
59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 |
|
posteriors.torchopt.update(state, batch, loss_fn, optimizer, inplace=False)
+
+𝞡Update the TorchOpt optimizer state.
+Make sure to use the lower case functional optimizers e.g. torchopt.adam()
.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ TorchOptState
+ |
+
+
+
+ Current state. + |
+ + required + | +
batch |
+
+ TensorTree
+ |
+
+
+
+ Batch of data. + |
+ + required + | +
loss_fn |
+
+ LogProbFn
+ |
+
+
+
+ Loss function. + |
+ + required + | +
optimizer |
+
+ GradientTransformation
+ |
+
+
+
+ TorchOpt functional optimizer. +Make sure to use lower case like torchopt.adam() + |
+ + required + | +
inplace |
+
+ bool
+ |
+
+
+
+ Whether to update the state in place. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TorchOptState
+ |
+
+
+
+ Updated TorchOptState. + |
+
posteriors/torchopt.py
79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 |
|
posteriors.tree_utils.tree_size(tree)
+
+𝞡Returns the total number of elements in a PyTree. +Not the number of leaves, but the total number of elements for all tensors in the +tree.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree of tensors. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ Number of elements in the PyTree. + |
+
posteriors/tree_utils.py
8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 |
|
posteriors.tree_utils.tree_extract(tree, f)
+
+𝞡Extracts values from a PyTree where f returns True. +False values are replaced with empty tensors.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree. + |
+ + required + | +
f |
+
+ Callable[[tensor], bool]
+ |
+
+
+
+ A function that takes a PyTree element and returns True or False. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ A PyTree with the same structure as tree where f returns True. + |
+
posteriors/tree_utils.py
26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 |
|
posteriors.tree_utils.tree_insert(full_tree, sub_tree, f=lambda _: True)
+
+𝞡Inserts sub_tree into full_tree where full_tree tensors evaluate f to True. +Both PyTrees must have the same structure.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
full_tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree to insert sub_tree into. + |
+ + required + | +
sub_tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree to insert into full_tree. + |
+ + required + | +
f |
+
+ Callable[[tensor], bool]
+ |
+
+
+
+ A function that takes a PyTree element and returns True or False. +Defaults to lambda _: True. I.e. insert on all leaves. + |
+
+ lambda _: True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ A PyTree with sub_tree inserted into full_tree. + |
+
posteriors/tree_utils.py
43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 |
|
posteriors.tree_utils.tree_insert_(full_tree, sub_tree, f=lambda _: True)
+
+𝞡Inserts sub_tree into full_tree in-place where full_tree tensors evaluate +f to True. Both PyTrees must have the same structure.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
full_tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree to insert sub_tree into. + |
+ + required + | +
sub_tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree to insert into full_tree. + |
+ + required + | +
f |
+
+ Callable[[tensor], bool]
+ |
+
+
+
+ A function that takes a PyTree element and returns True or False. +Defaults to lambda _: True. I.e. insert on all leaves. + |
+
+ lambda _: True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ A pointer to full_tree with sub_tree inserted. + |
+
posteriors/tree_utils.py
67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 |
|
posteriors.tree_utils.extract_requires_grad(tree)
+
+𝞡Extracts only parameters that require gradients.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree of tensors. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ A PyTree of tensors that require gradients. + |
+
posteriors/tree_utils.py
92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 |
|
posteriors.tree_utils.insert_requires_grad(full_tree, sub_tree)
+
+𝞡Inserts sub_tree into full_tree where full_tree tensors requires_grad. +Both PyTrees must have the same structure.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
full_tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree to insert sub_tree into. + |
+ + required + | +
sub_tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree to insert into full_tree. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ A PyTree with sub_tree inserted into full_tree. + |
+
posteriors/tree_utils.py
104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 |
|
posteriors.tree_utils.insert_requires_grad_(full_tree, sub_tree)
+
+𝞡Inserts sub_tree into full_tree in-place where full_tree tensors requires_grad. +Both PyTrees must have the same structure.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
full_tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree to insert sub_tree into. + |
+ + required + | +
sub_tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree to insert into full_tree. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ A pointer to full_tree with sub_tree inserted. + |
+
posteriors/tree_utils.py
118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 |
|
posteriors.tree_utils.extract_requires_grad_and_func(tree, func, inplace=False)
+
+𝞡Extracts only parameters that require gradients and converts a function +that takes the full parameter tree (in its first argument) +into one that takes the subtree.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tree |
+
+ TensorTree
+ |
+
+
+
+ A PyTree of tensors. + |
+ + required + | +
func |
+
+ Callable
+ |
+
+
+
+ A function that takes tree in its first argument. + |
+ + required + | +
inplace |
+
+ bool
+ |
+
+
+
+ Whether to modify the tree inplace or not whe the new function +is called. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Tuple[TensorTree, Callable]
+ |
+
+
+
+ A PyTree of tensors that require gradients and a modified func that takes the +subtree structure rather than full tree in its first argument. + |
+
posteriors/tree_utils.py
132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 |
|
posteriors.tree_utils.inplacify(func)
+
+𝞡Converts a function that takes a tensor as its first argument +into one that takes the same arguments but modifies the first argument +tensor in-place with the output of the function.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
func |
+
+ Callable
+ |
+
+
+
+ A function that takes a tensor as its first argument and a returns +a modified version of said tensor. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Callable
+ |
+
+
+
+ A function that takes a tensor as its first argument and modifies it +in-place. + |
+
posteriors/tree_utils.py
159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 |
|
posteriors.tree_utils.tree_map_inplacify_(func, tree, *rests, is_leaf=None, none_is_leaf=False, namespace='')
+
+𝞡Applies a pure function to each tensor in a PyTree in-place.
+Like optree.tree_map_ +but takes a pure function as input (and takes replaces its first argument with its +output in-place) rather than a side-effect function.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
func |
+
+ Callable
+ |
+
+
+
+ A function that takes a tensor as its first argument and a returns +a modified version of said tensor. + |
+ + required + | +
tree |
+
+ pytree
+ |
+
+
+
+ A pytree to be mapped over, with each leaf providing the first
+positional argument to function |
+ + required + | +
rests |
+
+ tuple of pytree
+ |
+
+
+
+ A tuple of pytrees, each of which has the same
+structure as |
+
+ ()
+ |
+
is_leaf |
+
+ callable
+ |
+
+
+
+ An optionally specified function that will be
+called at each flattening step. It should return a boolean, with
+ |
+
+ None
+ |
+
none_is_leaf |
+
+ bool
+ |
+
+
+
+ Whether to treat |
+
+ False
+ |
+
namespace |
+
+ str
+ |
+
+
+
+ The registry namespace used for custom pytree node
+types. (default: :const: |
+
+ ''
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ The original |
+
posteriors/tree_utils.py
180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 |
|
posteriors.tree_utils.flexi_tree_map(func, tree, *rests, inplace=False, is_leaf=None, none_is_leaf=False, namespace='')
+
+𝞡Applies a pure function to each tensor in a PyTree, with inplace argument.
+out_tensor = func(tensor, *rest_tensors)
+
where out_tensor
is of the same shape as tensor
.
+Therefore
out_tree = func(tree, *rests, inplace=True)
+
will return out_tree
a pointer to the original tree
with leaves (tensors)
+modified in place.
+If inplace=False
, flexi_tree_map
is equivalent to optree.tree_map
+and returns a new tree.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
func |
+
+ Callable
+ |
+
+
+
+ A pure function that takes a tensor as its first argument and a returns +a modified version of said tensor. + |
+ + required + | +
tree |
+
+ pytree
+ |
+
+
+
+ A pytree to be mapped over, with each leaf providing the first
+positional argument to function |
+ + required + | +
rests |
+
+ tuple of pytree
+ |
+
+
+
+ A tuple of pytrees, each of which has the same
+structure as |
+
+ ()
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ Whether to modify the tree in-place or not. + |
+
+ False
+ |
+
is_leaf |
+
+ callable
+ |
+
+
+
+ An optionally specified function that will be
+called at each flattening step. It should return a boolean, with |
+
+ None
+ |
+
none_is_leaf |
+
+ bool
+ |
+
+
+
+ Whether to treat |
+
+ False
+ |
+
namespace |
+
+ str
+ |
+
+
+
+ The registry namespace used for custom pytree node
+types. (default: :const: |
+
+ ''
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ Either the original tree modified in-place or a new tree depending on the
+ |
+
posteriors/tree_utils.py
229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 |
|
posteriors.types.TransformState
+
+
+𝞡
+ Bases: NamedTuple
A posteriors
transform state is a NamedTuple
containing the required
+information for the posteriors iterative algorithm defined by the init
and
+update
functions.
Inherit the NamedTuple
class when defining a new algorithm state, just make sure
+to include the params
and aux
fields.
class AlgorithmState(NamedTuple):
+ params: TensorTree
+ algorithm_info: Any
+ aux: Any
+
Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ PyTree containing the current value of parameters. + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the model call. + |
+
posteriors/types.py
15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 |
|
posteriors.types.InitFn
+
+
+𝞡
+ Bases: Protocol
posteriors/types.py
39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 |
|
__call__(params)
+
+
+ staticmethod
+
+
+𝞡Initiate a posteriors state with unified API:
+state = init(params)
+
where params is a PyTree of parameters. The produced state
is a
+TransformState
(NamedTuple
) containing the required information for the
+posteriors iterative algorithm defined by the init
and update
functions.
Note that this represents the init
function as stored in a Transform
+returned by an algorithm's build
function, the internal init
function in
+the algorithm module can and likely will have additional arguments.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ PyTree containing initial value of parameters. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ TransformState
+ |
+
+
+
+ The initial state ( |
+
posteriors/types.py
40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 |
|
posteriors.types.UpdateFn
+
+
+𝞡
+ Bases: Protocol
posteriors/types.py
66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 |
|
__call__(state, batch, inplace=False)
+
+
+ staticmethod
+
+
+𝞡Transform a posteriors state with unified API:
+state = update(state, batch, inplace=False)
+
where state is a NamedTuple
containing the required information for the
+posteriors iterative algorithm defined by the init
and update
functions.
Note that this represents the update
function as stored in a Transform
+returned by an algorithm's build
function, the internal update
function in
+the algorithm module can and likely will have additional arguments.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ TransformState
+ |
+
+
+
+ The state of the iterative algorithm. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ The data batch. + |
+ + required + | +
inplace |
+
+ bool
+ |
+
+
+
+ Whether to modify state using inplace operations. Defaults to True. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TransformState
+ |
+
+
+
+ The transformed state ( |
+
posteriors/types.py
67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 |
|
posteriors.types.Transform
+
+
+𝞡
+ Bases: NamedTuple
A transform contains init
and update
functions defining an iterative
+ algorithm.
Within the Transform
all algorithm specific arguments are predefined, so that the
+init
and update
functions have a unified API:
+
state = transform.init(params)
+state = transform.update(state, batch, inplace=False)
+
Note that this represents the Transform
function is returned by an algorithm's
+build
function, the internal init
and update
functions in the
+algorithm module can and likely will have additional arguments.
Attributes:
+Name | +Type | +Description | +
---|---|---|
init |
+
+ InitFn
+ |
+
+
+
+ The init function. + |
+
update |
+
+ UpdateFn
+ |
+
+
+
+ The update function. + |
+
posteriors/types.py
96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 |
|
posteriors.utils.CatchAuxError
+
+
+𝞡
+ Bases: AbstractContextManager
Context manager to catch errors when auxiliary output is not found.
+ +posteriors/utils.py
19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 |
|
posteriors.utils.model_to_function(model)
+
+𝞡Converts a model into a function that maps parameters and inputs to outputs.
+Convenience wrapper around torch.functional_call.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
model |
+
+ Module
+ |
+
+
+
+ torch.nn.Module with parameters stored in .named_parameters(). + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Callable[[TensorTree, Any], Any]
+ |
+
+
+
+ Function that takes a PyTree of parameters as well as any input +arg or kwargs and returns the output of the model. + |
+
posteriors/utils.py
42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 |
|
posteriors.utils.linearized_forward_diag(forward_func, params, batch, sd_diag)
+
+𝞡Compute the linearized forward mean and its square root covariance, assuming +posterior covariance over parameters is diagonal.
+$$ +f(x | θ) \sim N(x | f(x | θₘ), J(x | θₘ) \Sigma J(x | θₘ)^T) +$$ +where \(θₘ\) is the MAP estimate, \(\Sigma\) is the diagonal covariance approximation +at the MAP and \(J(x | θₘ)\) is the Jacobian of the forward function \(f(x | θₘ)\) with +respect to \(θₘ\).
+For more info on linearized models see Foong et al, 2019.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
forward_func |
+
+ ForwardFn
+ |
+
+
+
+ A function that takes params and batch and returns the forward +values and any auxiliary information. Forward values must be a dim=2 Tensor +with batch dimension in its first axis. + |
+ + required + | +
params |
+
+ TensorTree
+ |
+
+
+
+ PyTree of tensors. + |
+ + required + | +
batch |
+
+ TensorTree
+ |
+
+
+
+ PyTree of tensors. + |
+ + required + | +
sd_diag |
+
+ TensorTree
+ |
+
+
+
+ PyTree of tensors of same shape as params. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tuple[TensorTree, Tensor, TensorTree]
+ |
+
+
+
+ A tuple of (forward_vals, chol, aux) where forward_vals is the output of the +forward function (mean), chol is the tensor square root of the covariance +matrix (non-diagonal) and aux is auxiliary info from the forward function. + |
+
posteriors/utils.py
61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 |
|
posteriors.utils.hvp(f, primals, tangents, has_aux=False)
+
+𝞡Hessian vector product.
+H(primals) @ tangents
+where H(primals) is the Hessian of f evaluated at primals.
+Taken from jacobians_hessians.html.
+Follows API from torch.func.jvp
.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
f |
+
+ Callable
+ |
+
+
+
+ A function with scalar output. + |
+ + required + | +
primals |
+
+ tuple
+ |
+
+
+
+ Tuple of e.g. tensor or dict with tensor values to evalute f at. + |
+ + required + | +
tangents |
+
+ tuple
+ |
+
+
+
+ Tuple matching structure of primals. + |
+ + required + | +
has_aux |
+
+ bool
+ |
+
+
+
+ Whether f returns auxiliary information. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Tuple[float, TensorTree] | Tuple[float, TensorTree, Any]
+ |
+
+
+
+ Returns a (gradient, hvp_out) tuple containing the gradient of func evaluated at +primals and the Hessian-vector product. If has_aux is True, then instead +returns a (gradient, hvp_out, aux) tuple. + |
+
posteriors/utils.py
108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 |
|
posteriors.utils.fvp(f, primals, tangents, has_aux=False, normalize=False)
+
+𝞡Empirical Fisher vector product.
+F(primals) @ tangents
+where F(primals) is the empirical Fisher of f evaluated at primals.
+The empirical Fisher is defined as:
+$$
+F(θ) = J_f(θ) J_f(θ)^T
+$$
+where typically \(f_θ\) is the per-sample log likelihood (with elements
+\(\log p(y_i | x_i, θ)\) for a model with primals
\(θ\) given inputs \(x_i\) and
+labels \(y_i\)).
If normalize=True
, then \(F(θ)\) is divided by the number of outputs from f
+(i.e. batchsize).
Follows API from torch.func.jvp
.
More info on empirical Fisher matrices can be found in +Martens, 2020.
+ + +Examples:
+from functools import partial
+from optree import tree_map
+import torch
+from posteriors import fvp
+
+# Load model that outputs logits
+# Load batch = {'inputs': ..., 'labels': ...}
+
+def log_likelihood_per_sample(params, batch):
+ output = torch.func.functional_call(model, params, batch["inputs"])
+ return -torch.nn.functional.cross_entropy(
+ output, batch["labels"], reduction="none"
+ )
+
+params = dict(model.parameters())
+v = tree_map(lambda x: torch.randn_like(x), params)
+fvp_result = fvp(
+ partial(log_likelihood_per_sample, batch=batch),
+ (params,),
+ (v,)
+)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
f |
+
+ Callable
+ |
+
+
+
+ A function with tensor output. +Typically this is the per-sample log likelihood of a model. + |
+ + required + | +
primals |
+
+ tuple
+ |
+
+
+
+ Tuple of e.g. tensor or dict with tensor values to evaluate f at. + |
+ + required + | +
tangents |
+
+ tuple
+ |
+
+
+
+ Tuple matching structure of primals. + |
+ + required + | +
has_aux |
+
+ bool
+ |
+
+
+
+ Whether f returns auxiliary information. + |
+
+ False
+ |
+
normalize |
+
+ bool
+ |
+
+
+
+ Whether to normalize, divide by the dimension of the output from f. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Tuple[float, TensorTree] | Tuple[float, TensorTree, Any]
+ |
+
+
+
+ Returns a (output, fvp_out) tuple containing the output of func evaluated at +primals and the empirical Fisher-vector product. If has_aux is True, then +instead returns a (output, fvp_out, aux) tuple. + |
+
posteriors/utils.py
134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 |
|
posteriors.utils.empirical_fisher(f, argnums=0, has_aux=False, normalize=False)
+
+𝞡Constructs function to compute the empirical Fisher information matrix of a function
+f with respect to its parameters, defined as (unnormalized):
+$$
+F(θ) = J_f(θ) J_f(θ)^T
+$$
+where typically \(f_θ\) is the per-sample log likelihood (with elements
+\(\log p(y_i | x_i, θ)\) for a model with primals
\(θ\) given inputs \(x_i\) and
+labels \(y_i\)).
If normalize=True
, then \(F(θ)\) is divided by the number of outputs from f
+(i.e. batchsize).
The empirical Fisher will be provided as a square tensor with respect to the
+ravelled parameters.
+flat_params, params_unravel = optree.tree_ravel(params)
.
Follows API from torch.func.jacrev
.
More info on empirical Fisher matrices can be found in +Martens, 2020.
+ + +Examples:
+import torch
+from posteriors import empirical_fisher, per_samplify
+
+# Load model that outputs logits
+# Load batch = {'inputs': ..., 'labels': ...}
+
+def log_likelihood(params, batch):
+ output = torch.func.functional_call(model, params, batch['inputs'])
+ return -torch.nn.functional.cross_entropy(output, batch['labels'])
+
+likelihood_per_sample = per_samplify(log_likelihood)
+params = dict(model.parameters())
+ef_result = empirical_fisher(log_likelihood_per_sample)(params, batch)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
f |
+
+ Callable
+ |
+
+
+
+ A Python function that takes one or more arguments, one of which must be a +Tensor, and returns one or more Tensors. +Typically this is the per-sample log likelihood of a model. + |
+ + required + | +
argnums |
+
+ int | Sequence[int]
+ |
+
+
+
+ Optional, integer or sequence of integers. Specifies which +positional argument(s) to differentiate with respect to. + |
+
+ 0
+ |
+
has_aux |
+
+ bool
+ |
+
+
+
+ Whether f returns auxiliary information. + |
+
+ False
+ |
+
normalize |
+
+ bool
+ |
+
+
+
+ Whether to normalize, divide by the dimension of the output from f. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Callable
+ |
+
+
+
+ A function with the same arguments as f that returns the empirical Fisher, F. +If has_aux is True, then the function instead returns a tuple of (F, aux). + |
+
posteriors/utils.py
213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 |
|
posteriors.utils.ggnvp(forward, loss, primals, tangents, forward_has_aux=False, loss_has_aux=False, normalize=False)
+
+𝞡Generalised Gauss-Newton vector product.
+Equivalent to the (non-empirical) Fisher vector product when loss
is the negative
+log likelihood of an exponential family distribution as a function of its natural
+parameter.
Defined as +$$ +G(θ) = J_f(θ) H_l(z) J_f(θ)^T +$$ +where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) +is a loss function with scalar output.
+Thus \(J_f(θ)\) is the Jacobian of the forward function \(f\) evaluated
+at primals
\(θ\), with dimensions (dz, dθ)
.
+And \(H_l(z)\) is the Hessian of the loss function \(l\) evaluated at z = f(θ)
, with
+dimensions (dz, dz)
.
Follows API from torch.func.jvp
.
More info on Fisher and GGN matrices can be found in +Martens, 2020.
+ + +Examples:
+from functools import partial
+from optree import tree_map
+import torch
+from posteriors import ggnvp
+
+# Load model that outputs logits
+# Load batch = {'inputs': ..., 'labels': ...}
+
+def forward(params, inputs):
+ return torch.func.functional_call(model, params, inputs)
+
+def loss(logits, labels):
+ return torch.nn.functional.cross_entropy(logits, labels)
+
+params = dict(model.parameters())
+v = tree_map(lambda x: torch.randn_like(x), params)
+ggnvp_result = ggnvp(
+ partial(forward, inputs=batch['inputs']),
+ partial(loss, labels=batch['labels']),
+ (params,),
+ (v,),
+)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
forward |
+
+ Callable
+ |
+
+
+
+ A function with tensor output. + |
+ + required + | +
loss |
+
+ Callable
+ |
+
+
+
+ A function that maps the output of forward to a scalar output. + |
+ + required + | +
primals |
+
+ tuple
+ |
+
+
+
+ Tuple of e.g. tensor or dict with tensor values to evaluate f at. + |
+ + required + | +
tangents |
+
+ tuple
+ |
+
+
+
+ Tuple matching structure of primals. + |
+ + required + | +
forward_has_aux |
+
+ bool
+ |
+
+
+
+ Whether forward returns auxiliary information. + |
+
+ False
+ |
+
loss_has_aux |
+
+ bool
+ |
+
+
+
+ Whether loss returns auxiliary information. + |
+
+ False
+ |
+
normalize |
+
+ bool
+ |
+
+
+
+ Whether to normalize, divide by the first dimension of the output +from f. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Tuple[float, TensorTree] | Tuple[float, TensorTree, Any] | Tuple[float, TensorTree, Any, Any]
+ |
+
+
+
+ Returns a (output, ggnvp_out) tuple, where output is a tuple of
+ |
+
posteriors/utils.py
297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 |
|
posteriors.utils.ggn(forward, loss, argnums=0, forward_has_aux=False, loss_has_aux=False, normalize=False)
+
+𝞡Constructs function to compute the Generalised Gauss-Newton matrix.
+Equivalent to the (non-empirical) Fisher when loss
is the negative
+log likelihood of an exponential family distribution as a function of its natural
+parameter.
Defined as +$$ +G(θ) = J_f(θ) H_l(z) J_f(θ)^T +$$ +where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) +is a loss function with scalar output.
+Thus \(J_f(θ)\) is the Jacobian of the forward function \(f\) evaluated
+at primals
\(θ\). And \(H_l(z)\) is the Hessian of the loss function \(l\) evaluated
+at z = f(θ)
.
Requires output from forward
to be a tensor and therefore loss
takes a tensor as
+input. Although both support aux
output.
If normalize=True
, then \(G(θ)\) is divided by the size of the leading dimension of
+outputs from forward
(i.e. batchsize).
The GGN will be provided as a square tensor with respect to the
+ravelled parameters.
+flat_params, params_unravel = optree.tree_ravel(params)
.
Follows API from torch.func.jacrev
.
More info on Fisher and GGN matrices can be found in +Martens, 2020.
+ + +Examples:
+from functools import partial
+import torch
+from posteriors import ggn
+
+# Load model that outputs logits
+# Load batch = {'inputs': ..., 'labels': ...}
+
+def forward(params, inputs):
+ return torch.func.functional_call(model, params, inputs)
+
+def loss(logits, labels):
+ return torch.nn.functional.cross_entropy(logits, labels)
+
+params = dict(model.parameters())
+ggn_result = ggn(
+ partial(forward, inputs=batch['inputs']),
+ partial(loss, labels=batch['labels']),
+)(params)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
forward |
+
+ Callable
+ |
+
+
+
+ A function with tensor output. + |
+ + required + | +
loss |
+
+ Callable
+ |
+
+
+
+ A function that maps the output of forward to a scalar output. +Takes a single input and returns a scalar (and possibly aux). + |
+ + required + | +
argnums |
+
+ int | Sequence[int]
+ |
+
+
+
+ Optional, integer or sequence of integers. Specifies which
+positional argument(s) to differentiate |
+
+ 0
+ |
+
forward_has_aux |
+
+ bool
+ |
+
+
+
+ Whether forward returns auxiliary information. + |
+
+ False
+ |
+
loss_has_aux |
+
+ bool
+ |
+
+
+
+ Whether loss returns auxiliary information. + |
+
+ False
+ |
+
normalize |
+
+ bool
+ |
+
+
+
+ Whether to normalize, divide by the first dimension of the output +from f. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Callable
+ |
+
+
+
+ A function with the same arguments as f that returns the tensor GGN. +If has_aux is True, then the function instead returns a tuple of (F, aux). + |
+
posteriors/utils.py
437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 +468 +469 +470 +471 +472 +473 +474 +475 +476 +477 +478 +479 +480 +481 +482 +483 +484 +485 +486 +487 +488 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 +500 +501 +502 +503 +504 +505 +506 +507 +508 +509 +510 +511 +512 +513 +514 +515 +516 +517 +518 +519 +520 +521 +522 +523 +524 +525 +526 +527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 +538 |
|
posteriors.utils.diag_ggn(forward, loss, argnums=0, forward_has_aux=False, loss_has_aux=False, normalize=False)
+
+𝞡Constructs function to compute the diagonal of the Generalised Gauss-Newton matrix.
+Equivalent to the (non-empirical) diagonal Fisher when loss
is the negative
+log likelihood of an exponential family distribution as a function of its natural
+parameter.
The GGN is defined as +$$ +G(θ) = J_f(θ) H_l(z) J_f(θ)^T +$$ +where \(z = f(θ)\) is the output of the forward function \(f\) and \(l(z)\) +is a loss function with scalar output.
+Thus \(J_f(θ)\) is the Jacobian of the forward function \(f\) evaluated
+at primals
\(θ\). And \(H_l(z)\) is the Hessian of the loss function \(l\) evaluated
+at z = f(θ)
.
Requires output from forward
to be a tensor and therefore loss
takes a tensor as
+input. Although both support aux
output.
If normalize=True
, then \(G(θ)\) is divided by the size of the leading dimension of
+outputs from forward
(i.e. batchsize).
Unlike posteriors.ggn
, the output will be in PyTree form matching the input.
Follows API from torch.func.jacrev
.
More info on Fisher and GGN matrices can be found in +Martens, 2020.
+ + +Examples:
+from functools import partial
+import torch
+from posteriors import diag_ggn
+
+# Load model that outputs logits
+# Load batch = {'inputs': ..., 'labels': ...}
+
+def forward(params, inputs):
+ return torch.func.functional_call(model, params, inputs)
+
+def loss(logits, labels):
+ return torch.nn.functional.cross_entropy(logits, labels)
+
+params = dict(model.parameters())
+ggndiag_result = diag_ggn(
+ partial(forward, inputs=batch['inputs']),
+ partial(loss, labels=batch['labels']),
+)(params)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
forward |
+
+ Callable
+ |
+
+
+
+ A function with tensor output. + |
+ + required + | +
loss |
+
+ Callable
+ |
+
+
+
+ A function that maps the output of forward to a scalar output. +Takes a single input and returns a scalar (and possibly aux). + |
+ + required + | +
argnums |
+
+ int | Sequence[int]
+ |
+
+
+
+ Optional, integer or sequence of integers. Specifies which
+positional argument(s) to differentiate |
+
+ 0
+ |
+
forward_has_aux |
+
+ bool
+ |
+
+
+
+ Whether forward returns auxiliary information. + |
+
+ False
+ |
+
loss_has_aux |
+
+ bool
+ |
+
+
+
+ Whether loss returns auxiliary information. + |
+
+ False
+ |
+
normalize |
+
+ bool
+ |
+
+
+
+ Whether to normalize, divide by the first dimension of the output +from f. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Callable
+ |
+
+
+
+ A function with the same arguments as f that returns the diagonal GGN. +If has_aux is True, then the function instead returns a tuple of (F, aux). + |
+
posteriors/utils.py
541 +542 +543 +544 +545 +546 +547 +548 +549 +550 +551 +552 +553 +554 +555 +556 +557 +558 +559 +560 +561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +583 +584 +585 +586 +587 +588 +589 +590 +591 +592 +593 +594 +595 +596 +597 +598 +599 +600 +601 +602 +603 +604 +605 +606 +607 +608 +609 +610 +611 +612 +613 +614 +615 +616 +617 +618 +619 +620 +621 +622 +623 +624 +625 +626 +627 +628 +629 +630 +631 +632 +633 +634 +635 +636 +637 +638 +639 +640 +641 +642 +643 |
|
posteriors.utils.cg(A, b, x0=None, *, maxiter=None, damping=0.0, tol=1e-05, atol=0.0, M=_identity)
+
+𝞡Use Conjugate Gradient iteration to solve Ax = b
.
+A
is supplied as a function instead of a matrix.
Adapted from jax.scipy.sparse.linalg.cg
.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
A |
+
+ Callable
+ |
+
+
+
+ Callable that calculates the linear map (matrix-vector
+product) |
+ + required + | +
b |
+
+ TensorTree
+ |
+
+
+
+ Right hand side of the linear system representing a single vector. + |
+ + required + | +
x0 |
+
+ TensorTree
+ |
+
+
+
+ Starting guess for the solution. Must have the same structure as |
+
+ None
+ |
+
maxiter |
+
+ int
+ |
+
+
+
+ Maximum number of iterations. Iteration will stop after maxiter +steps even if the specified tolerance has not been achieved. + |
+
+ None
+ |
+
damping |
+
+ float
+ |
+
+
+
+ damping term for the mvp function. Acts as regularization. + |
+
+ 0.0
+ |
+
tol |
+
+ float
+ |
+
+
+
+ Tolerance for convergence. + |
+
+ 1e-05
+ |
+
atol |
+
+ float
+ |
+
+
+
+ Tolerance for convergence. |
+
+ 0.0
+ |
+
M |
+
+ Callable
+ |
+
+
+
+ Preconditioner for A. +See the preconditioned CG method. + |
+
+ _identity
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
x |
+ TensorTree
+ |
+
+
+
+ The converged solution. Has the same structure as |
+
info |
+ Any
+ |
+
+
+
+ Placeholder for convergence information. + |
+
posteriors/utils.py
685 +686 +687 +688 +689 +690 +691 +692 +693 +694 +695 +696 +697 +698 +699 +700 +701 +702 +703 +704 +705 +706 +707 +708 +709 +710 +711 +712 +713 +714 +715 +716 +717 +718 +719 +720 +721 +722 +723 +724 +725 +726 +727 +728 +729 +730 +731 +732 +733 +734 +735 +736 +737 +738 +739 +740 +741 +742 +743 +744 +745 +746 +747 +748 +749 +750 +751 +752 +753 +754 +755 +756 +757 +758 +759 +760 +761 +762 +763 +764 +765 +766 +767 +768 +769 +770 +771 +772 +773 |
|
posteriors.utils.diag_normal_log_prob(x, mean=0.0, sd_diag=1.0, normalize=True)
+
+𝞡Evaluate multivariate normal log probability for a diagonal covariance matrix.
+If either mean or sd_diag are scalars, they will be broadcast to the same shape as x +(in a memory efficient manner).
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+
+ TensorTree
+ |
+
+
+
+ Value to evaluate log probability at. + |
+ + required + | +
mean |
+
+ float | TensorTree
+ |
+
+
+
+ Mean of the distribution. + |
+
+ 0.0
+ |
+
sd_diag |
+
+ float | TensorTree
+ |
+
+
+
+ Square-root diagonal of the covariance matrix. + |
+
+ 1.0
+ |
+
normalize |
+
+ bool
+ |
+
+
+
+ Whether to compute normalized log probability. +If False the elementwise log prob is -0.5 * ((x - mean) / sd_diag)**2. + |
+
+ True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ float
+ |
+
+
+
+ Scalar log probability. + |
+
posteriors/utils.py
776 +777 +778 +779 +780 +781 +782 +783 +784 +785 +786 +787 +788 +789 +790 +791 +792 +793 +794 +795 +796 +797 +798 +799 +800 +801 +802 +803 +804 +805 +806 +807 +808 +809 +810 +811 +812 +813 +814 +815 +816 +817 +818 |
|
posteriors.utils.diag_normal_sample(mean, sd_diag, sample_shape=torch.Size([]))
+
+𝞡Sample from multivariate normal with diagonal covariance matrix.
+If sd_diag is scalar, it will be broadcast to the same shape as mean +(in a memory efficient manner).
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
mean |
+
+ TensorTree
+ |
+
+
+
+ Mean of the distribution. + |
+ + required + | +
sd_diag |
+
+ float | TensorTree
+ |
+
+
+
+ Square-root diagonal of the covariance matrix. + |
+ + required + | +
sample_shape |
+
+ Size
+ |
+
+
+
+ Shape of the sample. + |
+
+ Size([])
+ |
+
Returns:
+Type | +Description | +
---|---|
+ dict
+ |
+
+
+
+ Sample(s) from normal distribution with the same structure as mean and sd_diag. + |
+
posteriors/utils.py
821 +822 +823 +824 +825 +826 +827 +828 +829 +830 +831 +832 +833 +834 +835 +836 +837 +838 +839 +840 +841 +842 +843 +844 +845 +846 |
|
posteriors.utils.per_samplify(f)
+
+𝞡Converts a function that takes params and batch into one that provides an output +for each batch sample.
+output = f(params, batch)
+per_sample_output = per_samplify(f)(params, batch)
+
For more info see per_sample_grads.html
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
f |
+
+ Callable[[TensorTree, TensorTree], Any]
+ |
+
+
+
+ A function that takes params and batch provides an output with size +independent of batchsize (i.e. averaged). + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Callable[[TensorTree, TensorTree], Any]
+ |
+
+
+
+ A new function that provides an output for each batch sample.
+ |
+
posteriors/utils.py
849 +850 +851 +852 +853 +854 +855 +856 +857 +858 +859 +860 +861 +862 +863 +864 +865 +866 +867 +868 +869 +870 +871 +872 +873 +874 +875 +876 +877 +878 +879 +880 |
|
posteriors.utils.is_scalar(x)
+
+𝞡Returns True if x is a scalar (int, float, bool) or a tensor with a single element.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+
+ Any
+ |
+
+
+
+ Any object. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ bool
+ |
+
+
+
+ True if x is a scalar. + |
+
posteriors/utils.py
883 +884 +885 +886 +887 +888 +889 +890 +891 +892 |
|
posteriors.vi.diag.build(log_posterior, optimizer, temperature=1.0, n_samples=1, stl=True, init_log_sds=0.0)
+
+𝞡Builds a transform for variational inference with a diagonal Normal +distribution over parameters.
+Find \(\mu\) and diagonal \(\Sigma\) that mimimize \(\text{KL}(N(θ| \mu, \Sigma) || p_T(θ))\) +where \(p_T(θ) \propto \exp( \log p(θ) / T)\) with temperature \(T\).
+The log posterior and temperature are recommended to be constructed in tandem +to ensure robust scaling for a large amount of data.
+For more information on variational inference see Blei et al, 2017.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
log_posterior |
+
+ Callable[[TensorTree, Any], float]
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior (which can be unnormalised). + |
+ + required + | +
optimizer |
+
+ GradientTransformation
+ |
+
+
+
+ TorchOpt functional optimizer for updating the variational +parameters. Make sure to use lower case like torchopt.adam() + |
+ + required + | +
temperature |
+
+ float
+ |
+
+
+
+ Temperature to rescale (divide) log_posterior. + |
+
+ 1.0
+ |
+
n_samples |
+
+ int
+ |
+
+
+
+ Number of samples to use for Monte Carlo estimate. + |
+
+ 1
+ |
+
stl |
+
+ bool
+ |
+
+
+
+ Whether to use the stick-the-landing estimator +from (Roeder et al](https://arxiv.org/abs/1703.09194). + |
+
+ True
+ |
+
init_log_sds |
+
+ TensorTree | float
+ |
+
+
+
+ Initial log of the square-root diagonal of the covariance matrix +of the variational distribution. Can be a tree matching params or scalar. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Transform
+ |
+
+
+
+ Diagonal VI transform instance. + |
+
posteriors/vi/diag.py
18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 |
|
posteriors.vi.diag.VIDiagState
+
+
+𝞡
+ Bases: NamedTuple
State encoding a diagonal Normal variational distribution over parameters.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Mean of the variational distribution. + |
+
log_sd_diag |
+
+ TensorTree
+ |
+
+
+
+ Log of the square-root diagonal of the covariance matrix of the +variational distribution. + |
+
opt_state |
+
+ OptState
+ |
+
+
+
+ TorchOpt state storing optimizer data for updating the +variational parameters. + |
+
nelbo |
+
+ tensor
+ |
+
+
+
+ Negative evidence lower bound (lower is better). + |
+
aux |
+
+ Any
+ |
+
+
+
+ Auxiliary information from the log_posterior call. + |
+
posteriors/vi/diag.py
64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 |
|
posteriors.vi.diag.init(params, optimizer, init_log_sds=0.0)
+
+𝞡Initialise diagonal Normal variational distribution over parameters.
+optimizer.init will be called on flattened variational parameters so hyperparameters +such as learning rate need to pre-specified through TorchOpt's functional API:
+import torchopt
+
+optimizer = torchopt.adam(lr=1e-2)
+vi_state = init(init_mean, optimizer)
+
It's assumed maximize=False for the optimizer, so that we minimize the NELBO.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
params |
+
+ TensorTree
+ |
+
+
+
+ Initial mean of the variational distribution. + |
+ + required + | +
optimizer |
+
+ GradientTransformation
+ |
+
+
+
+ TorchOpt functional optimizer for updating the variational +parameters. Make sure to use lower case like torchopt.adam() + |
+ + required + | +
init_log_sds |
+
+ TensorTree | float
+ |
+
+
+
+ Initial log of the square-root diagonal of the covariance matrix +of the variational distribution. Can be a tree matching params or scalar. + |
+
+ 0.0
+ |
+
Returns:
+Type | +Description | +
---|---|
+ VIDiagState
+ |
+
+
+
+ Initial DiagVIState. + |
+
posteriors/vi/diag.py
84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 |
|
posteriors.vi.diag.update(state, batch, log_posterior, optimizer, temperature=1.0, n_samples=1, stl=True, inplace=False)
+
+𝞡Updates the variational parameters to minimize the NELBO.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ VIDiagState
+ |
+
+
+
+ Current state. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Input data to log_posterior. + |
+ + required + | +
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior (which can be unnormalised). + |
+ + required + | +
optimizer |
+
+ GradientTransformation
+ |
+
+
+
+ TorchOpt functional optimizer for updating the variational +parameters. Make sure to use lower case like torchopt.adam() + |
+ + required + | +
temperature |
+
+ float
+ |
+
+
+
+ Temperature to rescale (divide) log_posterior. + |
+
+ 1.0
+ |
+
n_samples |
+
+ int
+ |
+
+
+
+ Number of samples to use for Monte Carlo estimate. + |
+
+ 1
+ |
+
stl |
+
+ bool
+ |
+
+
+
+ Whether to use the stick-the-landing estimator +from (Roeder et al](https://arxiv.org/abs/1703.09194). + |
+
+ True
+ |
+
inplace |
+
+ bool
+ |
+
+
+
+ Whether to modify state in place. + |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ VIDiagState
+ |
+
+
+
+ Updated DiagVIState. + |
+
posteriors/vi/diag.py
123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 |
|
posteriors.vi.diag.nelbo(mean, sd_diag, batch, log_posterior, temperature=1.0, n_samples=1, stl=True)
+
+𝞡Returns the negative evidence lower bound (NELBO) for a diagonal Normal +variational distribution over the parameters of a model.
+Monte Carlo estimate with n_samples
from q.
+$$
+\text{NELBO} = - 𝔼_{q(θ)}[\log p(y|x, θ) + \log p(θ) - \log q(θ) * T])
+$$
+for temperature \(T\).
log_posterior
expects to take parameters and input batch and return a scalar
+as well as a TensorTree of any auxiliary information:
log_posterior_eval, aux = log_posterior(params, batch)
+
The log posterior and temperature are recommended to be constructed in tandem +to ensure robust scaling for a large amount of data and variable batch size.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
mean |
+
+ dict
+ |
+
+
+
+ Mean of the variational distribution. + |
+ + required + | +
sd_diag |
+
+ dict
+ |
+
+
+
+ Square-root diagonal of the covariance matrix of the +variational distribution. + |
+ + required + | +
batch |
+
+ Any
+ |
+
+
+
+ Input data to log_posterior. + |
+ + required + | +
log_posterior |
+
+ LogProbFn
+ |
+
+
+
+ Function that takes parameters and input batch and +returns the log posterior (which can be unnormalised). + |
+ + required + | +
temperature |
+
+ float
+ |
+
+
+
+ Temperature to rescale (divide) log_posterior. + |
+
+ 1.0
+ |
+
n_samples |
+
+ int
+ |
+
+
+
+ Number of samples to use for Monte Carlo estimate. + |
+
+ 1
+ |
+
stl |
+
+ bool
+ |
+
+
+
+ Whether to use the stick-the-landing estimator +from (Roeder et al](https://arxiv.org/abs/1703.09194). + |
+
+ True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Tuple[float, Any]
+ |
+
+
+
+ The sampled approximate NELBO averaged over the batch. + |
+
posteriors/vi/diag.py
178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 |
|
posteriors.vi.diag.sample(state, sample_shape=torch.Size([]))
+
+𝞡Single sample from diagonal Normal distribution over parameters.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
state |
+
+ VIDiagState
+ |
+
+
+
+ State encoding mean and log standard deviations. + |
+ + required + | +
sample_shape |
+
+ Size
+ |
+
+
+
+ Shape of the desired samples. + |
+
+ Size([])
+ |
+
Returns:
+Type | +Description | +
---|---|
+ TensorTree
+ |
+
+
+
+ Sample(s) from Normal distribution. + |
+
posteriors/vi/diag.py
231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 |
|