From e4a16e2db98df39068fd85e84c395b34fc28475e Mon Sep 17 00:00:00 2001 From: IvanTomilov1 Date: Mon, 8 Jan 2024 03:35:04 +0300 Subject: [PATCH] gaussian wrapper test part fix --- tests/tests.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/tests.py b/tests/tests.py index 4c2bf92..94783da 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -59,7 +59,10 @@ def _test_bayes_prediction(mode: str): N, dim, data = utils.create_testing_data() model = utils.create_testing_model() n_iter = 10 - res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) + if mode != 'gauss': + res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) + else: + res = bayes_api.DropoutGaussianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter) utils.compare_values(dict, type(res), "Wrong result type") utils.compare_values(2, len(res), "Wrong dictionary length")