Skip to content

Commit

Permalink
Merge pull request #730 from theislab/pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
[pre-commit.ci] pre-commit autoupdate
  • Loading branch information
selmanozleyen authored Sep 12, 2024
2 parents 2025d82 + 22aaccf commit 0ccc3b1
Show file tree
Hide file tree
Showing 45 changed files with 135 additions and 147 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ default_stages:
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies: [numpy>=1.25.0]
files: ^src
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.8.0
hooks:
- id: black
additional_dependencies: [toml]
Expand Down Expand Up @@ -42,7 +42,7 @@ repos:
- id: check-yaml
- id: check-toml
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
Expand All @@ -52,18 +52,18 @@ repos:
- id: blacken-docs
additional_dependencies: [black==23.1.0]
- repo: https://github.com/rstcheck/rstcheck
rev: v6.2.0
rev: v6.2.4
hooks:
- id: rstcheck
additional_dependencies: [tomli]
args: [--config=pyproject.toml]
- repo: https://github.com/PyCQA/doc8
rev: v1.1.1
rev: v1.1.2
hooks:
- id: doc8
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.0
rev: v0.6.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
4 changes: 4 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@
r"https://doi.org/10.1126/science.aax1971",
r"https://doi.org/10.1093/nar/gkac235",
r"https://www.science.org/doi/abs/10.1126/science.aax1971",
r"https://doi.org/10.1101/2022.01.10.475692",
r"https://www.biorxiv.org/content/10.1101/2023.04.14.536867v1",
r"https://www.biorxiv.org/content/10.1101/2023.05.11.540374v2",
r"https://www.biorxiv.org/content/early/2022/01/11/2022.01.10.475692",
]

exclude_patterns = ["_build", "**.ipynb_checkpoints", "notebooks/README.rst", "notebooks/CONTRIBUTING.rst"]
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

ArrayLike = NDArray[np.float64]
except (ImportError, TypeError):
ArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc]
ArrayLike = np.ndarray
DTypeLike = np.dtype

