diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 9b6911d..4a1e9f1 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -1,10 +1,69 @@ -name: python-tiledbsoma-ml +name: python-tiledbsoma-ml CI on: + pull_request: + branches: ["**"] + paths-ignore: ['scripts/**'] + push: + branches: [main] + paths-ignore: ['scripts/**'] workflow_dispatch: jobs: - job: + lint: runs-on: ubuntu-latest steps: - # Empty job; placeholder GHA + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Restore pre-commit cache + uses: actions/cache@v4 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} + + - name: Install pre-commit + run: pip -v install pre-commit + + - name: Run pre-commit hooks on all files + run: pre-commit run -v -a + + tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install prereqs + run: | + pip install --upgrade pip wheel pytest pytest-cov setuptools + pip install . + + - name: Run tests + run: pytest -v --cov=src --cov-report=xml tests + + build: + # for now, just do a test build to ensure that it works + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Do build + run: | + pip install --upgrade build pip wheel setuptools setuptools-scm + python -m build . diff --git a/.github/workflows/python-tilledbsoma-ml-compat.yml b/.github/workflows/python-tilledbsoma-ml-compat.yml index a83ed31..742cc05 100644 --- a/.github/workflows/python-tilledbsoma-ml-compat.yml +++ b/.github/workflows/python-tilledbsoma-ml-compat.yml @@ -1,9 +1,47 @@ name: python-tiledbsoma-ml past tiledbsoma compat # Latest tiledbsoma version covered by another workflow + on: + pull_request: + branches: ["**"] + paths-ignore: + - "scripts/**" + - "notebooks/**" + push: + branches: [main] + paths-ignore: + - "scripts/**" + - "notebooks/**" workflow_dispatch: jobs: - job: - runs-on: ubuntu-latest + unit_tests: + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] # could add 'macos-latest', but the matrix is already huge... + python-version: ["3.9", "3.10", "3.11"] # TODO: add 3.12 when tiledbsoma releases wheels for it. + pkg-version: + - "tiledbsoma~=1.9.0 'numpy<2.0.0'" + - "tiledbsoma~=1.10.0 'numpy<2.0.0'" + - "tiledbsoma~=1.11.0" + - "tiledbsoma~=1.12.0" + - "tiledbsoma~=1.13.0" + - "tiledbsoma~=1.14.0" + + runs-on: ${{ matrix.os }} + steps: - # Empty job; placeholder GHA + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install prereqs + run: | + pip install --upgrade pip pytest setuptools + pip install ${{ matrix.pkg-version }} . + + - name: Run tests + run: pytest -v tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e0f39b5..6385ca9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,6 @@ repos: rev: "24.8.0" hooks: - id: black - exclude: 'apis/' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.6.5 @@ -11,7 +10,6 @@ repos: - id: ruff name: "ruff for tiledbsoma_ml" args: ["--config=pyproject.toml"] - exclude: 'apis/' - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.11.2 @@ -23,4 +21,3 @@ repos: - attrs - numpy - pandas-stubs>=2 - exclude: 'apis/' diff --git a/pyproject.toml b/pyproject.toml index 7901253..416e46f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ show_error_codes = true ignore_missing_imports = true warn_unreachable = true strict = true -python_version = 3.9 +python_version = '3.11' plugins = "numpy.typing.mypy_plugin" [tool.ruff] @@ -72,5 +72,5 @@ lint.select = ["E", "F", "B", "I"] lint.ignore = ["E501"] # line too long lint.extend-select = ["I001"] # unsorted-imports fix = true -target-version = "py39" +target-version = "py311" line-length = 120 diff --git a/src/tiledbsoma_ml/__init__.py b/src/tiledbsoma_ml/__init__.py index 7adb437..49e850e 100644 --- a/src/tiledbsoma_ml/__init__.py +++ b/src/tiledbsoma_ml/__init__.py @@ -5,4 +5,14 @@ """An API to support machine learning applications built on SOMA.""" +from .pytorch import ( + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, +) + __version__ = "0.1.0-dev" + +__all__ = [ + "ExperimentAxisQueryIterDataPipe", + "ExperimentAxisQueryIterableDataset", +] diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py new file mode 100644 index 0000000..6f629b0 --- /dev/null +++ b/src/tiledbsoma_ml/pytorch.py @@ -0,0 +1,786 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +import contextlib +import gc +import itertools +import logging +import os +import sys +import time +from contextlib import contextmanager +from itertools import islice +from math import ceil +from typing import ( + Any, + ContextManager, + Dict, + Generator, + Iterable, + Iterator, + Sequence, + Tuple, + TypeVar, + Union, +) + +import attrs +import numpy as np +import numpy.typing as npt +import pandas as pd +import scipy.sparse as sparse +import tiledbsoma as soma +import torch +import torchdata +from somacore.query._eager_iter import EagerIterator as _EagerIterator + +logger = logging.getLogger("tiledbsoma_ml.pytorch") + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +NDArrayNumber = npt.NDArray[np.number[Any]] +XDatum = Union[NDArrayNumber, sparse.csr_matrix] +XObsDatum = Tuple[XDatum, pd.DataFrame] +"""Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, +which pairs a slice of ``X`` rows with a corresponding slice of ``obs``. In the default case, +the datum is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs`` +respectively). If the object is created with ``return_sparse_X`` as True, the ``X`` slice is +returned as a :class:`scipy.sparse.csr_matrix`. If the ``batch_size`` is 1, the :class:`numpy.ndarray` +will be returned with rank 1; in all other cases, objects are returned with rank 2.""" + + +@attrs.define(frozen=True, kw_only=True) +class _ExperimentLocator: + """State required to open the Experiment. + + Serializable across multiple processes. + + Private implementation class. + """ + + uri: str + tiledb_timestamp_ms: int + tiledb_config: Dict[str, Union[str, float]] + + @classmethod + def create(cls, experiment: soma.Experiment) -> "_ExperimentLocator": + return _ExperimentLocator( + uri=experiment.uri, + tiledb_timestamp_ms=experiment.tiledb_timestamp_ms, + tiledb_config=experiment.context.tiledb_config, + ) + + @contextmanager + def open_experiment(self) -> Generator[soma.Experiment, None, None]: + context = soma.SOMATileDBContext(tiledb_config=self.tiledb_config) + yield soma.Experiment.open( + self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context + ) + + +class ExperimentAxisQueryIterable(Iterable[XObsDatum]): + """An :class:`Iterable` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as + selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator + produces a batch containing equal-sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and + :class:`pandas.DataFrame`, respectively. + + Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and + :class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset` + and :class:`ExperimentAxisQueryIterDataPipe` for more details on usage. + + Lifecycle: + experimental + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str, + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + io_batch_size: int = 2**16, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy :class:`numpy.ndarray` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas + :class:`pandas.DataFrame`, respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data to iterate over. + X_name: + The name of the X layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will + result in :class:`torch.Tensor`s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows + this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` batching, but higher + performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s + ``batch_size`` parameter to ``None``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts + maximum memory utilization, larger values provide better read performance, but require more memory. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the + default), will return ``X`` data as a :class:`numpy.ndarray`. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is + made available for processing via the iterator. This allows network (or filesystem) requests to be made + in parallel with client-side processing of the SOMA data, potentially improving overall performance at + the cost of doubling memory utilization. Defaults to ``True``. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + + """ + + super().__init__() + + # Anything set in the instance needs to be pickle-able for multi-process DataLoaders + self.experiment_locator = _ExperimentLocator.create(query.experiment) + self.layer_name = X_name + self.measurement_name = query.measurement_name + self.obs_query = query._matrix_axis_query.obs + self.var_query = query._matrix_axis_query.var + self.obs_column_names = list(obs_column_names) + self.batch_size = batch_size + self.io_batch_size = io_batch_size + self.return_sparse_X = return_sparse_X + self.use_eager_fetch = use_eager_fetch + self._obs_joinids: npt.NDArray[np.int64] | None = None + self._var_joinids: npt.NDArray[np.int64] | None = None + self._initialized = False + + if not self.obs_column_names: + raise ValueError("Must specify at least one value in `obs_column_names`") + + def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: + """Create iterator over obs id chunks with split size of (roughly) io_batch_size. + + As appropriate, will partition per worker. + + IMPORTANT: in any scenario using torch.distributed, where WORLD_SIZE > 1, this will + always partition such that each process has the same number of samples. Where + the number of obs_joinids is not evenly divisible by the number of processes, + the number of joinids will be dropped (dropped ids can never exceed WORLD_SIZE-1). + + Abstractly, the steps taken: + 1. Split the joinids into WORLD_SIZE sections (aka number of GPUS in DDP) + 2. Trim the splits to be of equal length + 3. Partition by number of data loader workers (to not generate redundant batches + in cases where the DataLoader is running with `n_workers>1`). + + Private method. + """ + assert self._obs_joinids is not None + obs_joinids: npt.NDArray[np.int64] = self._obs_joinids + + # 1. Get the split for the model replica/GPU + world_size, rank = _get_distributed_world_rank() + _gpu_splits = _splits(len(obs_joinids), world_size) + _gpu_split = obs_joinids[_gpu_splits[rank] : _gpu_splits[rank + 1]] + + # 2. Trim to be all of equal length - equivalent to a "drop_last" + # TODO: may need to add an option to do padding as well. + min_len = np.diff(_gpu_splits).min() + assert 0 <= (np.diff(_gpu_splits).min() - min_len) <= 1 + _gpu_split = _gpu_split[:min_len] + + obs_joinids_chunked = np.array_split( + _gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size)) + ) + + # 3. Partition by DataLoader worker + n_workers, worker_id = _get_worker_world_rank() + obs_splits = _splits(len(obs_joinids_chunked), n_workers) + obs_partition_joinids = obs_joinids_chunked[ + obs_splits[worker_id] : obs_splits[worker_id + 1] + ].copy() + + if logger.isEnabledFor(logging.DEBUG): + partition_size = sum([len(chunk) for chunk in obs_partition_joinids]) + logger.debug( + f"Process {os.getpid()} {rank=}, {world_size=}, {worker_id=}, n_workers={n_workers}, {partition_size=}" + ) + + return iter(obs_partition_joinids) + + def _init_once(self, exp: soma.Experiment | None = None) -> None: + """One-time per worker initialization. + + All operations should be idempotent in order to support pipe reset(). + + Private method. + """ + if self._initialized: + return + + logger.debug("Initializing ExperimentAxisQueryIterable") + + if exp is None: + # If no user-provided Experiment, open/close it ourselves + exp_cm: ContextManager[soma.Experiment] = ( + self.experiment_locator.open_experiment() + ) + else: + # else, it is caller responsibility to open/close the experiment + exp_cm = contextlib.nullcontext(exp) + + with exp_cm as exp: + with exp.axis_query( + measurement_name=self.measurement_name, + obs_query=self.obs_query, + var_query=self.var_query, + ) as query: + self._obs_joinids = query.obs_joinids().to_numpy() + self._var_joinids = query.var_joinids().to_numpy() + + self._initialized = True + + def __iter__(self) -> Iterator[XObsDatum]: + """Create iterator over query. + + Returns: + ``iterator`` + + Lifecycle: + experimental + """ + + if ( + self.return_sparse_X + and torch.utils.data.get_worker_info() + and torch.utils.data.get_worker_info().num_workers > 0 + ): + raise NotImplementedError( + "torch does not work with sparse tensors in multi-processing mode " + "(see https://github.com/pytorch/pytorch/issues/20248)" + ) + + world_size, rank = _get_distributed_world_rank() + n_workers, worker_id = _get_worker_world_rank() + logger.debug( + f"Iterator created {rank=}, {world_size=}, {worker_id=}, {n_workers=}" + ) + + with self.experiment_locator.open_experiment() as exp: + self._init_once(exp) + X = exp.ms[self.measurement_name].X[self.layer_name] + if not isinstance(X, soma.SparseNDArray): + raise NotImplementedError( + "ExperimentAxisQueryIterable only supports X layers which are of type SparseNDArray" + ) + + obs_joinid_iter = self._create_obs_joinids_partition() + _mini_batch_iter = self._mini_batch_iter(exp.obs, X, obs_joinid_iter) + if self.use_eager_fetch: + _mini_batch_iter = _EagerIterator( + _mini_batch_iter, pool=exp.context.threadpool + ) + + yield from _mini_batch_iter + + def __len__(self) -> int: + """Return the number of batches this iterable will produce. If run in the context of :class:`torch.distributed` + or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the + batch count will reflect the size of the data partition assigned to the active process. + + See important caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + documentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + ``int`` (Number of batches). + + Lifecycle: + experimental + """ + return self.shape[0] + + @property + def shape(self) -> Tuple[int, int]: + """Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. + + If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), + the number of batches will reflect the size of the data partition assigned to the active process. + + Returns: + A tuple of two ``int`` values: number of batches, number of vars. + + Lifecycle: + experimental + """ + self._init_once() + assert self._obs_joinids is not None + assert self._var_joinids is not None + world_size, rank = _get_distributed_world_rank() + n_workers, worker_id = _get_worker_world_rank() + # Every "distributed" process must receive the same number of "obs" rows; the last ≤world_size may be dropped + # (see _create_obs_joinids_partition). + obs_per_proc = len(self._obs_joinids) // world_size + obs_per_worker, obs_rem = divmod(obs_per_proc, n_workers) + # obs rows assigned to this worker process + n_worker_obs = obs_per_worker + bool(worker_id < obs_rem) + n_batches, rem = divmod(n_worker_obs, self.batch_size) + # (num batches this worker will produce, num features) + return n_batches + bool(rem), len(self._var_joinids) + + def __getitem__(self, index: int) -> XObsDatum: + raise NotImplementedError( + "``ExperimentAxisQueryIterable can only be iterated - does not support mapping" + ) + + def _io_batch_iter( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[Tuple[sparse.csr_matrix, pd.DataFrame]]: + """Iterate over IO batches, i.e., SOMA query reads, producing tuples of ``(X: csr_array, obs: DataFrame)``. + + ``obs`` joinids read are controlled by the ``obs_joinid_iter``. Iterator results will be reindexed. + + Private method. + """ + assert self._var_joinids is not None + + obs_column_names = ( + list(self.obs_column_names) + if "soma_joinid" in self.obs_column_names + else ["soma_joinid", *self.obs_column_names] + ) + var_indexer = soma.IntIndexer(self._var_joinids, context=X.context) + + for obs_coords in obs_joinid_iter: + st_time = time.perf_counter() + obs_indexer = soma.IntIndexer(obs_coords, context=X.context) + logger.debug( + f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." + ) + + X_tbl = X.read(coords=(obs_coords, self._var_joinids)).tables().concat() + X_io_batch = sparse.csr_matrix( + ( + X_tbl["soma_data"].to_numpy(), + ( + obs_indexer.get_indexer(X_tbl["soma_dim_0"]), + var_indexer.get_indexer(X_tbl["soma_dim_1"]), + ), + ), + shape=(len(obs_coords), len(self._var_joinids)), + ) + + # Now that X read is potentially in progress (in eager mode), go fetch obs data + # fmt: off + obs_io_batch = ( + obs.read(coords=(obs_coords,), column_names=obs_column_names) + .concat() + .to_pandas() + .set_index("soma_joinid") + .reindex(obs_coords, copy=False) + .reset_index() # demote "soma_joinid" to a column + [self.obs_column_names] + ) # fmt: on + + del obs_indexer, obs_coords, X_tbl + gc.collect() + + tm = time.perf_counter() - st_time + logger.debug( + f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f} samples/sec" + ) + yield X_io_batch, obs_io_batch + + def _mini_batch_iter( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[XObsDatum]: + """Break IO batches into mini-batch-sized chunks. + + Private method. + """ + assert self._obs_joinids is not None + assert self._var_joinids is not None + + io_batch_iter = self._io_batch_iter(obs, X, obs_joinid_iter) + if self.use_eager_fetch: + io_batch_iter = _EagerIterator(io_batch_iter, pool=X.context.threadpool) + + mini_batch_size = self.batch_size + result: Tuple[NDArrayNumber, pd.DataFrame] | None = None + for X_io_batch, obs_io_batch in io_batch_iter: + assert X_io_batch.shape[0] == obs_io_batch.shape[0] + assert X_io_batch.shape[1] == len(self._var_joinids) + iob_idx = 0 # current offset into io batch + iob_len = X_io_batch.shape[0] + + while iob_idx < iob_len: + if result is None: + X_datum = ( + X_io_batch[iob_idx : iob_idx + mini_batch_size] + if self.return_sparse_X + else X_io_batch[iob_idx : iob_idx + mini_batch_size].toarray() + ) + result = ( + X_datum, + obs_io_batch.iloc[ + iob_idx : iob_idx + mini_batch_size + ].reset_index(drop=True), + ) + iob_idx += len(result[1]) + else: + # Use any remnant from previous IO batch + to_take = min(mini_batch_size - len(result[1]), iob_len - iob_idx) + X_datum = ( + sparse.vstack([result[0], X_io_batch[0:to_take]]) + if self.return_sparse_X + else np.concatenate( + [result[0], X_io_batch[0:to_take].toarray()] + ) + ) + result = ( + X_datum, + pd.concat( + [result[1], obs_io_batch.iloc[0:to_take]], + # Index `obs_batch` from 0 to N-1, instead of disjoint, concatenated pieces of IO batches' + # indices + ignore_index=True, + ), + ) + iob_idx += to_take + + assert result[0].shape[0] == result[1].shape[0] + if result[0].shape[0] == mini_batch_size: + yield result + result = None + + else: + # yield the remnant, if any + if result is not None: + yield result + + +class ExperimentAxisQueryIterDataPipe( + torchdata.datapipes.iter.IterDataPipe[ # type:ignore[misc] + torch.utils.data.dataset.Dataset[XObsDatum] + ], +): + """A :class:`torchdata.datapipes.iter.IterDataPipe` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for + legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the + TorchData [README](https://github.com/pytorch/data/blob/v0.8.0/README.md) for more information. + + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str = "raw", + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + io_batch_size: int = 2**16, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + super().__init__() + self._exp_iter = ExperimentAxisQueryIterable( + query=query, + X_name=X_name, + obs_column_names=obs_column_names, + batch_size=batch_size, + io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + ) + + def __iter__(self) -> Iterator[XObsDatum]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + if batch_size == 1: + X = X[0] # This is a no-op for `csr_matrix`s + yield X, obs + + def __len__(self) -> int: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + return len(self._exp_iter) + + @property + def shape(self) -> Tuple[int, int]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + return self._exp_iter.shape + + +class ExperimentAxisQueryIterableDataset( + torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc] +): + """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class works seamlessly with :class:`torch.utils.data.DataLoader` to load ``obs`` and ``X`` data as + specified by a SOMA :class:`tiledbsoma.ExperimentAxisQuery`, providing an iterator over batches of + ``obs`` and ``X`` data. Each iteration will yield a tuple containing an :class:`numpy.ndarray` + and a :class:`pandas.DataFrame`. + + For example: + + >>> import torch + >>> import tiledbsoma + >>> import tiledbsoma_ml + >>> with tiledbsoma.Experiment.open("my_experiment_path") as exp: + ... with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: + ... ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query) + ... dataloader = torch.utils.data.DataLoader(ds) + >>> data = next(iter(dataloader)) + >>> data + (array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), + soma_joinid + 0 57905025) + >>> data[0] + array([0., 0., 0., ..., 0., 0., 0.], dtype=float32) + >>> data[1] + soma_joinid + 0 57905025 + + The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each + iteration. If the ``batch_size`` is 1, then each result will have rank 1, else it will have rank 2. A ``batch_size`` + of 1 is compatible with :class:`torch.utils.data.DataLoader`-implemented batching, but it will usually be more + performant to create mini-batches using this class, and set the ``DataLoader`` batch size to `None`. + + The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` DataFrame (the + default is a single column, containing the ``soma_joinid`` for the ``obs`` dimension). + + The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A + larger value will increase total memory usage and may reduce average read time per row. + + This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader` + and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition + data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any + data tail will be dropped. + + Lifecycle: + experimental + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str = "raw", + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + io_batch_size: int = 2**16, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy ``ndarray`` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas ``DataFrame`` respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. + X_name: + The name of the ``X`` layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). + + Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` + batch_size parameter to ``None``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will + return ``X`` data as a :class:`numpy.ndarray`. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made + available for processing via the iterator. This allows network (or filesystem) requests to be made in + parallel with client-side processing of the SOMA data, potentially improving overall performance at the + cost of doubling memory utilization. Defaults to ``True``. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + + """ + super().__init__() + self._exp_iter = ExperimentAxisQueryIterable( + query=query, + X_name=X_name, + obs_column_names=obs_column_names, + batch_size=batch_size, + io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + ) + + def __iter__(self) -> Iterator[XObsDatum]: + """Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and + :class:`pandas.DataFrame`. + + Returns: + ``iterator`` + + Lifecycle: + experimental + """ + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + if batch_size == 1: + X = X[0] # This is a no-op for `csr_matrix`s + yield X, obs + + def __len__(self) -> int: + """Return number of batches this iterable will produce. + + See important caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + documentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + ``int`` (number of batches). + + Lifecycle: + experimental + """ + return len(self._exp_iter) + + @property + def shape(self) -> Tuple[int, int]: + """Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. + + If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), + the number of batches will reflect the size of the data partition assigned to the active process. + + Returns: + A tuple of two ``int`` values: number of batches, number of vars. + + Lifecycle: + experimental + """ + return self._exp_iter.shape + + +def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: + """For ``total_length`` points, compute start/stop offsets that split the length into roughly equal sizes. + + A total_length of L, split into N sections, will return L%N sections of size L//N+1, + and the remainder as size L//N. This results in the same split as numpy.array_split, + for an array of length L and sections N. + + Private. + + Examples + -------- + >>> _splits(10, 3) + array([0, 4, 7, 10]) + >>> _splits(4, 2) + array([0, 2, 4]) + """ + if sections <= 0: + raise ValueError("number of sections must greater than 0.") from None + each_section, extras = divmod(total_length, sections) + per_section_sizes = ( + [0] + extras * [each_section + 1] + (sections - extras) * [each_section] + ) + splits = np.array(per_section_sizes, dtype=np.intp).cumsum() + return splits + + +if sys.version_info >= (3, 12): + _batched = itertools.batched +else: + + def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: + """Same as the Python 3.12+ ``itertools.batched`` -- polyfill for old Python versions.""" + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch + + +def _get_distributed_world_rank() -> Tuple[int, int]: + """Return tuple containing equivalent of ``torch.distributed`` world size and rank.""" + world_size, rank = 1, 0 + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: + # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There + # is a NODE_RANK for the node's rank, but no way to tell the local node's + # world. So computing a global rank is impossible(?). Using LOCAL_RANK as a + # proxy, which works fine on a single-CPU box. TODO: could throw/error + # if NODE_RANK != 0. + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["LOCAL_RANK"]) + elif torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + return world_size, rank + + +def _get_worker_world_rank() -> Tuple[int, int]: + """Return number of DataLoader workers and our worker rank/id""" + num_workers, worker = 1, 0 + if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: + num_workers = int(os.environ["NUM_WORKERS"]) + worker = int(os.environ["WORKER"]) + else: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + num_workers = worker_info.num_workers + worker = worker_info.id + return num_workers, worker diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py new file mode 100644 index 0000000..8cbae43 --- /dev/null +++ b/tests/test_pytorch.py @@ -0,0 +1,580 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +from functools import partial +from pathlib import Path +from typing import Callable, Optional, Sequence, Union +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +import tiledbsoma as soma +from pandas._testing import assert_frame_equal +from scipy import sparse +from scipy.sparse import coo_matrix, spmatrix +from tiledbsoma import Experiment, _factory +from tiledbsoma._collection import CollectionBase +from torch.utils.data._utils.worker import WorkerInfo + +from tiledbsoma_ml.pytorch import ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, +) + +assert_array_equal = partial(np.testing.assert_array_equal, strict=True) + +# These control which classes are tested (for most, but not all tests). +# Centralized to allow easy add/delete of specific test parameters. +PipeClassType = Union[ + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, +] +PipeClasses = ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, +) +XValueGen = Callable[[range, range], spmatrix] + + +def pytorch_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + occupied_shape = ( + obs_range.stop - obs_range.start, + var_range.stop - var_range.start, + ) + checkerboard_of_ones = coo_matrix(np.indices(occupied_shape).sum(axis=0) % 2) + checkerboard_of_ones.row += obs_range.start + checkerboard_of_ones.col += var_range.start + return checkerboard_of_ones + + +def pytorch_seq_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + """A sparse matrix where the values of each col are the obs_range values. Useful for checking the + X values are being returned in the correct order.""" + data = np.vstack([list(obs_range)] * len(var_range)).flatten() + rows = np.vstack([list(obs_range)] * len(var_range)).flatten() + cols = np.column_stack([list(var_range)] * len(obs_range)).flatten() + return coo_matrix((data, (rows, cols))) + + +@pytest.fixture +def X_layer_names() -> list[str]: + return ["raw"] + + +@pytest.fixture +def obsp_layer_names() -> Optional[list[str]]: + return None + + +@pytest.fixture +def varp_layer_names() -> Optional[list[str]]: + return None + + +def add_dataframe(coll: CollectionBase, key: str, value_range: range) -> None: + df = coll.add_new_dataframe( + key, + schema=pa.schema( + [ + ("soma_joinid", pa.int64()), + ("label", pa.large_string()), + ("label2", pa.large_string()), + ] + ), + index_column_names=["soma_joinid"], + ) + df.write( + pa.Table.from_pydict( + { + "soma_joinid": list(value_range), + "label": [str(i) for i in value_range], + "label2": ["c" for i in value_range], + } + ) + ) + + +def add_sparse_array( + coll: CollectionBase, + key: str, + obs_range: range, + var_range: range, + value_gen: XValueGen, +) -> None: + a = coll.add_new_sparse_ndarray( + key, type=pa.float32(), shape=(obs_range.stop, var_range.stop) + ) + tensor = pa.SparseCOOTensor.from_scipy(value_gen(obs_range, var_range)) + a.write(tensor) + + +@pytest.fixture(scope="function") +def soma_experiment( + tmp_path: Path, + obs_range: Union[int, range], + var_range: Union[int, range], + X_value_gen: XValueGen, + obsp_layer_names: Sequence[str], + varp_layer_names: Sequence[str], +) -> soma.Experiment: + with soma.Experiment.create((tmp_path / "exp").as_posix()) as exp: + if isinstance(obs_range, int): + obs_range = range(obs_range) + if isinstance(var_range, int): + var_range = range(var_range) + + add_dataframe(exp, "obs", obs_range) + ms = exp.add_new_collection("ms") + rna = ms.add_new_collection("RNA", soma.Measurement) + add_dataframe(rna, "var", var_range) + rna_x = rna.add_new_collection("X", soma.Collection) + add_sparse_array(rna_x, "raw", obs_range, var_range, X_value_gen) + + if obsp_layer_names: + obsp = rna.add_new_collection("obsp") + for obsp_layer_name in obsp_layer_names: + add_sparse_array( + obsp, obsp_layer_name, obs_range, var_range, X_value_gen + ) + + if varp_layer_names: + varp = rna.add_new_collection("varp") + for varp_layer_name in varp_layer_names: + add_sparse_array( + varp, varp_layer_name, obs_range, var_range, X_value_gen + ) + return _factory.open((tmp_path / "exp").as_posix()) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_non_batched( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, + return_sparse_X: bool, +) -> None: + """Check batches of size 1 (the default)""" + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, + ) + assert exp_data_pipe.shape == (6, 3) + batch_iter = iter(exp_data_pipe) + for idx, (X_batch, obs_batch) in enumerate(batch_iter): + expected_X = [0, 1, 0] if idx % 2 == 0 else [1, 0, 1] + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + # Sparse slices are always 2D + assert X_batch.shape == (1, 3) + assert X_batch.todense().tolist() == [expected_X] + else: + assert isinstance(X_batch, np.ndarray) + if PipeClass is ExperimentAxisQueryIterable: + assert X_batch.shape == (1, 3) + assert X_batch.tolist() == [expected_X] + else: + # ExperimentAxisQueryIterData{Pipe,set} "squeeze" dense single-row batches + assert X_batch.shape == (3,) + assert X_batch.tolist() == expected_X + + assert_frame_equal(obs_batch, pd.DataFrame({"label": [str(idx)]})) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_uneven_soma_and_result_batches( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, + return_sparse_X: bool, +) -> None: + """Check that batches are correctly created when they require fetching multiple chunks.""" + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + io_batch_size=2, + use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, + ) + assert exp_data_pipe.shape == (2, 3) + batch_iter = iter(exp_data_pipe) + + X_batch, obs_batch = next(batch_iter) + assert X_batch.shape == (3, 3) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + else: + assert isinstance(X_batch, np.ndarray) + assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]})) + + X_batch, obs_batch = next(batch_iter) + assert X_batch.shape == (3, 3) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + else: + assert isinstance(X_batch, np.ndarray) + assert X_batch.tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["3", "4", "5"]})) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_batching__all_batches_full_size( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, + return_sparse_X: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, + ) + batch_iter = iter(exp_data_pipe) + assert exp_data_pipe.shape == (2, 3) + + X_batch, obs_batch = next(batch_iter) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]})) + + X_batch, obs_batch = next(batch_iter) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + assert X_batch.tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["3", "4", "5"]})) + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(range(100_000_000, 100_000_003), 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_soma_joinids( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid", "label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + assert exp_data_pipe.shape == (1, 3) + + soma_joinids = np.concatenate( + [batch[1]["soma_joinid"].to_numpy() for batch in exp_data_pipe] + ) + assert_array_equal(soma_joinids, np.arange(100_000_000, 100_000_003)) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(5, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_batching__partial_final_batch_size( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, + return_sparse_X: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, + ) + assert exp_data_pipe.shape == (2, 3) + batch_iter = iter(exp_data_pipe) + + next(batch_iter) + X_batch, obs_batch = next(batch_iter) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + assert X_batch.tolist() == [[1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["3", "4"]})) + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(3, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_batching__exactly_one_batch( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + assert exp_data_pipe.shape == (1, 3) + batch_iter = iter(exp_data_pipe) + X_batch, obs_batch = next(batch_iter) + assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]})) + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_batching__empty_query_result( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query( + measurement_name="RNA", obs_query=soma.AxisQuery(coords=([],)) + ) as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + assert exp_data_pipe.shape == (0, 3) + batch_iter = iter(exp_data_pipe) + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(10, 1, pytorch_x_value_gen)], +) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_batching__partial_soma_batches_are_concatenated( + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + # set SOMA batch read size such that PyTorch batches will span the tail and head of two SOMA batches + io_batch_size=4, + use_eager_fetch=use_eager_fetch, + ) + + batches = list(exp_data_pipe) + + assert [len(batch[0]) for batch in batches] == [3, 3, 3, 1] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen), (7, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize( + "world_size,rank", + [(3, 0), (3, 1), (3, 2), (2, 0), (2, 1)], +) +@pytest.mark.parametrize("PipeClass", PipeClasses) +def test_distributed__returns_data_partition_for_rank( + PipeClass: PipeClassType, + soma_experiment: Experiment, + obs_range: int, + world_size: int, + rank: int, +) -> None: + """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode, + using mocks to avoid having to do real PyTorch distributed setup.""" + + with ( + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + ) + batches = list(iter(dp)) + soma_joinids = np.concatenate( + [batch[1]["soma_joinid"].to_numpy() for batch in batches] + ) + + expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ + 0 : obs_range // world_size + ].tolist() + assert sorted(soma_joinids) == expected_joinids + + +# fmt: off +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,world_size,num_workers,splits", + [ + (12, 3, pytorch_x_value_gen, 3, 2, [[0, 2, 4], [4, 6, 8], [ 8, 10, 12]]), + (13, 3, pytorch_x_value_gen, 3, 2, [[0, 2, 4], [5, 7, 9], [ 9, 11, 13]]), + (15, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 5], [5, 9, 10], [10, 14, 15]]), + (16, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 5], [6, 10, 11], [11, 15, 16]]), + (18, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [6, 10, 12], [12, 16, 18]]), + (19, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [7, 11, 13], [13, 17, 19]]), + (20, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [7, 11, 13], [14, 18, 20]]), + (21, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 7], [7, 11, 14], [14, 18, 21]]), + (25, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 8], [9, 13, 17], [17, 21, 25]]), + (27, 3, pytorch_x_value_gen, 3, 2, [[0, 6, 9], [9, 15, 18], [18, 24, 27]]), + ], +) +# fmt: on +def test_distributed_and_multiprocessing__returns_data_partition_for_rank( + soma_experiment: Experiment, + world_size: int, + num_workers: int, + splits: list[list[int]], +) -> None: + """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode and + DataLoader multiprocessing mode, using mocks to avoid having to do distributed pytorch + setup or real DataLoader multiprocessing.""" + + for rank in range(world_size): + proc_splits = splits[rank] + for worker_id in range(num_workers): + expected_joinids = list( + range(proc_splits[worker_id], proc_splits[worker_id + 1]) + ) + with ( + patch("torch.utils.data.get_worker_info") as mock_get_worker_info, + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): + mock_get_worker_info.return_value = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=1234 + ) + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = ExperimentAxisQueryIterable( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + ) + + batches = list(iter(dp)) + + soma_joinids = np.concatenate( + [batch[1]["soma_joinid"].to_numpy() for batch in batches] + ).tolist() + + assert soma_joinids == expected_joinids + + +def test_batched() -> None: + from tiledbsoma_ml.pytorch import _batched + + assert list(_batched(range(6), 1)) == list((i,) for i in range(6)) + assert list(_batched(range(6), 2)) == [(0, 1), (2, 3), (4, 5)] + assert list(_batched(range(6), 3)) == [(0, 1, 2), (3, 4, 5)] + assert list(_batched(range(6), 4)) == [(0, 1, 2, 3), (4, 5)] + assert list(_batched(range(6), 5)) == [(0, 1, 2, 3, 4), (5,)] + assert list(_batched(range(6), 6)) == [(0, 1, 2, 3, 4, 5)] + assert list(_batched(range(6), 7)) == [(0, 1, 2, 3, 4, 5)] + + # bogus batch value + with pytest.raises(ValueError): + list(_batched([0, 1], 0)) + with pytest.raises(ValueError): + list(_batched([2, 3], -1)) + + +def test_splits() -> None: + from tiledbsoma_ml.pytorch import _splits + + assert _splits(10, 1).tolist() == [0, 10] + assert _splits(10, 2).tolist() == [0, 5, 10] + assert _splits(10, 3).tolist() == [0, 4, 7, 10] + assert _splits(10, 4).tolist() == [0, 3, 6, 8, 10] + assert _splits(10, 10).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert _splits(10, 11).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10] + + # bad number of sections + with pytest.raises(ValueError): + _splits(10, 0) + with pytest.raises(ValueError): + _splits(10, -1)