Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tanyapole committed Jan 8, 2024
1 parent e4a16e2 commit be4983b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 22 deletions.
4 changes: 2 additions & 2 deletions eXNN/bayes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ class DropoutGaussianWrapper:
def __init__(
self,
model: torch.nn.Module,
sigma: float
sigma: float,
):
"""Class representing bayesian equivalent of a neural network.
Args:
model (torch.nn.Module): neural network
sigma (float): std of parameters gaussian noise
"""
self.model = create_dropout_bayesian_wrapper(model, "gaussian", sigma=sigma)
self.model = create_dropout_bayesian_wrapper(model, "gauss", sigma=sigma)

def predict(self, data, n_iter) -> Dict[str, torch.Tensor]:
"""Function computes mean and standard deviation of bayesian equivalent
Expand Down
38 changes: 22 additions & 16 deletions eXNN/bayes/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
sigma: Optional[float] = None
sigma: Optional[float] = None,
):
super(ModuleBayesianWrapper, self).__init__()

Expand All @@ -26,8 +26,8 @@ def __init__(
(p is None and a is None and b is None and sigma is not None), pab_check

if (p is None) and (sigma is None):
ab_check = "If you choose to specify a and b, you must to specify both"
assert (self.a is not None) and (self.b is not None), ab_check
ab_check = "If you choose to specify a and b, you must specify both"
assert (a is not None) and (b is not None), ab_check

# At the moment we are only modifying linear and convolutional layers, so check this
self.layer = layer
Expand All @@ -50,8 +50,8 @@ def augment_weights(self, weights, bias):

else:
# If gauss is chosen, then apply it
weights = weights + (torch.randn(*weights.shape) * self.sigma).to(weights.device())
bias = bias + (torch.randn(*bias.shape) * self.sigma).to(bias.device())
weights = weights + (torch.randn(*weights.shape) * self.sigma).to(weights.device)
bias = bias + (torch.randn(*bias.shape) * self.sigma).to(bias.device)

return weights, bias

Expand Down Expand Up @@ -92,9 +92,11 @@ def __init__(

super(NetworkBayes, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(self.model,
ModuleBayesianWrapper,
{"p": dropout_p})
self.model = replace_modules_with_wrapper(
self.model,
ModuleBayesianWrapper,
{"p": dropout_p},
)

def mean_forward(
self,
Expand Down Expand Up @@ -128,9 +130,11 @@ def __init__(

super(NetworkBayesBeta, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(self.model,
ModuleBayesianWrapper,
{"a": alpha, "b": beta})
self.model = replace_modules_with_wrapper(
self.model,
ModuleBayesianWrapper,
{"a": alpha, "b": beta},
)

def mean_forward(
self,
Expand Down Expand Up @@ -158,14 +162,16 @@ class NetworkBayesGauss(nn.Module):
def __init__(
self,
model: torch.nn.Module,
sigma: float
sigma: float,
):

super(NetworkBayesGauss, self).__init__()
self.model = copy.deepcopy(model)
self.model = replace_modules_with_wrapper(self.model,
ModuleBayesianWrapper,
{"sigma": sigma})
self.model = replace_modules_with_wrapper(
self.model,
ModuleBayesianWrapper,
{"sigma": sigma},
)

def mean_forward(
self,
Expand Down Expand Up @@ -195,7 +201,7 @@ def create_dropout_bayesian_wrapper(
p: Optional[float] = None,
a: Optional[float] = None,
b: Optional[float] = None,
sigma: Optional[float] = None
sigma: Optional[float] = None,
) -> torch.nn.Module:
if mode == "basic":
net = NetworkBayes(model, p)
Expand Down
3 changes: 2 additions & 1 deletion eXNN/visualization/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def visualize_recurrent_layer_manifolds(
emb_out.update_layout(
autosize=False,
width=1000,
height=1000)
height=1000,
)
emb_out.show(renderer="colab")


Expand Down
8 changes: 5 additions & 3 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ def test_visualization():


def _test_bayes_prediction(mode: str):
params = {"basic": dict(mode="basic", p=0.5),
"beta": dict(mode="beta", a=0.9, b=0.2),
"gauss": dict(mode="gauss", sigma=1e-2)}
params = {
"basic": dict(mode="basic", p=0.5),
"beta": dict(mode="beta", a=0.9, b=0.2),
"gauss": dict(sigma=1e-2),
}

N, dim, data = utils.create_testing_data()
model = utils.create_testing_model()
Expand Down

0 comments on commit be4983b

Please sign in to comment.