Skip to content

Commit

Permalink
Support converting bool matrices (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Jul 14, 2023
1 parent 7cf91ed commit 83be262
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 76 deletions.
7 changes: 1 addition & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@
import sys
from abc import ABC
from datetime import datetime, timezone
from importlib.metadata import metadata
from pathlib import Path
from unittest.mock import MagicMock, patch


try:
from importlib.metadata import metadata
except ImportError:
from importlib_metadata import metadata


def mock_rpy2() -> None:
"""Can’t use autodoc_mock_imports as we import anndata2ri."""
patch('rpy2.situation.get_r_home', lambda: None).start()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ doc = [
'scanpydoc',
'sphinx-rtd-theme>=0.5',
'lxml', # For scraping the R link info
'importlib_metadata; python_version < "3.8"',
'importlib_resources; python_version < "3.9"',
]

[tool.hatch.version]
Expand Down
Empty file added src/anndata2ri/py.typed
Empty file.
65 changes: 14 additions & 51 deletions src/anndata2ri/scipy2ri/_py2r.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations

from functools import wraps
from functools import lru_cache, wraps
from typing import TYPE_CHECKING


try:
from importlib.resources import files
except ImportError: # Python < 3.9
from importlib_resources import files

import numpy as np
from rpy2.robjects import default_converter, numpy2ri
from rpy2.robjects.conversion import localconverter
Expand All @@ -24,6 +30,11 @@
base: Package | None = None


@lru_cache
def get_r_code() -> str:
return files('anndata2ri').joinpath('scipy2ri', '_py2r_helpers.r').read_text()


def get_type_conv(dtype: np.dtype) -> Callable[[np.ndarray], Sexp]:
global base # noqa: PLW0603
if base is None:
Expand All @@ -44,56 +55,8 @@ def wrapper(obj: sparse.spmatrix) -> Sexp:
global matrix # noqa: PLW0603
if matrix is None:
importr('Matrix') # make class available
matrix = SignatureTranslatedAnonymousPackage(
"""
sparse_matrix <- function(x, conv_data, dims, ...) {
Matrix::sparseMatrix(
...,
x=conv_data(x),
dims=as.integer(dims),
index1=FALSE
)
}
from_csc <- function(i, p, x, dims, conv_data) {
sparse_matrix(
i=as.integer(i),
p=as.integer(p),
x=x,
conv_data=conv_data,
dims=dims,
repr="C"
)
}
from_csr <- function(j, p, x, dims, conv_data) {
sparse_matrix(
j=as.integer(j),
p=as.integer(p),
x=x,
conv_data=conv_data,
dims=dims,
repr="R"
)
}
from_coo <- function(i, j, x, dims, conv_data) {
sparse_matrix(
i=as.integer(i),
j=as.integer(j),
x=x,
conv_data=conv_data,
dims=dims,
repr="T"
)
}
from_dia <- function(n, x, conv_data) {
Matrix::Diagonal(n=as.integer(n), x=conv_data(x))
}
""",
'matrix',
)
r_code = get_r_code()
matrix = SignatureTranslatedAnonymousPackage(r_code, 'matrix')

return f(obj)

Expand Down
45 changes: 45 additions & 0 deletions src/anndata2ri/scipy2ri/_py2r_helpers.r
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
sparse_matrix <- function(x, conv_data, dims, ...) {
Matrix::sparseMatrix(
...,
x=conv_data(x),
dims=as.integer(dims),
index1=FALSE
)
}

from_csc <- function(i, p, x, dims, conv_data) {
sparse_matrix(
i=as.integer(i),
p=as.integer(p),
x=x,
conv_data=conv_data,
dims=dims,
repr="C"
)
}

from_csr <- function(j, p, x, dims, conv_data) {
sparse_matrix(
j=as.integer(j),
p=as.integer(p),
x=x,
conv_data=conv_data,
dims=dims,
repr="R"
)
}

from_coo <- function(i, j, x, dims, conv_data) {
sparse_matrix(
i=as.integer(i),
j=as.integer(j),
x=x,
conv_data=conv_data,
dims=dims,
repr="T"
)
}

from_dia <- function(n, x, conv_data) {
Matrix::Diagonal(n=as.integer(n), x=conv_data(x))
}
36 changes: 27 additions & 9 deletions src/anndata2ri/scipy2ri/_r2py.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
Expand All @@ -10,7 +11,28 @@
from scipy import sparse

from ._conv import converter
from ._support import supported_r_matrix_classes
from ._support import SupportedMatStor, supported_r_matrix_classes


if TYPE_CHECKING:
from collections.abc import Callable, Iterable

CoordSpec = tuple[np.ndarray, ...] | tuple[tuple[np.ndarray, ...]] | tuple[list[int]]


OPTIONS: Iterable[
tuple[
SupportedMatStor,
type[sparse.spmatrix],
Callable[[RSlots], CoordSpec],
Callable[[CoordSpec], int] | None,
]
] = [
('C', sparse.csc_matrix, lambda slots: (slots['i'], slots['p']), lambda c: len(c[0])),
('R', sparse.csr_matrix, lambda slots: (slots['j'], slots['p']), lambda c: len(c[0])),
('T', sparse.coo_matrix, lambda slots: ((slots['i'], slots['j']),), lambda c: len(c[0][0])),
('di', sparse.dia_matrix, lambda _: ([0],), None),
]


@converter.rpy2py.register(SexpS4)
Expand All @@ -27,22 +49,18 @@ def rmat_to_spmat(rmat: SexpS4) -> sparse.spmatrix:
# https://github.com/theislab/anndata2ri/issues/111
warn(f'Encountered Matrix class that is not supported: {r_classes}', stacklevel=2)
return rmat
for storage, mat_cls, idx, nnz in [
('C', sparse.csc_matrix, lambda: [slots['i'], slots['p']], lambda c: len(c[0])),
('R', sparse.csr_matrix, lambda: [slots['j'], slots['p']], lambda c: len(c[0])),
('T', sparse.coo_matrix, lambda: [(slots['i'], slots['j'])], lambda c: len(c[0][0])),
('di', sparse.dia_matrix, lambda: [[0]], None),
]:
for storage, mat_cls, idx, nnz in OPTIONS:
if not supported_r_matrix_classes(storage=storage) & r_classes:
continue
coord_spec = idx()
coord_spec = idx(slots)
data = (
np.repeat(a=True, repeats=nnz(coord_spec))
# we have pattern matrix without data (but always i and j!)
if supported_r_matrix_classes(types='n') & r_classes
else slots['x']
)
return mat_cls((data, *coord_spec), shape=shape)
dtype = np.bool_ if supported_r_matrix_classes(types=('n', 'l')) & r_classes else np.floating
return mat_cls((data, *coord_spec), shape=shape, dtype=dtype)

msg = 'Should have hit one of the branches'
raise AssertionError(msg)
12 changes: 7 additions & 5 deletions src/anndata2ri/scipy2ri/_support.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal, get_args


if TYPE_CHECKING:
from collections.abc import Iterable

SupportedMatType = Literal['d', 'l', 'n']
SupportedMatStor = Literal['C', 'R', 'T', 'di']

# these are documented in __init__.py because of sphinx limitations
supported_r_matrix_types = frozenset({'d', 'l', 'n'})
supported_r_matrix_storage = frozenset({'C', 'R', 'T', 'di'})
supported_r_matrix_types: frozenset[SupportedMatType] = frozenset(get_args(SupportedMatType))
supported_r_matrix_storage: frozenset[SupportedMatStor] = frozenset(get_args(SupportedMatStor))


@lru_cache(maxsize=None)
def supported_r_matrix_classes(
types: Iterable[str] | str = supported_r_matrix_types,
storage: Iterable[str] | str = supported_r_matrix_storage,
types: Iterable[SupportedMatType] | SupportedMatType = supported_r_matrix_types,
storage: Iterable[SupportedMatStor] | SupportedMatStor = supported_r_matrix_storage,
) -> frozenset[str]:
"""Get supported classes, possibly limiting data types or storage types.
Expand Down
5 changes: 1 addition & 4 deletions tests/test_scipy_rpy2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ def test_py2rpy(
sm = r2py(scipy2ri, dataset)
assert isinstance(sm, cls)
assert sm.shape == shape
# TODO(flying-sheep): check dtype
# https://github.com/theislab/anndata2ri/issues/113
if dtype != np.bool_:
assert sm.dtype == dtype
assert sm.dtype == dtype
assert np.allclose(sm.toarray(), np.array(arr))

dm = numpy2ri.converter.rpy2py(baseenv['as.matrix'](dataset()))
Expand Down

0 comments on commit 83be262

Please sign in to comment.