ProblemKind_t = Literal["linear", "quadratic", "unknown"]
Numeric_t = Union[int, float] # type of `time_key` arguments
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def densify(arr: ArrayLike) -> jax.Array:
dense :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.toarray() # type: ignore[attr-defined]
arr = arr.toarray()
elif isinstance(arr, jesp.BCOO):
arr = arr.todense()
return jnp.asarray(arr)
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def project_to_transport_matrix( # type:ignore[override]
The projected transport matrix.
"""
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
push = self.push if condition is None else lambda x: self.push(x, condition)
pull = self.pull if condition is None else lambda x: self.pull(x, condition)
push: Callable[[Any], Any] = self.push if condition is None else lambda x: self.push(x, condition) # type: ignore
pull: Callable[[Any], Any] = self.pull if condition is None else lambda x: self.pull(x, condition) # type: ignore
func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells)
return self._project_transport_matrix(
src_dist=src_dist,
Expand Down
10 changes: 0 additions & 10 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import jax
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
from ott.neural.datasets import OTData, OTDataset
from ott.neural.methods.flows import dynamics, genot
Expand Down Expand Up @@ -652,15 +651,6 @@ def _prepare( # type: ignore[override]
MultiLoader(datasets=validate_loaders, seed=seed),
)

@staticmethod
def _assert2d(arr: ArrayLike, *, allow_reshape: bool = True) -> jnp.ndarray:
arr: jnp.ndarray = jnp.asarray(arr.A if sp.issparse(arr) else arr) # type: ignore[no-redef, attr-defined] # noqa:E501
if allow_reshape and arr.ndim == 1:
return jnp.reshape(arr, (-1, 1))
if arr.ndim != 2:
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
return arr

def _split_data( # TODO: adapt for Gromov terms
self,
x: ArrayLike,
Expand Down
3 changes: 1 addition & 2 deletions src/moscot/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def register_solver(
return _REGISTRY.register(backend) # type: ignore[return-value]


# TODO(@MUCDK) fix mypy error
@register_solver("ott") # type: ignore[arg-type]
@register_solver("ott")
def _(
problem_kind: Literal["linear", "quadratic"],
solver_name: Optional[Literal["GENOTLinSolver"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/base/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayLike:
"""
cost = self._compute(*args, **kwargs)
if np.any(np.isnan(cost)):
maxx = np.nanmax(cost)
maxx = np.nanmax(cost) # type: ignore[var-annotated]
logger.warning(
f"Cost matrix contains `{np.sum(np.isnan(cost))}` NaN values, "
f"setting them to the maximum value `{maxx}`."
)
cost = np.nan_to_num(cost, nan=maxx) # type: ignore[call-overload]
cost = np.nan_to_num(cost, nan=maxx)
if np.any(cost < 0):
raise ValueError(f"Cost matrix contains `{np.sum(cost < 0)}` negative values.")
return cost
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102

@property
def shape(self) -> tuple[int, int]: # noqa: D102
return self.transport_matrix.shape # type: ignore[return-value]
return self.transport_matrix.shape

def to( # noqa: D102
self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _sample_from_tmap(
account_for_unbalancedness: bool = False,
interpolation_parameter: Optional[Numeric_t] = None,
seed: Optional[int] = None,
) -> tuple[list[Any], list[ArrayLike]]:
) -> tuple[Any, list[str]]:
rng = np.random.RandomState(seed)
if account_for_unbalancedness and interpolation_parameter is None:
raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.")
Expand Down Expand Up @@ -453,7 +453,7 @@ def _sample_from_tmap(
for i in range(len(rows_batch))
]
all_cols_sampled.extend(cols_sampled)
return rows, all_cols_sampled # type: ignore[return-value]
return rows, all_cols_sampled

def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr
corr_bs = np.concatenate(corr_bs, axis=0)
corr_ci_low, corr_ci_high = np.quantile(corr_bs, q=ql, axis=0), np.quantile(corr_bs, q=qh, axis=0)

return pvals, corr_ci_low, corr_ci_high # type:ignore[return-value]
return pvals, corr_ci_low, corr_ci_high

if not (0 <= confidence_level <= 1):
raise ValueError(f"Expected `confidence_level` to be in interval `[0, 1]`, found `{confidence_level}`.")
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _split_mass(arr: ArrayLike) -> ArrayLike:
if start >= adata.n_obs:
raise IndexError(f"Expected starting index to be smaller than `{adata.n_obs}`, found `{start}`.")
data = np.zeros((adata.n_obs,), dtype=float)
data[range(start, min(start + offset, adata.n_obs))] = 1.0
data[range(start, min(start + offset, adata.n_obs))] = 1.0 # type: ignore[index]
else:
raise TypeError(f"Unable to interpret subset of type `{type(subset)}`.")
elif not hasattr(data, "shape"):
Expand Down
5 changes: 0 additions & 5 deletions src/moscot/base/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
Union,
)

import numpy as np

from moscot._logging import logger
from moscot._types import ArrayLike, Device_t, ProblemKind_t
from moscot.base.output import BaseDiscreteSolverOutput
Expand Down Expand Up @@ -55,9 +53,6 @@ def to_tuple(
loss_x = {k[2:]: v for k, v in kwargs.items() if k.startswith("x_")}
loss_y = {k[2:]: v for k, v in kwargs.items() if k.startswith("y_")}

if isinstance(xy, dict) and np.all([isinstance(v, tuple) for v in xy.values()]): # handling joint learning
return xy

# fmt: off
xy = xy if isinstance(xy, TaggedArray) else self._convert(*to_tuple(xy), tag=tags.get("xy", None), **loss_xy)
x = x if isinstance(x, TaggedArray) else self._convert(*to_tuple(x), tag=tags.get("x", None), **loss_x)
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/costs/_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _scaled_hamming_dist(x: ArrayLike, y: ArrayLike) -> float:
raise ValueError("No shared indices.")
b2 = y[shared_indices]

differences = b1 != b2
double_scars = differences & (b1 != 0) & (b2 != 0)
differences: ArrayLike = b1 != b2
double_scars: ArrayLike = differences & (b1 != 0) & (b2 != 0)

return (np.sum(differences) + np.sum(double_scars)) / len(b1)
return float(float(np.sum(differences)) + np.sum(double_scars)) / len(b1)
2 changes: 1 addition & 1 deletion src/moscot/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _plot_scatter(
_ = kwargs.pop("palette", None)
if (time_points[i] == source and push) or (time_points[i] == target and not push):
st = f"not in {time_points[i]}"
vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask])
vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask]) # type: ignore[var-annotated]
column = pd.Series(tmp).fillna(st).astype("category")
# TODO(michalk8): check
if len(np.unique(column[mask.values].values)) > 2:
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, s
raise TypeError("`x_attr` and `y_attr` must be of type `str` or `dict` if no callback is provided.")


