Skip to content

Commit

Permalink
changes to gw_solver_args
Browse files Browse the repository at this point in the history
  • Loading branch information
Arina Danilina committed Oct 8, 2023
1 parent cf952b8 commit 23d34a7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def __init__(
):
super().__init__(jit=jit)
if rank > -1:
kwargs = {**linear_solver_kwargs, **kwargs}
kwargs.setdefault("gamma", 10)
kwargs.setdefault("gamma_rescale", True)
initializer = "random" if initializer is None else initializer
Expand All @@ -260,6 +261,7 @@ def __init__(

else:
initializer = None
kwargs = {**linear_solver_kwargs, **kwargs}
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
quad_initializer=initializer,
Expand Down
10 changes: 9 additions & 1 deletion tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
"gw_unbalanced_correction": False,
"ranks": 3,
"tolerances": 3e-2,
"warm_start": True,
#"warm_start": True,
"linear_solver_kwargs": linear_solver_kwargs2,
}

Expand All @@ -175,6 +175,14 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
"initializer": "quad_initializer",
}

gw_lr_solver_args = {
"epsilon": "epsilon",
"rank": "rank",
"threshold": "threshold",
"initializer_kwargs": "kwargs_init",
"initializer": "initializer",
}

gw_linear_solver_args = {
"lse_mode": "lse_mode",
"inner_iterations": "inner_iterations",
Expand Down
5 changes: 3 additions & 2 deletions tests/problems/generic/test_gw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
gw_linear_solver_args,
gw_lr_linear_solver_args,
gw_solver_args,
gw_lr_solver_args,
quad_prob_args,
)

Expand Down Expand Up @@ -113,11 +114,11 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin
problem = problem.solve(**args_to_check)
key = ("0", "1")
solver = problem[key].solver.solver
for arg, val in gw_solver_args.items():
for arg, val in gw_solver_args.items() if args_to_check["rank"] == -1 else gw_lr_solver_args.items():
assert hasattr(solver, val)
assert getattr(solver, val) == args_to_check[arg]

sinkhorn_solver = solver.linear_ot_solver
sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver
lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args
for arg, val in lin_solver_args.items():
assert hasattr(sinkhorn_solver, val)
Expand Down

0 comments on commit 23d34a7

Please sign in to comment.