-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove Theano from Week 5 #348
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,260 +2,80 @@ | |
import numpy as np | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn import Parameter | ||
import math | ||
|
||
|
||
class BayesianModule(nn.Module): | ||
""" | ||
creates base class for BNN, in order to enable specific behavior | ||
""" | ||
def init(self): | ||
super().__init__() | ||
def calculate_kl(log_alpha): | ||
return 0.5 * torch.sum(torch.log1p(torch.exp(-log_alpha))) | ||
|
||
|
||
class GaussianVariational(nn.Module): | ||
#Samples weights for variational inference as in Weights Uncertainity on Neural Networks (Bayes by backprop paper) | ||
#Calculates the variational posterior part of the complexity part of the loss | ||
def __init__(self, mu, rho): | ||
super().__init__() | ||
class ModuleWrapper(nn.Module): | ||
"""Wrapper for nn.Module with support for arbitrary flags and a universal forward pass""" | ||
|
||
self.mu = nn.Parameter(mu) | ||
self.rho = nn.Parameter(rho) | ||
self.w = None | ||
self.sigma = None | ||
self.pi = np.pi | ||
self.normal = torch.distributions.Normal(0, 1) | ||
def __init__(self): | ||
super(ModuleWrapper, self).__init__() | ||
|
||
def sample(self): | ||
""" | ||
Samples weights by sampling form a Normal distribution, multiplying by a sigma, which is | ||
a function from a trainable parameter, and adding a mean | ||
sets those weights as the current ones | ||
returns: | ||
torch.tensor with same shape as self.mu and self.rho | ||
""" | ||
device = self.mu.device | ||
epsilon = self.normal.sample(self.mu.size()).to(device) | ||
self.sigma = torch.log(1 + torch.exp(self.rho)).to(device) | ||
self.w = self.mu + self.sigma * epsilon | ||
return self.w | ||
def set_flag(self, flag_name, value): | ||
setattr(self, flag_name, value) | ||
for m in self.children(): | ||
if hasattr(m, 'set_flag'): | ||
m.set_flag(flag_name, value) | ||
|
||
def log_posterior(self): | ||
|
||
""" | ||
Calculates the log_likelihood for each of the weights sampled as a part of the complexity cost | ||
returns: | ||
torch.tensor with shape [] | ||
""" | ||
|
||
assert (self.w is not None), "You can only have a log posterior for W if you've already sampled it" | ||
|
||
log_sqrt2pi = np.log(np.sqrt(2*self.pi)) | ||
log_posteriors = -log_sqrt2pi - self.sigma - (((self.w - self.mu) ** 2)/(2 * self.sigma ** 2)) | ||
return log_posteriors.mean() | ||
|
||
|
||
class ScaleMixturePrior(nn.Module): | ||
#Calculates a Scale Mixture Prior distribution for the prior part of the complexity cost on Bayes by Backprop paper | ||
def __init__(self, pi, sigma1, sigma2): | ||
super().__init__() | ||
def forward(self, x): | ||
for module in self.children(): | ||
x = module(x) | ||
|
||
self.pi = pi | ||
self.sigma1 = sigma1 | ||
self.sigma2 = sigma2 | ||
self.normal1 = torch.distributions.Normal(0, sigma1) | ||
self.normal2 = torch.distributions.Normal(0, sigma2) | ||
kl = 0.0 | ||
for module in self.modules(): | ||
if hasattr(module, 'kl_loss'): | ||
kl = kl + module.kl_loss() | ||
|
||
def log_prior(self, w): | ||
""" | ||
Calculates the log_likelihood for each of the weights sampled relative to a prior distribution as a part of the complexity cost | ||
returns: | ||
torch.tensor with shape [] | ||
""" | ||
prob_n1 = torch.exp(self.normal1.log_prob(w)) | ||
prob_n2 = torch.exp(self.normal2.log_prob(w)) | ||
prior_pdf = (self.pi * prob_n1 + (1 - self.pi) * prob_n2) | ||
return x, kl | ||
|
||
return (torch.log(prior_pdf)).mean() | ||
|
||
|
||
class BayesianLinear(BayesianModule): | ||
""" | ||
Bayesian Linear layer, implements the linear layer proposed on Weight Uncertainity on Neural Networks | ||
(Bayes by Backprop paper). | ||
Its objective is be interactable with torch nn.Module API, being able even to be chained in nn.Sequential models with other non-this-lib layers | ||
class BBBLinear(ModuleWrapper): | ||
|
||
parameters: | ||
in_fetaures: int -> incoming features for the layer | ||
out_features: int -> output features for the layer | ||
bias: bool -> whether the bias will exist (True) or set to zero (False) | ||
prior_sigma_1: float -> prior sigma on the mixture prior distribution 1 | ||
prior_sigma_2: float -> prior sigma on the mixture prior distribution 2 | ||
prior_pi: float -> pi on the scaled mixture prior | ||
freeze: bool -> wheter the model will start with frozen(deterministic) weights, or not | ||
|
||
""" | ||
def __init__(self, | ||
in_features, | ||
out_features, | ||
bias=True, | ||
prior_sigma_1 = 1, | ||
prior_sigma_2 = 0.002, | ||
prior_pi = 0.5, | ||
freeze = False): | ||
super().__init__() | ||
|
||
#our main parameters | ||
def __init__(self, in_features, out_features, alpha_shape=(1, 1), bias=True, name='BBBLinear'): | ||
super(BBBLinear, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.bias = bias | ||
self.freeze = freeze | ||
|
||
#parameters for the scale mixture prior | ||
self.prior_sigma_1 = prior_sigma_1 | ||
self.prior_sigma_2 = prior_sigma_2 | ||
self.prior_pi = prior_pi | ||
|
||
# Variational weight parameters and sample | ||
self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2)) | ||
self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-5, -4)) | ||
self.weight_sampler = GaussianVariational(self.weight_mu, self.weight_rho) | ||
|
||
# Variational bias parameters and sample | ||
self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2)) | ||
self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-5, -4)) | ||
self.bias_sampler = GaussianVariational(self.bias_mu, self.bias_rho) | ||
|
||
# Priors (as BBP paper) | ||
self.weight_prior_dist = ScaleMixturePrior(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2) | ||
self.bias_prior_dist = ScaleMixturePrior(self.prior_pi, self.prior_sigma_1, self.prior_sigma_2) | ||
self.log_prior = 0 | ||
self.log_variational_posterior = 0 | ||
self.alpha_shape = alpha_shape | ||
self.W = Parameter(torch.Tensor(out_features, in_features)) | ||
self.log_alpha = Parameter(torch.Tensor(*alpha_shape)) | ||
if bias: | ||
self.bias = Parameter(torch.Tensor(1, out_features)) | ||
else: | ||
self.register_parameter('bias', None) | ||
self.reset_parameters() | ||
self.kl_value = calculate_kl | ||
self.name = name | ||
def reset_parameters(self): | ||
stdv = 1. / math.sqrt(self.W.size(1)) | ||
self.W.data.uniform_(-stdv, stdv) | ||
self.log_alpha.data.fill_(-5.0) | ||
if self.bias is not None: | ||
self.bias.data.zero_() | ||
|
||
def forward(self, x): | ||
# Sample the weights and forward it | ||
|
||
#if the model is frozen, return frozen | ||
if self.freeze: | ||
return self.forward_frozen(x) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant new line |
||
w = self.weight_sampler.sample() | ||
mean = F.linear(x, self.W) | ||
if self.bias is not None: | ||
mean = mean + self.bias | ||
|
||
if self.bias: | ||
b = self.bias_sampler.sample() | ||
b_log_posterior = self.bias_sampler.log_posterior() | ||
b_log_prior = self.bias_prior_dist.log_prior(b) | ||
sigma = torch.exp(self.log_alpha) * self.W * self.W | ||
|
||
std = torch.sqrt(1e-16 + F.linear(x * x, sigma)) | ||
if self.training: | ||
epsilon = std.data.new(std.size()).normal_() | ||
else: | ||
b = torch.zeros((self.out_features)) | ||
b_log_posterior = 0 | ||
b_log_prior = 0 | ||
|
||
# Get the complexity cost | ||
self.log_variational_posterior = self.weight_sampler.log_posterior() + b_log_posterior | ||
self.log_prior = self.weight_prior_dist.log_prior(w) + b_log_prior | ||
epsilon = 0.0 | ||
# Local reparameterization trick | ||
out = mean + std * epsilon | ||
|
||
return F.linear(x, w, b) | ||
|
||
def forward_frozen(self, x): | ||
""" | ||
Computes the feedforward operation with the expected value for weight and biases | ||
""" | ||
if self.bias: | ||
return F.linear(x, self.weight_mu, self.bias_mu) | ||
else: | ||
return F.linear(x, self.weight_mu, torch.zeros(self.out_features)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant new line |
||
|
||
def kl_divergence_from_nn(model): | ||
|
||
""" | ||
Gathers the KL Divergence from a nn.Module object | ||
Works by gathering each Bayesian layer kl divergence and summing it, doing nothing with the non Bayesian ones | ||
""" | ||
kl_divergence = 0 | ||
for module in model.modules(): | ||
if isinstance(module, (BayesianModule)): | ||
kl_divergence += module.log_variational_posterior - module.log_prior | ||
return kl_divergence | ||
|
||
|
||
|
||
def variational_estimator(nn_class): | ||
""" | ||
This decorator adds some util methods to a nn.Module, in order to facilitate the handling of Bayesian Deep Learning features | ||
Parameters: | ||
nn_class: torch.nn.Module -> Torch neural network module | ||
Returns a nn.Module with methods for: | ||
(1) Gathering the KL Divergence along its BayesianModules; | ||
(2) Sample the Elbo Loss along its variational inferences (helps training) | ||
(3) Freeze the model, in order to predict using only their weight distribution means | ||
""" | ||
|
||
def nn_kl_divergence(self): | ||
"""Returns the sum of the KL divergence of each of the BayesianModules of the model, which are from | ||
their posterior current distribution of weights relative to a scale-mixtured prior (and simpler) distribution of weights | ||
Parameters: | ||
N/a | ||
Returns torch.tensor with 0 dim. | ||
|
||
""" | ||
return kl_divergence_from_nn(self) | ||
|
||
setattr(nn_class, "nn_kl_divergence", nn_kl_divergence) | ||
|
||
def sample_elbo(self, | ||
inputs, | ||
labels, | ||
criterion, | ||
sample_nbr, | ||
complexity_cost_weight=1): | ||
|
||
""" Samples the ELBO Loss for a batch of data, consisting of inputs and corresponding-by-index labels | ||
The ELBO Loss consists of the sum of the KL Divergence of the model | ||
(explained above, interpreted as a "complexity part" of the loss) | ||
with the actual criterion - (loss function) of optimization of our model | ||
(the performance part of the loss). | ||
As we are using variational inference, it takes several (quantified by the parameter sample_nbr) Monte-Carlo | ||
samples of the weights in order to gather a better approximation for the loss. | ||
Parameters: | ||
inputs: torch.tensor -> the input data to the model | ||
labels: torch.tensor -> label data for the performance-part of the loss calculation | ||
The shape of the labels must match the label-parameter shape of the criterion (one hot encoded or as index, if needed) | ||
criterion: torch.nn.Module, custom criterion (loss) function, torch.nn.functional function -> criterion to gather | ||
the performance cost for the model | ||
sample_nbr: int -> The number of times of the weight-sampling and predictions done in our Monte-Carlo approach to | ||
gather the loss to be .backwarded in the optimization of the model. | ||
|
||
""" | ||
|
||
loss = 0 | ||
for _ in range(sample_nbr): | ||
outputs = self(inputs) | ||
loss = criterion(outputs, labels) | ||
loss += self.nn_kl_divergence() * complexity_cost_weight | ||
return loss / sample_nbr | ||
|
||
setattr(nn_class, "sample_elbo", sample_elbo) | ||
|
||
|
||
def freeze_model(self): | ||
""" | ||
Freezes the model by making it predict using only the expected value to their BayesianModules' weights distributions | ||
""" | ||
for module in self.modules(): | ||
if isinstance(module, (BayesianModule)): | ||
module.freeze = True | ||
|
||
setattr(nn_class, "freeze", freeze_model) | ||
|
||
def unfreeze_model(self): | ||
""" | ||
Unfreezes the model by letting it draw its weights with uncertanity from their correspondent distributions | ||
""" | ||
|
||
for module in self.modules(): | ||
if isinstance(module, (BayesianModule)): | ||
module.freeze = False | ||
return out | ||
|
||
setattr(nn_class, "unfreeze", unfreeze_model) | ||
return nn_class | ||
def kl_loss(self): | ||
return self.W.nelement() * self.kl_value(self.log_alpha) / self.log_alpha.nelement() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you write a detailed docstrig here?