class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]):
"""Class for solving a :term:`linear problem`.
Parameters
Expand Down Expand Up @@ -264,7 +264,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]):
"""Class for solving the :term:`GW <Gromov-Wasserstein>` or :term:`FGW <fused Gromov-Wasserstein>` problems.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/space/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def _create_problem(
adata_tgt=self.adata_sc,
src_obs_mask=src_mask,
tgt_obs_mask=None,
src_var_mask=self.filtered_vars, # type: ignore[arg-type]
tgt_var_mask=self.filtered_vars, # type: ignore[arg-type]
src_var_mask=self.filtered_vars,
tgt_var_mask=self.filtered_vars,
src_key=src,
tgt_key=tgt,
**kwargs,
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/problems/space/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SpatialAlignmentMixinProtocol(AnalysisMixinProtocol[K, B]):
_spatial_key: Optional[str]
batch_key: Optional[str]

def _subset_spatial( # type:ignore[empty-body]
def _subset_spatial(
self: "SpatialAlignmentMixinProtocol[K, B]",
k: K,
spatial_key: str,
Expand Down Expand Up @@ -780,13 +780,13 @@ def _compute_correspondence(

def pdist(row_idx: ArrayLike, col_idx: float, feat: ArrayLike) -> Any:
if len(row_idx) > 0:
return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean() # type: ignore[index]
return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean()
return np.nan

# TODO(michalk8): vectorize using jax, this is just a for loop
vpdist = np.vectorize(pdist, excluded=["feat"])
if sp.issparse(features):
features = features.toarray() # type: ignore[attr-defined]
features = features.toarray()

feat_arr, index_arr, support_arr = [], [], []
for ind, i in enumerate(support):
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
__all__ = ["TemporalProblem", "LineageProblem"]


class TemporalProblem( # type: ignore[misc]
class TemporalProblem(
TemporalMixin[Numeric_t, BirthDeathProblem], BirthDeathMixin, CompoundProblem[Numeric_t, BirthDeathProblem]
):
"""Class for analyzing time-series single cell data based on :cite:`schiebinger:19`.
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram
# TODO(michalk8): `[1]` will fail if potentials is None
df_list = [
pd.DataFrame(
np.asarray(problem.solution.potentials[0]), # type: ignore[union-attr,index]
np.asarray(problem.solution.potentials[0]),
index=problem.adata_src.obs_names,
columns=cols,
)
Expand All @@ -612,7 +612,7 @@ def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram
# TODO(michalk8): `[1]` will fail if potentials is None
df_list = [
pd.DataFrame(
np.array(problem.solution.potentials[1]), # type: ignore[union-attr,index]
np.array(problem.solution.potentials[1]),
index=problem.adata_tgt.obs_names,
columns=cols,
)
Expand Down Expand Up @@ -664,7 +664,7 @@ def _get_data(
else:
raise ValueError(f"No data found for `{target}` time point.")

return ( # type:ignore[return-value]
return (
source_data,
growth_rates_source,
intermediate_data,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/utils/tagged_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def shape(self) -> Tuple[int, int]:
x, y = self.data_src, (self.data_src if self.data_tgt is None else self.data_tgt)
return x.shape[0], y.shape[0]

return self.data_src.shape # type: ignore[return-value]
return self.data_src.shape

@property
def is_cost_matrix(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


class TestSinkhorn:
@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("jit", [False, True])
@pytest.mark.parametrize("eps", [None, 1e-2, 1e-1])
def test_matches_ott(self, x: Geom_t, eps: Optional[float], jit: bool):
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, xy: Geom_t, eps: Optional[float
assert isinstance(solver.xy, PointCloud)
np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL)

@pytest.mark.fast()
@pytest.mark.fast
@pytest.mark.parametrize("alpha", [0.1, 0.9])
def test_alpha(self, x: Geom_t, y: Geom_t, xy: Geom_t, alpha: float) -> None:
thresh, eps = 5e-2, 1e-1
Expand Down
Loading

0 comments on commit 0ccc3b1

Please sign in to comment.