Skip to content

Commit

Permalink
gaussian wrapper test part fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanTomilov1 committed Jan 8, 2024
1 parent 30e97e7 commit e4a16e2
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def _test_bayes_prediction(mode: str):
N, dim, data = utils.create_testing_data()
model = utils.create_testing_model()
n_iter = 10
res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)
if mode != 'gauss':
res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)
else:
res = bayes_api.DropoutGaussianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)

utils.compare_values(dict, type(res), "Wrong result type")
utils.compare_values(2, len(res), "Wrong dictionary length")
Expand Down

0 comments on commit e4a16e2

Please sign in to comment.