Skip to content

Commit

Permalink
fix test for also fgw tests
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Oct 23, 2024
1 parent 895085d commit 7e472dc
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tests/problems/generic/test_fgw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ott.solvers.linear import acceleration

from anndata import AnnData
from typing import Callable

from moscot._types import CostKwargs_t
from moscot.backends.ott._utils import alpha_to_fused_penalty
Expand Down Expand Up @@ -112,7 +113,10 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
solver = problem[key].solver.solver
args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args
for arg, val in args.items():
assert getattr(solver, val, object()) == args_to_check[arg], arg
if args_to_check["rank"] == -1 and arg == "initializer":
assert isinstance(getattr(solver,val), Callable)
else:
assert getattr(solver, val, object()) == args_to_check[arg], arg

sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
Expand Down

0 comments on commit 7e472dc

Please sign in to comment.