Skip to content

Commit

Permalink
fix dimensions in bayes tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tanyapole committed Jan 8, 2024
1 parent be4983b commit cae03d1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def create_testing_data():
return N, dim, data


def create_testing_model():
def create_testing_model(num_classes=10):
return nn.Sequential(
OrderedDict(
[
("first_layer", nn.Linear(256, 128)),
("second_layer", nn.Linear(128, 64)),
("third_layer", nn.Linear(64, 10)),
("third_layer", nn.Linear(64, num_classes)),
],
),
)
9 changes: 5 additions & 4 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def _test_bayes_prediction(mode: str):
}

N, dim, data = utils.create_testing_data()
model = utils.create_testing_model()
n_iter = 10
num_classes = 17
model = utils.create_testing_model(num_classes=num_classes)
n_iter = 7
if mode != 'gauss':
res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)
else:
Expand All @@ -69,8 +70,8 @@ def _test_bayes_prediction(mode: str):
utils.compare_values(dict, type(res), "Wrong result type")
utils.compare_values(2, len(res), "Wrong dictionary length")
utils.compare_values(set(["mean", "std"]), set(res.keys()), "Wrong dictionary keys")
utils.compare_values(torch.Size([N, n_iter]), res["mean"].shape, "Wrong mean shape")
utils.compare_values(torch.Size([N, n_iter]), res["std"].shape, "Wrong mean std")
utils.compare_values(torch.Size([N, num_classes]), res["mean"].shape, "Wrong mean shape")
utils.compare_values(torch.Size([N, num_classes]), res["std"].shape, "Wrong mean std")


def test_basic_bayes_wrapper():
Expand Down

0 comments on commit cae03d1

Please sign in to comment.