Skip to content

Commit

Permalink
Gaussian wrapper added
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanTomilov1 committed Jan 8, 2024
1 parent 9147a61 commit 0462fdb
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 23 deletions.
2 changes: 1 addition & 1 deletion eXNN/bayes/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .api import DropoutBayesianWrapper
from .api import DropoutBayesianWrapper, DropoutGaussianWrapper
32 changes: 32 additions & 0 deletions eXNN/bayes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,35 @@ def predict(self, data, n_iter) -> Dict[str, torch.Tensor]:
"""
res = self.model.mean_forward(data, n_iter)
return {"mean": res[0], "std": res[1]}


class DropoutGaussianWrapper:
def __init__(
self,
model: torch.nn.Module,
sigma: float
):
"""Class representing bayesian equivalent of a neural network.
Args:
model (torch.nn.Module): neural network
sigma (float): std of parameters gaussian noise
"""
self.model = create_dropout_bayesian_wrapper(model, "gaussian", sigma = sigma)

def predict(self, data, n_iter) -> Dict[str, torch.Tensor]:
"""Function computes mean and standard deviation of bayesian equivalent
of a neural network.
Args:
data (_type_): input data of shape NxC1x...xCk,
where N is the number of data points,
C1,...,Ck are dimensions of each data point
n_iter (_type_): number of samplings form the bayesian equivalent
of a neural network
Returns:
Dict[str, torch.Tensor]: dictionary with `mean` and `std` of prediction
"""
res = self.model.mean_forward(data, n_iter)
return {"mean": res[0], "std": res[1]}
92 changes: 71 additions & 21 deletions eXNN/bayes/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,51 @@ def __init__(
layer: nn.Module,
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None
b: Optional[float] = None,
sigma: Optional[float] = None
):
super(ModuleBayesianWrapper, self).__init__()

pab_check = "You can either specify p (simple dropout), or specify a and b (beta dropout)"
assert (p is not None) != ((a is not None) and (b is not None)), pab_check
# Variables correctness checks
pab_check = "You can either specify the following options (exclusively):\n - p (simple dropout)\n - a and b (beta dropout)\n - sigma (gaussian dropout)"
assert (p is not None and a is None and b is None and sigma is None) or \
(p is None and a is not None and b is not None and sigma is None) or \
(p is None and a is None and b is None and sigma is not None), pab_check

if (p is None) and (sigma is None):
ab_check = "If you choose to specify a and b, you must to specify both"
assert (self.a is not None) and (self.b is not None), ab_check

if not type(layer) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]:
# At the moment we are only modifying linear and convolutional layers
self.layer = layer
else:
self.layer = layer
# At the moment we are only modifying linear and convolutional layers, so check this
self.layer = layer

self.p = p
self.a, self.b = a, b
if type(layer) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]:
self.p, self.a, self.b, self.sigma = p, a, b, sigma

if self.p is None:
ab_check = "If you choose to specify a and b, you must to specify both"
assert (self.a is not None) and (self.b is not None), ab_check

def dropout_weights(self, weights, bias):
if self.p is not None:
p = self.p
else:
p = Beta(torch.tensor(self.a), torch.tensor(self.b)).sample()
def augment_weights(self, weights, bias):

weights = F.dropout(weights, p, training=True)
bias = F.dropout(bias, p, training=True)
# Check if dropout is chosen
if (self.p is not None) or (self.a is not None and self.b is not None):
# Select correct option and apply dropout
if self.p is not None:
p = self.p
else:
p = Beta(torch.tensor(self.a), torch.tensor(self.b)).sample()

weights = F.dropout(weights, p, training=True)
bias = F.dropout(bias, p, training=True)

else:
# If gauss is chosen, then apply it
weights = weights + (torch.randn(*weights.shape)*self.sigma).to(weights.device())
bias = bias + (torch.randn(*bias.shape)*self.sigma).to(bias.device())

return weights, bias

def forward(self, x):

weight, bias = self.dropout_weights(self.layer.weight, self.layer.bias)
weight, bias = self.augment_weights(self.layer.weight, self.layer.bias)

if isinstance(self.layer, nn.Linear):
return F.linear(x, weight, bias)
Expand Down Expand Up @@ -141,6 +152,41 @@ def mean_forward(
dim=0,
)
return results


class NetworkBayesGauss(nn.Module):
def __init__(
self,
model: torch.nn.Module,
sigma: float
):

super(NetworkBayesGauss, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(self.model,
ModuleBayesianWrapper,
{"sigma": sigma})

def mean_forward(
self,
data: torch.Tensor,
n_iter: int,
):

results = []
for _ in range(n_iter):
results.append(self.model.forward(data))

results = torch.stack(results, dim=1)

results = torch.stack(
[
torch.mean(results, dim=1),
torch.std(results, dim=1),
],
dim=0,
)
return results


def create_dropout_bayesian_wrapper(
Expand All @@ -149,11 +195,15 @@ def create_dropout_bayesian_wrapper(
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
sigma: Optional[float] = None
) -> torch.nn.Module:
if mode == "basic":
net = NetworkBayes(model, p)

elif mode == "beta":
net = NetworkBayesBeta(model, a, b)

elif mode == 'gauss':
net = NetworkBayesGauss(model, sigma)

return net
4 changes: 3 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_visualization():


def _test_bayes_prediction(mode: str):
params = {"basic": dict(mode="basic", p=0.5), "beta": dict(mode="beta", a=0.9, b=0.2)}
params = {"basic": dict(mode="basic", p=0.5), "beta": dict(mode="beta", a=0.9, b=0.2), "gauss": dict(mode="gauss", sigma = 1e-2)}

N, dim, data = utils.create_testing_data()
model = utils.create_testing_model()
Expand All @@ -73,6 +73,8 @@ def test_basic_bayes_wrapper():
def test_beta_bayes_wrapper():
_test_bayes_prediction("beta")

def test_gauss_bayes_wrapper():
_test_bayes_prediction("gauss")

def test_data_barcode():
N, dim, data = utils.create_testing_data()
Expand Down

0 comments on commit 0462fdb

Please sign in to comment.