diff --git a/tests/test_utils.py b/tests/test_utils.py index a24431a..3635bdf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,13 +21,13 @@ def create_testing_data(): return N, dim, data -def create_testing_model(): +def create_testing_model(num_classes=10): return nn.Sequential( OrderedDict( [ ("first_layer", nn.Linear(256, 128)), ("second_layer", nn.Linear(128, 64)), - ("third_layer", nn.Linear(64, 10)), + ("third_layer", nn.Linear(64, num_classes)), ], ), ) diff --git a/tests/tests.py b/tests/tests.py index 5cd5a70..af4ffe0 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -59,8 +59,9 @@ def _test_bayes_prediction(mode: str): } N, dim, data = utils.create_testing_data() - model = utils.create_testing_model() - n_iter = 10 + num_classes = 17 + model = utils.create_testing_model(num_classes=num_classes) + n_iter = 7 if mode != 'gauss': res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) else: @@ -69,8 +70,8 @@ def _test_bayes_prediction(mode: str): utils.compare_values(dict, type(res), "Wrong result type") utils.compare_values(2, len(res), "Wrong dictionary length") utils.compare_values(set(["mean", "std"]), set(res.keys()), "Wrong dictionary keys") - utils.compare_values(torch.Size([N, n_iter]), res["mean"].shape, "Wrong mean shape") - utils.compare_values(torch.Size([N, n_iter]), res["std"].shape, "Wrong mean std") + utils.compare_values(torch.Size([N, num_classes]), res["mean"].shape, "Wrong mean shape") + utils.compare_values(torch.Size([N, num_classes]), res["std"].shape, "Wrong mean std") def test_basic_bayes_wrapper():