diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 31e3de91..166cafad 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Literal, Mapping, Optional +from typing import Any, Callable, Literal, Mapping, Optional import pytest @@ -197,7 +197,10 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg] + if arg == "initializer": + assert isinstance(getattr(solver, val), Callable) + else: + assert getattr(solver, val) == args_to_check[arg] sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index c5ec8e08..5a81ae57 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping, Optional +from typing import Any, Callable, List, Mapping, Optional import pytest @@ -442,7 +442,10 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A for arg, val in args.items(): assert hasattr(solver, val) el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val) - assert el == args_to_check[arg] + if arg == "initializer": + assert isinstance(el, Callable) + else: + assert el == args_to_check[arg] lin_prob = problem[key]._solver._problem for arg, val in lin_prob_args.items():