Skip to content

Commit

Permalink
Merge branch 'main' into add/sparse-geodesic
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen authored May 7, 2024
2 parents afcc83f + 6cc2215 commit 40c7966
Show file tree
Hide file tree
Showing 12 changed files with 32 additions and 24 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ default_stages:
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.9.0
rev: v1.10.0
hooks:
- id: mypy
additional_dependencies: [numpy>=1.25.0]
files: ^src
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 24.4.2
hooks:
- id: black
additional_dependencies: [toml]
Expand All @@ -29,7 +29,7 @@ repos:
additional_dependencies: [toml]
args: [--order-by-type]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-merge-conflict
- id: check-ast
Expand Down Expand Up @@ -63,7 +63,7 @@ repos:
- id: doc8
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.5
rev: v0.4.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
5 changes: 5 additions & 0 deletions src/moscot/base/problems/compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ def solve(
problems = self._problem_manager.get_problems(stage=stage)

logger.info(f"Solving `{len(problems)}` problems")
# expose min/max iterations to the user but remove them if they are None
if "min_iterations" in kwargs and kwargs["min_iterations"] is None:
kwargs.pop("min_iterations")
if "max_iterations" in kwargs and kwargs["max_iterations"] is None:
kwargs.pop("max_iterations")
for problem in problems.values():
logger.info(f"Solving problem {problem}.")
_ = problem.solve(**kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/cross_modality/_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def solve( # type: ignore[override]
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
12 changes: 6 additions & 6 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def solve(
threshold: float = 1e-3,
lse_mode: bool = True,
inner_iterations: int = 10,
min_iterations: int = 0,
max_iterations: int = 2000,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "SinkhornProblem[K,B]":
Expand Down Expand Up @@ -373,8 +373,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down Expand Up @@ -604,8 +604,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/space/_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/space/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/spatiotemporal/_spatio_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
8 changes: 4 additions & 4 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def solve(
threshold: float = 1e-3,
lse_mode: bool = True,
inner_iterations: int = 10,
min_iterations: int = 0,
max_iterations: int = 2000,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "TemporalProblem":
Expand Down Expand Up @@ -406,8 +406,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,8 @@ def _get_data(
tag = self.problems[src, tgt].xy.tag # type: ignore[union-attr]
if tag != Tag.POINT_CLOUD:
raise ValueError(
f"Expected `tag={Tag.POINT_CLOUD}`, " # type: ignore[union-attr]
f"found `tag={self.problems[src, tgt].xy.tag}`."
f"Expected `tag={Tag.POINT_CLOUD}`, "
f"found `tag={self.problems[src, tgt].xy.tag}`." # type: ignore[union-attr]
)
if src == source:
source_data = self.problems[src, tgt].xy.data_src # type: ignore[union-attr]
Expand Down
1 change: 1 addition & 0 deletions tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_prepare_star(self, adata_space_rotate: AnnData, reference: str):
assert ref == reference
assert isinstance(ap[prob_key], ap._base_problem_type)

@pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678")
@pytest.mark.parametrize(
("epsilon", "alpha", "rank", "initializer"),
[(1, 0.9, -1, None), (1, 0.5, 10, "random"), (1, 0.5, 10, "rank2"), (0.1, 0.1, -1, None)],
Expand Down
1 change: 1 addition & 0 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def test_prepare_varnames(self, adata_mapping: AnnData, var_names: Optional[List
assert prob.x.data_src.shape == (n_obs, x_n_var)
assert prob.y.data_src.shape == (n_obs, y_n_var)

@pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678")
@pytest.mark.parametrize(
("epsilon", "alpha", "rank", "initializer"),
[(1e-2, 0.9, -1, None), (2, 0.5, 10, "random"), (2, 0.5, 10, "rank2"), (2, 0.1, -1, None)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_solve_balanced(self, adata_spatio_temporal: AnnData):
assert isinstance(subsol, BaseSolverOutput)
assert key in expected_keys

@pytest.mark.skip(reason="unbalanced does not work yet")
def test_solve_unbalanced(self, adata_spatio_temporal: AnnData):
taus = [9e-1, 1e-2]
problem1 = SpatioTemporalProblem(adata=adata_spatio_temporal)
Expand Down

0 comments on commit 40c7966

Please sign in to comment.