diff --git a/tests/tests.py b/tests/tests.py index f560d55..ac7f199 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -52,10 +52,10 @@ def test_visualization(): def test_embed_visualization(): - N, dim, data = utils.create_testing_data() + N, dim, data, labels = utils.create_testing_data() model = utils.create_testing_model() layers = ["second_layer", "third_layer"] - res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers) + res = viz_api.visualize_recurrent_layer_manifolds(model, "umap", data, layers=layers, labels=labels) utils.compare_values(dict, type(res), "Wrong result type") utils.compare_values(3, len(res), "Wrong dictionary length")