Skip to content

Commit

Permalink
Merge branch 'main' into feature/new_data
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen authored Sep 12, 2024
2 parents 77193dc + f0a2bfe commit bf2951d
Show file tree
Hide file tree
Showing 19 changed files with 37 additions and 36 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.9", "3.10"]
python: ["3.10", "3.11"]
include:
- os: macos-latest
python: "3.9"
python: "3.10"

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Installation
============
:mod:`moscot` requires Python version >= 3.9 to run.
:mod:`moscot` requires Python version >= 3.10 to run.

PyPI
----
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "moscot"
dynamic = ["version"]
description = "Multi-omic single-cell optimal transport tools"
readme = "README.rst"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -233,7 +233,7 @@ ignore_roles = [

[tool.mypy]
mypy_path = "$MYPY_CONFIG_FILE_DIR/src"
python_version = "3.9"
python_version = "3.10"
plugins = "numpy.typing.mypy_plugin"

ignore_errors = false
Expand Down Expand Up @@ -270,7 +270,7 @@ max_line_length = 120
legacy_tox_ini = """
[tox]
min_version = 4.0
env_list = lint-code,py{3.9,3.10,3.11}
env_list = lint-code,py{3.10,3.11,3.12}
skip_missing_interpreters = true
[testenv]
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

try:
md = metadata.metadata(__name__)
__version__ = md.get("version", "")
__author__ = md.get("Author", "")
__maintainer__ = md.get("Maintainer-email", "")
__version__ = md.get("version", "") # type: ignore[attr-defined]
__author__ = md.get("Author", "") # type: ignore[attr-defined]
__maintainer__ = md.get("Maintainer-email", "") # type: ignore[attr-defined]
except ImportError:
md = None

Expand Down
4 changes: 2 additions & 2 deletions src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

ArrayLike = NDArray[np.float64]
except (ImportError, TypeError):
ArrayLike = np.ndarray
DTypeLike = np.dtype
ArrayLike = np.ndarray # type: ignore[misc]
DTypeLike = np.dtype # type: ignore[misc]

ProblemKind_t = Literal["linear", "quadratic", "unknown"]
Numeric_t = Union[int, float] # type of `time_key` arguments
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def densify(arr: ArrayLike) -> jax.Array:
dense :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.toarray()
arr = arr.toarray() # type: ignore[attr-defined]
elif isinstance(arr, jesp.BCOO):
arr = arr.todense()
return jnp.asarray(arr)
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def project_to_transport_matrix( # type:ignore[override]
The projected transport matrix.
"""
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
push: Callable[[Any], Any] = self.push if condition is None else lambda x: self.push(x, condition) # type: ignore
pull: Callable[[Any], Any] = self.pull if condition is None else lambda x: self.pull(x, condition) # type: ignore
push = self.push if condition is None else lambda x: self.push(x, condition)
pull = self.pull if condition is None else lambda x: self.pull(x, condition)
func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells)
return self._project_transport_matrix(
src_dist=src_dist,
Expand Down
3 changes: 2 additions & 1 deletion src/moscot/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def register_solver(
return _REGISTRY.register(backend) # type: ignore[return-value]


@register_solver("ott")
# TODO(@MUCDK) fix mypy error
@register_solver("ott") # type: ignore[arg-type]
def _(
problem_kind: Literal["linear", "quadratic"],
solver_name: Optional[Literal["GENOTLinSolver"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/base/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayLike:
"""
cost = self._compute(*args, **kwargs)
if np.any(np.isnan(cost)):
maxx = np.nanmax(cost) # type: ignore[var-annotated]
maxx = np.nanmax(cost)
logger.warning(
f"Cost matrix contains `{np.sum(np.isnan(cost))}` NaN values, "
f"setting them to the maximum value `{maxx}`."
)
cost = np.nan_to_num(cost, nan=maxx)
cost = np.nan_to_num(cost, nan=maxx) # type: ignore[call-overload]
if np.any(cost < 0):
raise ValueError(f"Cost matrix contains `{np.sum(cost < 0)}` negative values.")
return cost
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def _sample_from_tmap(
account_for_unbalancedness: bool = False,
interpolation_parameter: Optional[Numeric_t] = None,
seed: Optional[int] = None,
) -> tuple[Any, list[str]]:
) -> tuple[list[Any], list[ArrayLike]]:
rng = np.random.RandomState(seed)
if account_for_unbalancedness and interpolation_parameter is None:
raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.")
Expand Down Expand Up @@ -453,7 +453,7 @@ def _sample_from_tmap(
for i in range(len(rows_batch))
]
all_cols_sampled.extend(cols_sampled)
return rows, all_cols_sampled
return rows, all_cols_sampled # type: ignore[return-value]

def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr
corr_bs = np.concatenate(corr_bs, axis=0)
corr_ci_low, corr_ci_high = np.quantile(corr_bs, q=ql, axis=0), np.quantile(corr_bs, q=qh, axis=0)

return pvals, corr_ci_low, corr_ci_high
return pvals, corr_ci_low, corr_ci_high # type:ignore[return-value]

if not (0 <= confidence_level <= 1):
raise ValueError(f"Expected `confidence_level` to be in interval `[0, 1]`, found `{confidence_level}`.")
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _split_mass(arr: ArrayLike) -> ArrayLike:
if start >= adata.n_obs:
raise IndexError(f"Expected starting index to be smaller than `{adata.n_obs}`, found `{start}`.")
data = np.zeros((adata.n_obs,), dtype=float)
data[range(start, min(start + offset, adata.n_obs))] = 1.0 # type: ignore[index]
data[range(start, min(start + offset, adata.n_obs))] = 1.0
else:
raise TypeError(f"Unable to interpret subset of type `{type(subset)}`.")
elif not hasattr(data, "shape"):
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/costs/_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _scaled_hamming_dist(x: ArrayLike, y: ArrayLike) -> float:
raise ValueError("No shared indices.")
b2 = y[shared_indices]

differences: ArrayLike = b1 != b2
double_scars: ArrayLike = differences & (b1 != 0) & (b2 != 0)
differences = b1 != b2
double_scars = differences & (b1 != 0) & (b2 != 0)

return float(float(np.sum(differences)) + np.sum(double_scars)) / len(b1)
return (np.sum(differences) + np.sum(double_scars)) / len(b1)
2 changes: 1 addition & 1 deletion src/moscot/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def _plot_scatter(
_ = kwargs.pop("palette", None)
if (time_points[i] == source and push) or (time_points[i] == target and not push):
st = f"not in {time_points[i]}"
vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask]) # type: ignore[var-annotated]
vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask])
column = pd.Series(tmp).fillna(st).astype("category")
# TODO(michalk8): check
if len(np.unique(column[mask.values].values)) > 2:
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, s
raise TypeError("`x_attr` and `y_attr` must be of type `str` or `dict` if no callback is provided.")


class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]):
class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving a :term:`linear problem`.
Parameters
Expand Down Expand Up @@ -264,7 +264,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]):
class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving the :term:`GW <Gromov-Wasserstein>` or :term:`FGW <fused Gromov-Wasserstein>` problems.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/space/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def _create_problem(
adata_tgt=self.adata_sc,
src_obs_mask=src_mask,
tgt_obs_mask=None,
src_var_mask=self.filtered_vars,
tgt_var_mask=self.filtered_vars,
src_var_mask=self.filtered_vars, # type: ignore[arg-type]
tgt_var_mask=self.filtered_vars, # type: ignore[arg-type]
src_key=src,
tgt_key=tgt,
**kwargs,
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/problems/space/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SpatialAlignmentMixinProtocol(AnalysisMixinProtocol[K, B]):
_spatial_key: Optional[str]
batch_key: Optional[str]

def _subset_spatial(
def _subset_spatial( # type:ignore[empty-body]
self: "SpatialAlignmentMixinProtocol[K, B]",
k: K,
spatial_key: str,
Expand Down Expand Up @@ -780,13 +780,13 @@ def _compute_correspondence(

def pdist(row_idx: ArrayLike, col_idx: float, feat: ArrayLike) -> Any:
if len(row_idx) > 0:
return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean()
return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean() # type: ignore[index]
return np.nan

# TODO(michalk8): vectorize using jax, this is just a for loop
vpdist = np.vectorize(pdist, excluded=["feat"])
if sp.issparse(features):
features = features.toarray()
features = features.toarray() # type: ignore[attr-defined]

feat_arr, index_arr, support_arr = [], [], []
for ind, i in enumerate(support):
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
__all__ = ["TemporalProblem", "LineageProblem"]


class TemporalProblem(
class TemporalProblem( # type: ignore[misc]
TemporalMixin[Numeric_t, BirthDeathProblem], BirthDeathMixin, CompoundProblem[Numeric_t, BirthDeathProblem]
):
"""Class for analyzing time-series single cell data based on :cite:`schiebinger:19`.
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram
# TODO(michalk8): `[1]` will fail if potentials is None
df_list = [
pd.DataFrame(
np.asarray(problem.solution.potentials[0]),
np.asarray(problem.solution.potentials[0]), # type: ignore[union-attr,index]
index=problem.adata_src.obs_names,
columns=cols,
)
Expand All @@ -612,7 +612,7 @@ def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram
# TODO(michalk8): `[1]` will fail if potentials is None
df_list = [
pd.DataFrame(
np.array(problem.solution.potentials[1]),
np.array(problem.solution.potentials[1]), # type: ignore[union-attr,index]
index=problem.adata_tgt.obs_names,
columns=cols,
)
Expand Down Expand Up @@ -664,7 +664,7 @@ def _get_data(
else:
raise ValueError(f"No data found for `{target}` time point.")

return (
return ( # type:ignore[return-value]
source_data,
growth_rates_source,
intermediate_data,
Expand Down

0 comments on commit bf2951d

Please sign in to comment.