Skip to content

Commit

Permalink
Merge branch 'main' into fix/cat_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ArinaDanilina authored Oct 4, 2024
2 parents 14adcd3 + b516bd9 commit d9f65e6
Show file tree
Hide file tree
Showing 35 changed files with 208 additions and 167 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: ["3.9", "3.10"]
python: ["3.10", "3.11"]
include:
- os: macos-latest
python: "3.9"
python: "3.10"

steps:
- uses: actions/checkout@v3
Expand Down
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ default_stages:
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies: [numpy>=1.25.0]
files: ^src
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.8.0
hooks:
- id: black
additional_dependencies: [toml]
Expand Down Expand Up @@ -42,7 +42,7 @@ repos:
- id: check-yaml
- id: check-toml
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
Expand All @@ -52,18 +52,18 @@ repos:
- id: blacken-docs
additional_dependencies: [black==23.1.0]
- repo: https://github.com/rstcheck/rstcheck
rev: v6.2.0
rev: v6.2.4
hooks:
- id: rstcheck
additional_dependencies: [tomli]
args: [--config=pyproject.toml]
- repo: https://github.com/PyCQA/doc8
rev: v1.1.1
rev: v1.1.2
hooks:
- id: doc8
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.0
rev: v0.6.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
5 changes: 5 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"anndata": ("https://anndata.readthedocs.io/en/latest/", None),
"scanpy": ("https://scanpy.readthedocs.io/en/latest/", None),
"squidpy": ("https://squidpy.readthedocs.io/en/latest/", None),
"mudata": ("https://mudata.readthedocs.io/en/latest/", None),
}
master_doc = "index"
pygments_style = "tango"
Expand Down Expand Up @@ -139,6 +140,10 @@
r"https://doi.org/10.1126/science.aax1971",
r"https://doi.org/10.1093/nar/gkac235",
r"https://www.science.org/doi/abs/10.1126/science.aax1971",
r"https://doi.org/10.1101/2022.01.10.475692",
r"https://www.biorxiv.org/content/10.1101/2023.04.14.536867v1",
r"https://www.biorxiv.org/content/10.1101/2023.05.11.540374v2",
r"https://www.biorxiv.org/content/early/2022/01/11/2022.01.10.475692",
]

exclude_patterns = ["_build", "**.ipynb_checkpoints", "notebooks/README.rst", "notebooks/CONTRIBUTING.rst"]
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Installation
============
:mod:`moscot` requires Python version >= 3.9 to run.
:mod:`moscot` requires Python version >= 3.10 to run.

PyPI
----
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "moscot"
dynamic = ["version"]
description = "Multi-omic single-cell optimal transport tools"
readme = "README.rst"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -57,7 +57,8 @@ dependencies = [
"ott-jax[neural]>=0.4.6",
"cloudpickle>=2.2.0",
"rich>=13.5",
"docstring_inheritance>=2.0.0"
"docstring_inheritance>=2.0.0",
"mudata>=0.2.2"
]

[project.optional-dependencies]
Expand Down Expand Up @@ -232,7 +233,7 @@ ignore_roles = [

[tool.mypy]
mypy_path = "$MYPY_CONFIG_FILE_DIR/src"
python_version = "3.9"
python_version = "3.10"
plugins = "numpy.typing.mypy_plugin"

ignore_errors = false
Expand Down Expand Up @@ -269,7 +270,7 @@ max_line_length = 120
legacy_tox_ini = """
[tox]
min_version = 4.0
env_list = lint-code,py{3.9,3.10,3.11}
env_list = lint-code,py{3.10,3.11,3.12}
skip_missing_interpreters = true
[testenv]
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

try:
md = metadata.metadata(__name__)
__version__ = md.get("version", "")
__author__ = md.get("Author", "")
__maintainer__ = md.get("Maintainer-email", "")
__version__ = md.get("version", "") # type: ignore[attr-defined]
__author__ = md.get("Author", "") # type: ignore[attr-defined]
__maintainer__ = md.get("Maintainer-email", "") # type: ignore[attr-defined]
except ImportError:
md = None

Expand Down
10 changes: 0 additions & 10 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import jax
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
from ott.neural.datasets import OTData, OTDataset
from ott.neural.methods.flows import dynamics, genot
Expand Down Expand Up @@ -652,15 +651,6 @@ def _prepare( # type: ignore[override]
MultiLoader(datasets=validate_loaders, seed=seed),
)

@staticmethod
def _assert2d(arr: ArrayLike, *, allow_reshape: bool = True) -> jnp.ndarray:
arr: jnp.ndarray = jnp.asarray(arr.A if sp.issparse(arr) else arr) # type: ignore[no-redef, attr-defined] # noqa:E501
if allow_reshape and arr.ndim == 1:
return jnp.reshape(arr, (-1, 1))
if arr.ndim != 2:
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
return arr

def _split_data( # TODO: adapt for Gromov terms
self,
x: ArrayLike,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102

@property
def shape(self) -> tuple[int, int]: # noqa: D102
return self.transport_matrix.shape # type: ignore[return-value]
return self.transport_matrix.shape

def to( # noqa: D102
self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None
Expand Down
5 changes: 0 additions & 5 deletions src/moscot/base/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
Union,
)

import numpy as np

from moscot._logging import logger
from moscot._types import ArrayLike, Device_t, ProblemKind_t
from moscot.base.output import BaseDiscreteSolverOutput
Expand Down Expand Up @@ -55,9 +53,6 @@ def to_tuple(
loss_x = {k[2:]: v for k, v in kwargs.items() if k.startswith("x_")}
loss_y = {k[2:]: v for k, v in kwargs.items() if k.startswith("y_")}

if isinstance(xy, dict) and np.all([isinstance(v, tuple) for v in xy.values()]): # handling joint learning
return xy

# fmt: off
xy = xy if isinstance(xy, TaggedArray) else self._convert(*to_tuple(xy), tag=tags.get("xy", None), **loss_xy)
x = x if isinstance(x, TaggedArray) else self._convert(*to_tuple(x), tag=tags.get("x", None), **loss_x)
Expand Down
Loading

0 comments on commit d9f65e6

Please sign in to comment.