From ffb7ff8501162209f834ad625591ff100d87c1ac Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Mon, 23 Sep 2024 13:04:07 -0700 Subject: [PATCH 01/25] initial cut at a torch.utils.data.IterDataset --- src/tiledbsoma_ml/__init__.py | 10 + src/tiledbsoma_ml/pytorch.py | 797 ++++++++++++++++++++++++++++++++++ tests/test_pytorch.py | 626 ++++++++++++++++++++++++++ 3 files changed, 1433 insertions(+) create mode 100644 src/tiledbsoma_ml/pytorch.py create mode 100644 tests/test_pytorch.py diff --git a/src/tiledbsoma_ml/__init__.py b/src/tiledbsoma_ml/__init__.py index 7adb437..49e850e 100644 --- a/src/tiledbsoma_ml/__init__.py +++ b/src/tiledbsoma_ml/__init__.py @@ -5,4 +5,14 @@ """An API to support machine learning applications built on SOMA.""" +from .pytorch import ( + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, +) + __version__ = "0.1.0-dev" + +__all__ = [ + "ExperimentAxisQueryIterDataPipe", + "ExperimentAxisQueryIterableDataset", +] diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py new file mode 100644 index 0000000..32d90b3 --- /dev/null +++ b/src/tiledbsoma_ml/pytorch.py @@ -0,0 +1,797 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +import contextlib +import gc +import itertools +import logging +import os +import sys +import time +from contextlib import contextmanager +from itertools import islice +from math import ceil +from typing import ( + TYPE_CHECKING, + Any, + ContextManager, + Dict, + Iterable, + Iterator, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) + +import attrs +import numpy as np +import numpy.typing as npt +import pandas as pd +import scipy.sparse as sparse +import tiledbsoma as soma +import torch +import torchdata +from somacore.query._eager_iter import EagerIterator as _EagerIterator +from typing_extensions import TypeAlias + +logger = logging.getLogger("tiledbsoma_ml.pytorch") + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +if TYPE_CHECKING: + # Python 3.8 does not support subscripting types, so work-around by + # restricting this to when we are running a type checker. TODO: remove + # the conditional when Python 3.8 support is dropped. + NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] + XDatum: TypeAlias = Union[NDArrayNumber, sparse.csr_matrix] +else: + NDArrayNumber: TypeAlias = np.ndarray + XDatum: TypeAlias = Union[np.ndarray, sparse.csr_matrix] + +XObsDatum: TypeAlias = Tuple[XDatum, pd.DataFrame] +"""Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, +which pairs a slice of ``X`` rows with a cooresponding slice of ``obs``. In the default case, +the datum is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs`` +respectively). If the object is created with ``return_sparse_X`` as True, the ``X`` slice is +returned as a :class:`scipy.sparse.csr_matrix`. If the ``batch_size`` is 1, the :class:`numpy.ndarray` +will be returned with rank 1; in all other cases, objects are returned with rank 2.""" + + +@attrs.define(frozen=True, kw_only=True) +class _ExperimentLocator: + """State required to open the Experiment. + + Necessary as we will likely be invoked across multiple processes. + + Private implementation class. + """ + + uri: str + tiledb_timestamp_ms: int + tiledb_config: Dict[str, Union[str, float]] + + @classmethod + def create(cls, experiment: soma.Experiment) -> "_ExperimentLocator": + return _ExperimentLocator( + uri=experiment.uri, + tiledb_timestamp_ms=experiment.tiledb_timestamp_ms, + tiledb_config=experiment.context.tiledb_config, + ) + + @contextmanager + def open_experiment(self) -> Iterator[soma.Experiment]: + context = soma.SOMATileDBContext(tiledb_config=self.tiledb_config) + with soma.Experiment.open( + self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context + ) as exp: + yield exp + + +class ExperimentAxisQueryIterable(Iterable[XObsDatum]): + """An :class:`Iterator` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as + selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator + produces equal sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and + :class:`pandas.DataFrame`, respectively. + + Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and + :class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset` + and `ExperimentAxisQueryIterDataPipe` for more details on usage. + + Lifecycle: + experimental + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str, + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + io_batch_size: int = 2**16, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy :class:`numpy.ndarray` (or optionally, :class:`scipy.sparse.csr_matrix`) and a + Pandas :class:`pandas.DataFrame`, respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data to iterate over. + X_name: + The name of the X layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows + this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but higher performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s + ``batch_size`` parameter to ``None``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts + maximum memory utilization, larger values provide better read performance, but require more memory. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will + return ``X`` data as a :class:`numpy.ndarray`. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made + available for processing via the iterator. This allows network (or filesystem) requests to be made in + parallel with client-side processing of the SOMA data, potentially improving overall performance at the + cost of doubling memory utilization. Defaults to ``True``. + + Returns: + An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to + a :class:`torch.utils.data.DataLoader` instance. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + + """ + + super().__init__() + + # Anything set in the instance needs to be picklable for multi-process DataLoaders + self.experiment_locator = _ExperimentLocator.create(query.experiment) + self.layer_name = X_name + self.measurement_name = query.measurement_name + self.obs_query = query._matrix_axis_query.obs + self.var_query = query._matrix_axis_query.var + self.obs_column_names = list(obs_column_names) + self.batch_size = batch_size + self.io_batch_size = io_batch_size + self.return_sparse_X = return_sparse_X + self.use_eager_fetch = use_eager_fetch + self._obs_joinids: npt.NDArray[np.int64] | None = None + self._var_joinids: npt.NDArray[np.int64] | None = None + self._initialized = False + + if not self.obs_column_names: + raise ValueError("Must specify at least one value in `obs_column_names`") + + def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: + """Create iterator over obs id chunks with split size of (roughly) io_batch_size. + + As appropriate, will partition per worker. + + IMPORTANT: in any scenario using torch.distributed, where WORLD_SIZE > 1, this will + always partition such that each process has the same number of samples. Where + the number of obs_joinids is not evenly divisible by the number of processes, + the number of joinids will be dropped (dropped ids can never exceed WORLD_SIZE-1). + + Abstractly, the steps taken: + 1. Split the joinids into WORLD_SIZE sections (aka number of GPUS in DDP) + 2. Trim the splits to be of equal length + 3. Partition by number of data loader workers (to not generate redundant batches + in cases where the DataLoader is running with `n_workers>1`). + + Private method. + """ + assert self._obs_joinids is not None + obs_joinids: npt.NDArray[np.int64] = self._obs_joinids + + # 1. Get the split for the model replica/GPU + world_size, rank = _get_distributed_world_rank() + _gpu_splits = _splits(len(obs_joinids), world_size) + _gpu_split = obs_joinids[_gpu_splits[rank] : _gpu_splits[rank + 1]] + + # 2. Trim to be all of equal length - equivalent to a "drop_last" + # TODO: may need to add an option to do padding as well. + min_len = np.diff(_gpu_splits).min() + assert 0 <= (np.diff(_gpu_splits).min() - min_len) <= 1 + _gpu_split = _gpu_split[:min_len] + + obs_joinids_chunked = np.array_split( + _gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size)) + ) + + # 3. Partition by DataLoader worker + n_workers, worker_id = _get_worker_world_rank() + obs_splits = _splits(len(obs_joinids_chunked), n_workers) + obs_partition_joinids = obs_joinids_chunked[ + obs_splits[worker_id] : obs_splits[worker_id + 1] + ].copy() + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Process {os.getpid()} rank={rank}, world_size={world_size}, worker_id={worker_id}, " + f"n_workers={n_workers}, " + f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" + ) + + return iter(obs_partition_joinids) + + def _init_once(self, exp: soma.Experiment | None = None) -> None: + """One-time per worker initialization. + + All operations be idempotent in order to support pipe reset(). + + Private method. + """ + if self._initialized: + return + + logger.debug("Initializing ExperimentAxisQueryIterable") + + if exp is None: + # If no user-provided Experiment, open/close it ourselves + exp_cm: ContextManager[soma.Experiment] = ( + self.experiment_locator.open_experiment() + ) + else: + # else, it is caller responsibility to open/close the experiment + exp_cm = contextlib.nullcontext(exp) + + with exp_cm as exp: + with exp.axis_query( + measurement_name=self.measurement_name, + obs_query=self.obs_query, + var_query=self.var_query, + ) as query: + self._obs_joinids = query.obs_joinids().to_numpy() + self._var_joinids = query.var_joinids().to_numpy() + + self._initialized = True + + def __iter__(self) -> Iterator[XObsDatum]: + """Create iterator over query. + + Returns: + ``iterator`` + + Lifecycle: + experimental + """ + + if ( + self.return_sparse_X + and torch.utils.data.get_worker_info() + and torch.utils.data.get_worker_info().num_workers > 0 + ): + raise NotImplementedError( + "torch does not work with sparse tensors in multi-processing mode " + "(see https://github.com/pytorch/pytorch/issues/20248)" + ) + + world_size, rank = _get_distributed_world_rank() + n_workers, worker_id = _get_worker_world_rank() + logger.debug( + f"Iterator created rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}" + ) + + with self.experiment_locator.open_experiment() as exp: + self._init_once(exp) + X = exp.ms[self.measurement_name].X[self.layer_name] + if not isinstance(X, soma.SparseNDArray): + raise NotImplementedError( + "ExperimentAxisQueryIterable only supports X layers which are of type SparseNDArray" + ) + + obs_joinid_iter = self._create_obs_joinids_partition() + _mini_batch_iter = self._mini_batch_iter(exp.obs, X, obs_joinid_iter) + if self.use_eager_fetch: + _mini_batch_iter = _EagerIterator( + _mini_batch_iter, pool=exp.context.threadpool + ) + + yield from _mini_batch_iter + + def __len__(self) -> int: + """Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or + as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) + count will reflect the size of the data partition assigned to the active process. + + See important caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + documentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + An ``int``. + + Lifecycle: + experimental + """ + return self.shape[0] + + @property + def shape(self) -> Tuple[int, int]: + """Get the approximate shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. + This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode + (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect + the size of the data partition assigned to the active process. + + Returns: + A tuple of two ``int`` values: number of obs, number of vars. + + Lifecycle: + experimental + """ + self._init_once() + assert self._obs_joinids is not None + assert self._var_joinids is not None + world_size, _ = _get_distributed_world_rank() + n_workers, _ = _get_worker_world_rank() + partition_len = len(self._obs_joinids) // world_size // n_workers + div, rem = divmod(partition_len, self.batch_size) + return div + bool(rem), len(self._var_joinids) + + def __getitem__(self, index: int) -> XObsDatum: + raise NotImplementedError( + "``ExperimentAxisQueryIterable can only be iterated - does not support mapping" + ) + + def _io_batch_iter( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[Tuple[sparse.csr_matrix, pd.DataFrame]]: + """Iterate over IO batches, i.e., SOMA query/read, producing a tuple of + (X: csr_array, obs: DataFrame). + + obs joinids read are controlled by the obs_joinid_iter. Iterator results will + be reindexed. + + Private method. + """ + assert self._var_joinids is not None + + obs_column_names = ( + list(self.obs_column_names) + if "soma_joinid" in self.obs_column_names + else ["soma_joinid", *self.obs_column_names] + ) + var_indexer = soma.IntIndexer(self._var_joinids, context=X.context) + + for obs_coords in obs_joinid_iter: + st_time = time.perf_counter() + obs_indexer = soma.IntIndexer(obs_coords, context=X.context) + logger.debug( + f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." + ) + + X_tbl = X.read(coords=(obs_coords, self._var_joinids)).tables().concat() + X_io_batch = sparse.csr_matrix( + ( + X_tbl["soma_data"].to_numpy(), + ( + obs_indexer.get_indexer(X_tbl["soma_dim_0"]), + var_indexer.get_indexer(X_tbl["soma_dim_1"]), + ), + ), + shape=(len(obs_coords), len(self._var_joinids)), + ) + + # Now that X read is potentially in progress (in eager mode), go fetch obs data + # + obs_io_batch = cast( + pd.DataFrame, + obs.read(coords=(obs_coords,), column_names=obs_column_names) + .concat() + .to_pandas() + .set_index("soma_joinid") + .reindex(obs_coords, copy=False) + .reset_index(), + ) + obs_io_batch = obs_io_batch[self.obs_column_names] + + del obs_indexer, obs_coords, X_tbl + gc.collect() + + tm = time.perf_counter() - st_time + logger.debug( + f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f} samples/sec" + ) + yield X_io_batch, obs_io_batch + + def _mini_batch_iter( + self, + obs: soma.DataFrame, + X: soma.SparseNDArray, + obs_joinid_iter: Iterator[npt.NDArray[np.int64]], + ) -> Iterator[XObsDatum]: + """Break IO batches into mini-batch-sized chunks. + + Private method. + """ + assert self._obs_joinids is not None + assert self._var_joinids is not None + + io_batch_iter = self._io_batch_iter(obs, X, obs_joinid_iter) + if self.use_eager_fetch: + io_batch_iter = _EagerIterator(io_batch_iter, pool=X.context.threadpool) + + mini_batch_size = self.batch_size + result: Tuple[NDArrayNumber, pd.DataFrame] | None = None + for X_io_batch, obs_io_batch in io_batch_iter: + assert X_io_batch.shape[0] == obs_io_batch.shape[0] + assert X_io_batch.shape[1] == len(self._var_joinids) + iob_idx = 0 # current offset into io batch + iob_len = X_io_batch.shape[0] + + while iob_idx < iob_len: + if result is None: + X_datum = ( + X_io_batch[iob_idx : iob_idx + mini_batch_size] + if self.return_sparse_X + else X_io_batch[iob_idx : iob_idx + mini_batch_size].toarray() + ) + result = ( + X_datum, + obs_io_batch.iloc[iob_idx : iob_idx + mini_batch_size], + ) + iob_idx += len(result[1]) + else: + # use any remnant from previous IO batch + to_take = min(mini_batch_size - len(result[1]), iob_len - iob_idx) + X_datum = ( + sparse.vstack([result[0], X_io_batch[0:to_take]]) + if self.return_sparse_X + else np.concatenate( + [result[0], X_io_batch[0:to_take].toarray()] + ) + ) + result = ( + X_datum, + pd.concat([result[1], obs_io_batch.iloc[0:to_take]]), + ) + iob_idx += to_take + + assert result[0].shape[0] == result[1].shape[0] + if result[0].shape[0] == mini_batch_size: + yield result + result = None + + else: + # yield the remnant, if any + if result is not None: + yield result + + +class ExperimentAxisQueryIterDataPipe( + torchdata.datapipes.iter.IterDataPipe[ # type:ignore[misc] + torch.utils.data.dataset.Dataset[XObsDatum] + ], +): + """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for + legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the + TorchData [README](https://github.com/pytorch/data/blob/v0.8.0/README.md) for more information. + + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str = "raw", + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + io_batch_size: int = 2**16, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + super().__init__() + self._exp_iter = ExperimentAxisQueryIterable( + query=query, + X_name=X_name, + obs_column_names=obs_column_names, + batch_size=batch_size, + io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + ) + + def __iter__(self) -> Iterator[XObsDatum]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + if batch_size == 1: + X = X[0] + yield X, obs + + def __len__(self) -> int: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + return self._exp_iter.__len__() + + @property + def shape(self) -> Tuple[int, int]: + """ + See :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset` for more information on using this class. + + Lifecycle: + deprecated + """ + return self._exp_iter.shape + + +class ExperimentAxisQueryIterableDataset( + torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc] +): + """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + + This class works seamlessly with :class:`torch.utils.data.DataLoader` to load ``obs`` and ``X`` data as + specified by a SOMA :class:`tiledbsoma.ExperimentAxisQuery`, providing an iterator over batches of + ``obs`` and ``X`` data. Each iteration will yield a tuple containing an :class:`numpy.ndarray` + and a :class:`pandas.DataFrame`. + + For example: + + >>> import torch + >>> import tiledbsoma + >>> import tiledbsoma_ml + >>> with tiledbsoma.Experiment.open("my_experiment_path") as exp: + with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: + ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query) + dataloader = torch.utils.data.DataLoader(ds) + >>> data = next(iter(dataloader)) + >>> data + (array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), + soma_joinid + 0 57905025) + >>> data[0] + array([0., 0., 0., ..., 0., 0., 0.], dtype=float32) + >>> data[1] + soma_joinid + 0 57905025 + + The ``batch_size`` parameter controls the number of rows of ``obs`` and ``X`` data that are returned in each + iteration. If the ``batch_size`` is 1, then each result will have rank 1, else it will have rank 2. A ``batch_size`` + of 1 is compatible with :class:`torch.utils.data.DataLoader`-implemented batching, but it will usually be more + performant to create mini-batches using this class, and set the ``DataLoader`` batch size to `None`. + + The ``obs_column_names`` parameter determines the data columns that are returned in the ``obs`` DataFrame (the + default is a single column, containing the ``soma_joinid`` for the ``obs`` dimension). + + The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A + larger value will increase total memory usage and may reduce average read time per row. + + This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader` + and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition + data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any + data tail will be dropped. + + Lifecycle: + experimental + """ + + def __init__( + self, + query: soma.ExperimentAxisQuery, + X_name: str = "raw", + obs_column_names: Sequence[str] = ("soma_joinid",), + batch_size: int = 1, + io_batch_size: int = 2**16, + return_sparse_X: bool = False, + use_eager_fetch: bool = True, + ): + """ + Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. + + The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as + a NumPy ``ndarray`` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas ``DataFrame`` respectively. + + Args: + query: + A :class:`tiledbsoma.ExperimentAxisQuery`, defining the data which will be iterated over. + X_name: + The name of the ``X`` layer to read. + obs_column_names: + The names of the ``obs`` columns to return. At least one column name must be specified. + Default is ``('soma_joinid',)``. + batch_size: + The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of + ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will + result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). + + Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` + batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` + batch_size parameter to ``None``. + io_batch_size: + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. + return_sparse_X: + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will + return ``X`` data as a :class:`numpy.ndarray`. + use_eager_fetch: + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made + available for processing via the iterator. This allows network (or filesystem) requests to be made in + parallel with client-side processing of the SOMA data, potentially improving overall performance at the + cost of doubling memory utilization. Defaults to ``True``. + + Returns: + An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to + a :class:`torch.data.utils.DataLoader` instance. + + Raises: + ``ValueError`` on various unsupported or malformed parameter values. + + Lifecycle: + experimental + + """ + super().__init__() + self._exp_iter = ExperimentAxisQueryIterable( + query=query, + X_name=X_name, + obs_column_names=obs_column_names, + batch_size=batch_size, + io_batch_size=io_batch_size, + return_sparse_X=return_sparse_X, + use_eager_fetch=use_eager_fetch, + ) + + def __iter__(self) -> Iterator[XObsDatum]: + """Create Iterator yielding tuples of :class:`numpy.ndarray` and :class:`pandas.DataFrame`. + + Returns: + ``iterator`` + + Lifecycle: + experimental + """ + batch_size = self._exp_iter.batch_size + for X, obs in self._exp_iter: + if batch_size == 1: + X = X[0] + yield X, obs + + def __len__(self) -> int: + """Return approximate number of batches this iterable will produce. + + See important caveats in the PyTorch + [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) + documentation regarding ``len(dataloader)``, which also apply to this class. + + Returns: + An ``int``. + + Lifecycle: + experimental + """ + return self._exp_iter.__len__() + + @property + def shape(self) -> Tuple[int, int]: + """Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`. + + This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode + (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect + the size of the partition of the data assigned to the active process. + + Returns: + A tuple of ``int``s, for obs and var counts, respectively. + + Lifecycle: + experimental + """ + return self._exp_iter.shape + + +def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: + """For `total_length` points, compute start/stop offsets that split the length into roughly equal sizes. + + A total_length of L, split into N sections, will return L%N sections of size L//N+1, + and the remainder as size L//N. This results in the same split as numpy.array_split, + for an array of length L and sections N. + + Private. + + Examples + -------- + >>> _splits(10, 3) + array([0, 4, 7, 10]) + >>> _splits(4, 2) + array([0, 2, 4]) + """ + if sections <= 0: + raise ValueError("number of sections must greater than 0.") from None + each_section, extras = divmod(total_length, sections) + per_section_sizes = ( + [0] + extras * [each_section + 1] + (sections - extras) * [each_section] + ) + splits = np.array(per_section_sizes, dtype=np.intp).cumsum() + return splits + + +if sys.version_info >= (3, 12): + _batched = itertools.batched + +else: + + def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: + """Same as the Python 3.12+ itertools.batched -- polyfill for old Python versions.""" + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := tuple(islice(it, n)): + yield batch + + +def _get_distributed_world_rank() -> Tuple[int, int]: + """Return tuple containing equivalent of torch.distributed world size and rank.""" + world_size, rank = 1, 0 + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: + # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There + # is a NODE_RANK for the node's rank, but no way to tell the local node's + # world. So computing a global rank is impossible(?). Using LOCAL_RANK as a + # proxy, which works fine on a single-CPU box. TODO: could throw/error + # if NODE_RANK != 0. + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["LOCAL_RANK"]) + elif torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + return world_size, rank + + +def _get_worker_world_rank() -> Tuple[int, int]: + """Return number of DataLoader workers and our worker rank/id""" + num_workers, worker = 1, 0 + if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: + num_workers = int(os.environ["NUM_WORKERS"]) + worker = int(os.environ["WORKER"]) + else: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + num_workers = worker_info.num_workers + worker = worker_info.id + return num_workers, worker diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py new file mode 100644 index 0000000..1e00c5c --- /dev/null +++ b/tests/test_pytorch.py @@ -0,0 +1,626 @@ +# Copyright (c) 2021-2024 The Chan Zuckerberg Initiative Foundation +# Copyright (c) 2021-2024 TileDB, Inc. +# +# Licensed under the MIT License. + +from __future__ import annotations + +import pathlib +from typing import Callable, List, Optional, Sequence, Union +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +import tiledbsoma as soma +from scipy import sparse +from scipy.sparse import coo_matrix, spmatrix +from tiledbsoma import Experiment, _factory +from tiledbsoma._collection import CollectionBase + +# Conditionally import torch, as it will not be available in all test environments. +# This supports the pytest `ml` mark, which can be used to disable all PyTorch-dependent +# tests. +try: + from torch.utils.data._utils.worker import WorkerInfo + + from tiledbsoma_ml.pytorch import ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, + ) +except ImportError: + # This should only occur when not running `ml`-marked tests + pass + + +# These control which classes are tested (for most, but not all tests). +# Centralized to allow easy add/delete of specific test parameters. +PipeClassType = Union[ + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, +] +PipeClassImplementation = ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterDataPipe, + ExperimentAxisQueryIterableDataset, +) + + +def pytorch_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + occupied_shape = ( + obs_range.stop - obs_range.start, + var_range.stop - var_range.start, + ) + checkerboard_of_ones = coo_matrix(np.indices(occupied_shape).sum(axis=0) % 2) + checkerboard_of_ones.row += obs_range.start + checkerboard_of_ones.col += var_range.start + return checkerboard_of_ones + + +def pytorch_seq_x_value_gen(obs_range: range, var_range: range) -> spmatrix: + """A sparse matrix where the values of each col are the obs_range values. Useful for checking the + X values are being returned in the correct order.""" + data = np.vstack([list(obs_range)] * len(var_range)).flatten() + rows = np.vstack([list(obs_range)] * len(var_range)).flatten() + cols = np.column_stack([list(var_range)] * len(obs_range)).flatten() + return coo_matrix((data, (rows, cols))) + + +@pytest.fixture +def X_layer_names() -> List[str]: + return ["raw"] + + +@pytest.fixture +def obsp_layer_names() -> Optional[List[str]]: + return None + + +@pytest.fixture +def varp_layer_names() -> Optional[List[str]]: + return None + + +def add_dataframe(coll: CollectionBase, key: str, value_range: range) -> None: + df = coll.add_new_dataframe( + key, + schema=pa.schema( + [ + ("soma_joinid", pa.int64()), + ("label", pa.large_string()), + ("label2", pa.large_string()), + ] + ), + index_column_names=["soma_joinid"], + ) + df.write( + pa.Table.from_pydict( + { + "soma_joinid": list(value_range), + "label": [str(i) for i in value_range], + "label2": ["c" for i in value_range], + } + ) + ) + + +def add_sparse_array( + coll: CollectionBase, + key: str, + obs_range: range, + var_range: range, + value_gen: Callable[[range, range], spmatrix], +) -> None: + a = coll.add_new_sparse_ndarray( + key, type=pa.float32(), shape=(obs_range.stop, var_range.stop) + ) + tensor = pa.SparseCOOTensor.from_scipy(value_gen(obs_range, var_range)) + a.write(tensor) + + +@pytest.fixture(scope="function") +def soma_experiment( + tmp_path: pathlib.Path, + obs_range: Union[int, range], + var_range: Union[int, range], + X_value_gen: Callable[[range, range], sparse.spmatrix], + obsp_layer_names: Sequence[str], + varp_layer_names: Sequence[str], +) -> soma.Experiment: + with soma.Experiment.create((tmp_path / "exp").as_posix()) as exp: + if isinstance(obs_range, int): + obs_range = range(obs_range) + if isinstance(var_range, int): + var_range = range(var_range) + + add_dataframe(exp, "obs", obs_range) + ms = exp.add_new_collection("ms") + rna = ms.add_new_collection("RNA", soma.Measurement) + add_dataframe(rna, "var", var_range) + rna_x = rna.add_new_collection("X", soma.Collection) + add_sparse_array(rna_x, "raw", obs_range, var_range, X_value_gen) + + if obsp_layer_names: + obsp = rna.add_new_collection("obsp") + for obsp_layer_name in obsp_layer_names: + add_sparse_array( + obsp, obsp_layer_name, obs_range, var_range, X_value_gen + ) + + if varp_layer_names: + varp = rna.add_new_collection("varp") + for varp_layer_name in varp_layer_names: + add_sparse_array( + varp, varp_layer_name, obs_range, var_range, X_value_gen + ) + return _factory.open((tmp_path / "exp").as_posix()) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_non_batched( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, + return_sparse_X: bool, +) -> None: + # batch_size should default to 1 + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, + ) + assert type(exp_data_pipe.shape) is tuple + assert len(exp_data_pipe.shape) == 2 + assert exp_data_pipe.shape == (6, 3) + + row_iter = iter(exp_data_pipe) + + row = next(row_iter) + + if return_sparse_X: + # sparse slices remain 2D, always + assert isinstance(row[0], sparse.csr_matrix) + assert row[0].shape == (1, 3) + assert row[0].todense().tolist() == [[0, 1, 0]] + + else: + assert isinstance(row[0], np.ndarray) + print(row) + assert np.squeeze(row[0]).shape == (3,) + assert np.squeeze(row[0]).tolist() == [0, 1, 0] + + assert isinstance(row[1], pd.DataFrame) + assert row[1].shape == (1, 1) + assert row[1].keys() == ["label"] + assert row[1]["label"].tolist() == ["0"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("return_sparse_X", [True, False]) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_uneven_soma_and_result_batches( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, + return_sparse_X: bool, +) -> None: + """This is checking that batches are correctly created when they require fetching multiple chunks.""" + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + io_batch_size=2, + use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, + ) + row_iter = iter(exp_data_pipe) + + X_batch, obs_batch = next(row_iter) + + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + assert X_batch.todense()[0].tolist() == [[0, 1, 0]] + else: + assert isinstance(X_batch, np.ndarray) + assert X_batch[0].tolist() == [0, 1, 0] + + assert isinstance(obs_batch, pd.DataFrame) + assert X_batch.shape[0] == obs_batch.shape[0] + assert X_batch.shape == (3, 3) + assert obs_batch.shape == (3, 1) + assert ["label"] == obs_batch.keys() + assert obs_batch["label"].tolist() == ["0", "1", "2"] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_batching__all_batches_full_size( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1].keys() == ["label"] + assert batch[1]["label"].tolist() == ["0", "1", "2"] + + batch = next(batch_iter) + assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert batch[1].keys() == ["label"] + assert batch[1]["label"].tolist() == ["3", "4", "5"] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (range(100_000_000, 100_000_003), 3, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_unique_soma_joinids( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid", "label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + + soma_joinids = np.concatenate( + [batch[1]["soma_joinid"].to_numpy() for batch in exp_data_pipe] + ) + assert len(np.unique(soma_joinids)) == len(soma_joinids) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(5, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_batching__partial_final_batch_size( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + next(batch_iter) + batch = next(batch_iter) + assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0]] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_batching__exactly_one_batch( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert batch[1]["label"].tolist() == ["0", "1", "2"] + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_batching__empty_query_result( + PipeClass: PipeClassType, + soma_experiment: Experiment, + use_eager_fetch: bool, +) -> None: + with soma_experiment.axis_query( + measurement_name="RNA", obs_query=soma.AxisQuery(coords=([],)) + ) as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + with pytest.raises(StopIteration): + next(batch_iter) + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_sparse_output__non_batched( + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + return_sparse_X=True, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[0], sparse.csr_matrix) + assert batch[0].todense().A.squeeze().tolist() == [0, 1, 0] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_sparse_output__batched( + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + return_sparse_X=True, + use_eager_fetch=use_eager_fetch, + ) + batch_iter = iter(exp_data_pipe) + + batch = next(batch_iter) + assert isinstance(batch[0], sparse.csr_matrix) + assert batch[0].todense().tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [ + (10, 1, pytorch_x_value_gen, use_eager_fetch) + for use_eager_fetch in (True, False) + ], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_batching__partial_soma_batches_are_concatenated( + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + exp_data_pipe = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + # set SOMA batch read size such that PyTorch batches will span the tail and head of two SOMA batches + io_batch_size=4, + use_eager_fetch=use_eager_fetch, + ) + + full_result = list(exp_data_pipe) + + assert [len(batch[0]) for batch in full_result] == [3, 3, 3, 1] + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen), (7, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize( + "world_size,rank", + [(3, 0), (3, 1), (3, 2), (2, 0), (2, 1)], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_distributed__returns_data_partition_for_rank( + PipeClass: PipeClassType, + soma_experiment: Experiment, + obs_range: int, + world_size: int, + rank: int, +) -> None: + """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode, + using mocks to avoid having to do real PyTorch distributed setup.""" + + with ( + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + ) + full_result = list(iter(dp)) + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + + expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ + 0 : obs_range // world_size + ].tolist() + assert sorted(soma_joinids) == expected_joinids + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", + [(12, 3, pytorch_x_value_gen), (13, 3, pytorch_x_value_gen)], +) +@pytest.mark.parametrize( + "world_size,rank,num_workers,worker_id", + [ + (3, 1, 2, 0), + (3, 1, 2, 1), + ], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test_distributed_and_multiprocessing__returns_data_partition_for_rank( + PipeClass: PipeClassType, + soma_experiment: Experiment, + obs_range: int, + world_size: int, + rank: int, + num_workers: int, + worker_id: int, +) -> None: + """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode and + DataLoader multiprocessing mode, using mocks to avoid having to do distributed pytorch + setup or real DataLoader multiprocessing.""" + + with ( + patch("torch.utils.data.get_worker_info") as mock_get_worker_info, + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): + mock_get_worker_info.return_value = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=1234 + ) + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size + + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + ) + + full_result = list(iter(dp)) + + soma_joinids = np.concatenate( + [t[1]["soma_joinid"].to_numpy() for t in full_result] + ) + + expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ + 0 : obs_range // world_size + ] + expected_joinids = np.array_split(expected_joinids, num_workers)[worker_id] + assert sorted(soma_joinids) == expected_joinids.tolist() + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen,use_eager_fetch", + [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test__X_tensor_dtype_matches_X_matrix( + PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + obs_column_names=["label"], + batch_size=3, + use_eager_fetch=use_eager_fetch, + ) + data = next(iter(dp)) + + assert data[0].dtype == np.float32 + + +def test_batched() -> None: + from tiledbsoma_ml.pytorch import _batched + + assert list(_batched(range(6), 1)) == list((i,) for i in range(6)) + assert list(_batched(range(6), 2)) == [(0, 1), (2, 3), (4, 5)] + assert list(_batched(range(6), 3)) == [(0, 1, 2), (3, 4, 5)] + assert list(_batched(range(6), 4)) == [(0, 1, 2, 3), (4, 5)] + assert list(_batched(range(6), 5)) == [(0, 1, 2, 3, 4), (5,)] + assert list(_batched(range(6), 6)) == [(0, 1, 2, 3, 4, 5)] + assert list(_batched(range(6), 7)) == [(0, 1, 2, 3, 4, 5)] + + # bogus batch value + with pytest.raises(ValueError): + list(_batched([0, 1], 0)) + with pytest.raises(ValueError): + list(_batched([2, 3], -1)) + + +def test_splits() -> None: + from tiledbsoma_ml.pytorch import _splits + + assert _splits(10, 1).tolist() == [0, 10] + assert _splits(10, 3).tolist() == [0, 4, 7, 10] + assert _splits(10, 4).tolist() == [0, 3, 6, 8, 10] + assert _splits(10, 10).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert _splits(10, 11).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10] + + # bad number of sections + with pytest.raises(ValueError): + _splits(10, 0) + with pytest.raises(ValueError): + _splits(10, -1) From ec181fd459464ba22c57284edae4bdc211d61c76 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 25 Sep 2024 11:03:13 -0400 Subject: [PATCH 02/25] Fix GHA PR trigger --- .github/workflows/python-tiledbsoma-ml.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 9b6911d..6d185a8 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -1,6 +1,12 @@ name: python-tiledbsoma-ml on: + pull_request: + branches: ["**"] + paths-ignore: ['scripts/**'] + push: + branches: [main] + paths-ignore: ['scripts/**'] workflow_dispatch: jobs: From 9325a4f39fbcb617f2efde0b3d6c5900f27f4073 Mon Sep 17 00:00:00 2001 From: Bruce Martin Date: Wed, 25 Sep 2024 09:00:00 -0700 Subject: [PATCH 03/25] add GHA workflows for basic CI (PR 4 of N) (#7) * add GHA workflows; remove temp exclusion from pre-commit config * GHA tweak, attempt to trigger PR runs --------- Co-authored-by: Ryan Williams --- .github/workflows/python-tiledbsoma-ml.yml | 59 ++++++++++++++++++- .../python-tilledbsoma-ml-compat.yml | 44 +++++++++++++- .pre-commit-config.yaml | 3 - 3 files changed, 97 insertions(+), 9 deletions(-) diff --git a/.github/workflows/python-tiledbsoma-ml.yml b/.github/workflows/python-tiledbsoma-ml.yml index 6d185a8..4a1e9f1 100644 --- a/.github/workflows/python-tiledbsoma-ml.yml +++ b/.github/workflows/python-tiledbsoma-ml.yml @@ -1,4 +1,4 @@ -name: python-tiledbsoma-ml +name: python-tiledbsoma-ml CI on: pull_request: @@ -10,7 +10,60 @@ on: workflow_dispatch: jobs: - job: + lint: runs-on: ubuntu-latest steps: - # Empty job; placeholder GHA + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Restore pre-commit cache + uses: actions/cache@v4 + with: + path: ~/.cache/pre-commit + key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} + + - name: Install pre-commit + run: pip -v install pre-commit + + - name: Run pre-commit hooks on all files + run: pre-commit run -v -a + + tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install prereqs + run: | + pip install --upgrade pip wheel pytest pytest-cov setuptools + pip install . + + - name: Run tests + run: pytest -v --cov=src --cov-report=xml tests + + build: + # for now, just do a test build to ensure that it works + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Do build + run: | + pip install --upgrade build pip wheel setuptools setuptools-scm + python -m build . diff --git a/.github/workflows/python-tilledbsoma-ml-compat.yml b/.github/workflows/python-tilledbsoma-ml-compat.yml index a83ed31..026502a 100644 --- a/.github/workflows/python-tilledbsoma-ml-compat.yml +++ b/.github/workflows/python-tilledbsoma-ml-compat.yml @@ -1,9 +1,47 @@ name: python-tiledbsoma-ml past tiledbsoma compat # Latest tiledbsoma version covered by another workflow + on: + pull_request: + branches: ["*"] + paths-ignore: + - "scripts/**" + - "notebooks/**" + push: + branches: [main] + paths-ignore: + - "scripts/**" + - "notebooks/**" workflow_dispatch: jobs: - job: - runs-on: ubuntu-latest + unit_tests: + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] # could add 'macos-latest', but the matrix is already huge... + python-version: ["3.9", "3.10", "3.11"] # TODO: add 3.12 when tiledbsoma releases wheels for it. + pkg-version: + - "tiledbsoma~=1.9.0 'numpy<2.0.0'" + - "tiledbsoma~=1.10.0 'numpy<2.0.0'" + - "tiledbsoma~=1.11.0" + - "tiledbsoma~=1.12.0" + - "tiledbsoma~=1.13.0" + - "tiledbsoma~=1.14.0" + + runs-on: ${{ matrix.os }} + steps: - # Empty job; placeholder GHA + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install prereqs + run: | + pip install --upgrade pip pytest setuptools + pip install ${{ matrix.pkg-version }} . + + - name: Run tests + run: pytest -v tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e0f39b5..6385ca9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,6 @@ repos: rev: "24.8.0" hooks: - id: black - exclude: 'apis/' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.6.5 @@ -11,7 +10,6 @@ repos: - id: ruff name: "ruff for tiledbsoma_ml" args: ["--config=pyproject.toml"] - exclude: 'apis/' - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.11.2 @@ -23,4 +21,3 @@ repos: - attrs - numpy - pandas-stubs>=2 - exclude: 'apis/' From d9c0ef410967056120a395f827d318c8f1300cf9 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 25 Sep 2024 12:25:59 -0400 Subject: [PATCH 04/25] trigger `python-tilledbsoma-ml-compat.yml` GHA on all PR branches --- .github/workflows/python-tilledbsoma-ml-compat.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-tilledbsoma-ml-compat.yml b/.github/workflows/python-tilledbsoma-ml-compat.yml index 026502a..742cc05 100644 --- a/.github/workflows/python-tilledbsoma-ml-compat.yml +++ b/.github/workflows/python-tilledbsoma-ml-compat.yml @@ -2,7 +2,7 @@ name: python-tiledbsoma-ml past tiledbsoma compat # Latest tiledbsoma version co on: pull_request: - branches: ["*"] + branches: ["**"] paths-ignore: - "scripts/**" - "notebooks/**" From 9d9e20d2eaa660064012d593e907df1902aa958c Mon Sep 17 00:00:00 2001 From: bkmartinjr Date: Tue, 1 Oct 2024 07:56:39 -0700 Subject: [PATCH 05/25] cleanup based on PR feedback --- tests/test_pytorch.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 1e00c5c..9142125 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -18,22 +18,13 @@ from scipy.sparse import coo_matrix, spmatrix from tiledbsoma import Experiment, _factory from tiledbsoma._collection import CollectionBase +from torch.utils.data._utils.worker import WorkerInfo -# Conditionally import torch, as it will not be available in all test environments. -# This supports the pytest `ml` mark, which can be used to disable all PyTorch-dependent -# tests. -try: - from torch.utils.data._utils.worker import WorkerInfo - - from tiledbsoma_ml.pytorch import ( - ExperimentAxisQueryIterable, - ExperimentAxisQueryIterableDataset, - ExperimentAxisQueryIterDataPipe, - ) -except ImportError: - # This should only occur when not running `ml`-marked tests - pass - +from tiledbsoma_ml.pytorch import ( + ExperimentAxisQueryIterable, + ExperimentAxisQueryIterableDataset, + ExperimentAxisQueryIterDataPipe, +) # These control which classes are tested (for most, but not all tests). # Centralized to allow easy add/delete of specific test parameters. @@ -196,7 +187,6 @@ def test_non_batched( else: assert isinstance(row[0], np.ndarray) - print(row) assert np.squeeze(row[0]).shape == (3,) assert np.squeeze(row[0]).tolist() == [0, 1, 0] From 1ca11e137c5244fd19d3b2e87341e910d2a4df21 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Mon, 30 Sep 2024 23:16:14 -0400 Subject: [PATCH 06/25] use 3.11 in precommit tools --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7901253..416e46f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ show_error_codes = true ignore_missing_imports = true warn_unreachable = true strict = true -python_version = 3.9 +python_version = '3.11' plugins = "numpy.typing.mypy_plugin" [tool.ruff] @@ -72,5 +72,5 @@ lint.select = ["E", "F", "B", "I"] lint.ignore = ["E501"] # line too long lint.extend-select = ["I001"] # unsorted-imports fix = true -target-version = "py39" +target-version = "py311" line-length = 120 From 0031292b11e5f25e746f3916a0fb6d4f92109b00 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Mon, 30 Sep 2024 23:17:27 -0400 Subject: [PATCH 07/25] type tweaks --- src/tiledbsoma_ml/pytorch.py | 36 ++++++++++++------------------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 32d90b3..e394e0a 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -16,17 +16,16 @@ from itertools import islice from math import ceil from typing import ( - TYPE_CHECKING, Any, ContextManager, Dict, + Generator, Iterable, Iterator, Sequence, Tuple, TypeVar, Union, - cast, ) import attrs @@ -38,24 +37,15 @@ import torch import torchdata from somacore.query._eager_iter import EagerIterator as _EagerIterator -from typing_extensions import TypeAlias logger = logging.getLogger("tiledbsoma_ml.pytorch") _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) -if TYPE_CHECKING: - # Python 3.8 does not support subscripting types, so work-around by - # restricting this to when we are running a type checker. TODO: remove - # the conditional when Python 3.8 support is dropped. - NDArrayNumber: TypeAlias = npt.NDArray[np.number[Any]] - XDatum: TypeAlias = Union[NDArrayNumber, sparse.csr_matrix] -else: - NDArrayNumber: TypeAlias = np.ndarray - XDatum: TypeAlias = Union[np.ndarray, sparse.csr_matrix] - -XObsDatum: TypeAlias = Tuple[XDatum, pd.DataFrame] +NDArrayNumber = npt.NDArray[np.number[Any]] +XDatum = Union[NDArrayNumber, sparse.csr_matrix] +XObsDatum = Tuple[XDatum, pd.DataFrame] """Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, which pairs a slice of ``X`` rows with a cooresponding slice of ``obs``. In the default case, the datum is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs`` @@ -86,12 +76,11 @@ def create(cls, experiment: soma.Experiment) -> "_ExperimentLocator": ) @contextmanager - def open_experiment(self) -> Iterator[soma.Experiment]: + def open_experiment(self) -> Generator[soma.Experiment, None, None]: context = soma.SOMATileDBContext(tiledb_config=self.tiledb_config) - with soma.Experiment.open( + yield soma.Experiment.open( self.uri, tiledb_timestamp=self.tiledb_timestamp_ms, context=context - ) as exp: - yield exp + ) class ExperimentAxisQueryIterable(Iterable[XObsDatum]): @@ -398,17 +387,16 @@ def _io_batch_iter( ) # Now that X read is potentially in progress (in eager mode), go fetch obs data - # - obs_io_batch = cast( - pd.DataFrame, + # fmt: off + obs_io_batch = ( obs.read(coords=(obs_coords,), column_names=obs_column_names) .concat() .to_pandas() .set_index("soma_joinid") .reindex(obs_coords, copy=False) - .reset_index(), - ) - obs_io_batch = obs_io_batch[self.obs_column_names] + .reset_index() # demote "soma_joinid" to a column + [self.obs_column_names] + ) # fmt: on del obs_indexer, obs_coords, X_tbl gc.collect() From db26af402b22218a7b0981bda1556322120cbd6c Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Mon, 30 Sep 2024 23:19:35 -0400 Subject: [PATCH 08/25] docstr/comment nits --- src/tiledbsoma_ml/pytorch.py | 37 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index e394e0a..cf1abba 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -47,7 +47,7 @@ XDatum = Union[NDArrayNumber, sparse.csr_matrix] XObsDatum = Tuple[XDatum, pd.DataFrame] """Return type of ``ExperimentAxisQueryIterableDataset`` and ``ExperimentAxisQueryIterDataPipe``, -which pairs a slice of ``X`` rows with a cooresponding slice of ``obs``. In the default case, +which pairs a slice of ``X`` rows with a corresponding slice of ``obs``. In the default case, the datum is a tuple of :class:`numpy.ndarray` and :class:`pandas.DataFrame` (for ``X`` and ``obs`` respectively). If the object is created with ``return_sparse_X`` as True, the ``X`` slice is returned as a :class:`scipy.sparse.csr_matrix`. If the ``batch_size`` is 1, the :class:`numpy.ndarray` @@ -58,7 +58,7 @@ class _ExperimentLocator: """State required to open the Experiment. - Necessary as we will likely be invoked across multiple processes. + Serializable across multiple processes. Private implementation class. """ @@ -111,8 +111,8 @@ def __init__( Construct a new ``ExperimentAxisQueryIterable``, suitable for use with :class:`torch.utils.data.DataLoader`. The resulting iterator will produce a tuple containing associated slices of ``X`` and ``obs`` data, as - a NumPy :class:`numpy.ndarray` (or optionally, :class:`scipy.sparse.csr_matrix`) and a - Pandas :class:`pandas.DataFrame`, respectively. + a NumPy :class:`numpy.ndarray` (or optionally, :class:`scipy.sparse.csr_matrix`) and a Pandas + :class:`pandas.DataFrame`, respectively. Args: query: @@ -125,16 +125,16 @@ def __init__( batch_size: The number of rows of ``X`` and ``obs`` data to return in each iteration. Defaults to ``1``. A value of ``1`` will result in :class:`torch.Tensor` of rank 1 being returned (a single row); larger values will - result in :class:`torch.Tensor`\ s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows - this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` - batching, but higher performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s + result in :class:`torch.Tensor`s of rank 2 (multiple rows). Note that a ``batch_size`` of 1 allows + this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` batching, but higher + performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s ``batch_size`` parameter to ``None``. io_batch_size: The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts maximum memory utilization, larger values provide better read performance, but require more memory. return_sparse_X: - If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will - return ``X`` data as a :class:`numpy.ndarray`. + If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the + default), will return ``X`` data as a :class:`numpy.ndarray`. use_eager_fetch: Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made available for processing via the iterator. This allows network (or filesystem) requests to be made in @@ -155,7 +155,7 @@ def __init__( super().__init__() - # Anything set in the instance needs to be picklable for multi-process DataLoaders + # Anything set in the instance needs to be pickle-able for multi-process DataLoaders self.experiment_locator = _ExperimentLocator.create(query.experiment) self.layer_name = X_name self.measurement_name = query.measurement_name @@ -228,7 +228,7 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: def _init_once(self, exp: soma.Experiment | None = None) -> None: """One-time per worker initialization. - All operations be idempotent in order to support pipe reset(). + All operations should be idempotent in order to support pipe reset(). Private method. """ @@ -564,9 +564,9 @@ class ExperimentAxisQueryIterableDataset( >>> import tiledbsoma >>> import tiledbsoma_ml >>> with tiledbsoma.Experiment.open("my_experiment_path") as exp: - with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: - ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query) - dataloader = torch.utils.data.DataLoader(ds) + ... with exp.axis_query(measurement_name="RNA", obs_query=tiledbsoma.AxisQuery(value_filter="tissue_type=='lung'")) as query: + ... ds = tiledbsoma_ml.ExperimentAxisQueryIterableDataset(query) + ... dataloader = torch.utils.data.DataLoader(ds) >>> data = next(iter(dataloader)) >>> data (array([0., 0., 0., ..., 0., 0., 0.], dtype=float32), @@ -711,7 +711,7 @@ def shape(self) -> Tuple[int, int]: def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: - """For `total_length` points, compute start/stop offsets that split the length into roughly equal sizes. + """For ``total_length`` points, compute start/stop offsets that split the length into roughly equal sizes. A total_length of L, split into N sections, will return L%N sections of size L//N+1, and the remainder as size L//N. This results in the same split as numpy.array_split, @@ -738,11 +738,10 @@ def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]: if sys.version_info >= (3, 12): _batched = itertools.batched - else: def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: - """Same as the Python 3.12+ itertools.batched -- polyfill for old Python versions.""" + """Same as the Python 3.12+ ``itertools.batched`` -- polyfill for old Python versions.""" if n < 1: raise ValueError("n must be at least one") it = iter(iterable) @@ -751,13 +750,13 @@ def _batched(iterable: Iterable[_T_co], n: int) -> Iterator[Tuple[_T_co, ...]]: def _get_distributed_world_rank() -> Tuple[int, int]: - """Return tuple containing equivalent of torch.distributed world size and rank.""" + """Return tuple containing equivalent of ``torch.distributed`` world size and rank.""" world_size, rank = 1, 0 if "RANK" in os.environ and "WORLD_SIZE" in os.environ: world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: - # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There + # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There # is a NODE_RANK for the node's rank, but no way to tell the local node's # world. So computing a global rank is impossible(?). Using LOCAL_RANK as a # proxy, which works fine on a single-CPU box. TODO: could throw/error From 579727ad0a96ce911b8e3f4a7c52d0485e4d0c44 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Mon, 30 Sep 2024 23:19:52 -0400 Subject: [PATCH 09/25] f-string debugging --- src/tiledbsoma_ml/pytorch.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index cf1abba..7af1e57 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -217,10 +217,9 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: ].copy() if logger.isEnabledFor(logging.DEBUG): + partition_size = sum([len(chunk) for chunk in obs_partition_joinids]) logger.debug( - f"Process {os.getpid()} rank={rank}, world_size={world_size}, worker_id={worker_id}, " - f"n_workers={n_workers}, " - f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" + f"Process {os.getpid()} {rank=}, {world_size=}, {worker_id=}, n_workers={n_workers}, {partition_size=}" ) return iter(obs_partition_joinids) @@ -280,7 +279,7 @@ def __iter__(self) -> Iterator[XObsDatum]: world_size, rank = _get_distributed_world_rank() n_workers, worker_id = _get_worker_world_rank() logger.debug( - f"Iterator created rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}" + f"Iterator created {rank=}, {world_size=}, {worker_id=}, {n_workers=}" ) with self.experiment_locator.open_experiment() as exp: From 35c19253136fc4c4558ca44c725a9728bac96700 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 2 Oct 2024 16:00:47 -0400 Subject: [PATCH 10/25] docstring updates, attempt to make `shape`/`len` more precise --- src/tiledbsoma_ml/pytorch.py | 84 +++++++++++++++++------------------- 1 file changed, 40 insertions(+), 44 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 7af1e57..ff0ff3a 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -84,14 +84,14 @@ def open_experiment(self) -> Generator[soma.Experiment, None, None]: class ExperimentAxisQueryIterable(Iterable[XObsDatum]): - """An :class:`Iterator` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as + """An :class:`Iterable` which reads ``X`` and ``obs`` data from a :class:`tiledbsoma.Experiment`, as selected by a user-specified :class:`tiledbsoma.ExperimentAxisQuery`. Each step of the iterator - produces equal sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and + produces a batch containing equal-sized ``X`` and ``obs`` data, in the form of a :class:`numpy.ndarray` and :class:`pandas.DataFrame`, respectively. Private base class for subclasses of :class:`torch.utils.data.IterableDataset` and :class:`torchdata.datapipes.iter.IterDataPipe`. Refer to :class:`ExperimentAxisQueryIterableDataset` - and `ExperimentAxisQueryIterDataPipe` for more details on usage. + and :class:`ExperimentAxisQueryIterDataPipe` for more details on usage. Lifecycle: experimental @@ -136,14 +136,10 @@ def __init__( If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will return ``X`` data as a :class:`numpy.ndarray`. use_eager_fetch: - Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made - available for processing via the iterator. This allows network (or filesystem) requests to be made in - parallel with client-side processing of the SOMA data, potentially improving overall performance at the - cost of doubling memory utilization. Defaults to ``True``. - - Returns: - An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to - a :class:`torch.utils.data.DataLoader` instance. + Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is + made available for processing via the iterator. This allows network (or filesystem) requests to be made + in parallel with client-side processing of the SOMA data, potentially improving overall performance at + the cost of doubling memory utilization. Defaults to ``True``. Raises: ``ValueError`` on various unsupported or malformed parameter values. @@ -300,16 +296,16 @@ def __iter__(self) -> Iterator[XObsDatum]: yield from _mini_batch_iter def __len__(self) -> int: - """Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or - as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) - count will reflect the size of the data partition assigned to the active process. + """Return the number of batches this iterable will produce. If run in the context of :class:`torch.distributed` + or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the + batch count will reflect the size of the data partition assigned to the active process. See important caveats in the PyTorch [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) documentation regarding ``len(dataloader)``, which also apply to this class. Returns: - An ``int``. + ``int`` (Number of batches). Lifecycle: experimental @@ -318,13 +314,13 @@ def __len__(self) -> int: @property def shape(self) -> Tuple[int, int]: - """Get the approximate shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. - This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode - (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect - the size of the data partition assigned to the active process. + """Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. + + If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), + the number of batches will reflect the size of the data partition assigned to the active process. Returns: - A tuple of two ``int`` values: number of obs, number of vars. + A tuple of two ``int`` values: number of batches, number of vars. Lifecycle: experimental @@ -332,11 +328,17 @@ def shape(self) -> Tuple[int, int]: self._init_once() assert self._obs_joinids is not None assert self._var_joinids is not None - world_size, _ = _get_distributed_world_rank() - n_workers, _ = _get_worker_world_rank() - partition_len = len(self._obs_joinids) // world_size // n_workers - div, rem = divmod(partition_len, self.batch_size) - return div + bool(rem), len(self._var_joinids) + world_size, rank = _get_distributed_world_rank() + n_workers, worker_id = _get_worker_world_rank() + # Every "distributed" process must receive the same number of "obs" rows; the last ≤world_size may be dropped + # (see _create_obs_joinids_partition). + obs_per_proc = len(self._obs_joinids) // world_size + obs_per_worker, obs_rem = divmod(obs_per_proc, n_workers) + # obs rows assigned to this worker process + n_worker_obs = obs_per_worker + bool(worker_id < obs_rem) + n_batches, rem = divmod(n_worker_obs, self.batch_size) + # (num batches this worker will produce, num features) + return n_batches + bool(rem), len(self._var_joinids) def __getitem__(self, index: int) -> XObsDatum: raise NotImplementedError( @@ -349,11 +351,9 @@ def _io_batch_iter( X: soma.SparseNDArray, obs_joinid_iter: Iterator[npt.NDArray[np.int64]], ) -> Iterator[Tuple[sparse.csr_matrix, pd.DataFrame]]: - """Iterate over IO batches, i.e., SOMA query/read, producing a tuple of - (X: csr_array, obs: DataFrame). + """Iterate over IO batches, i.e., SOMA query reads, producing tuples of ``(X: csr_array, obs: DataFrame)``. - obs joinids read are controlled by the obs_joinid_iter. Iterator results will - be reindexed. + ``obs`` joinids read are controlled by the ``obs_joinid_iter``. Iterator results will be reindexed. Private method. """ @@ -475,7 +475,7 @@ class ExperimentAxisQueryIterDataPipe( torch.utils.data.dataset.Dataset[XObsDatum] ], ): - """A :class:`torch.utils.data.IterableDataset` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. + """A :class:`torchdata.datapipes.iter.IterDataPipe` implementation that loads from a :class:`tiledbsoma.SOMAExperiment`. This class is based upon the now-deprecated :class:`torchdata.datapipes` API, and should only be used for legacy code. See [GitHub issue #1196](https://github.com/pytorch/data/issues/1196) and the @@ -534,7 +534,7 @@ def __len__(self) -> int: Lifecycle: deprecated """ - return self._exp_iter.__len__() + return len(self._exp_iter) @property def shape(self) -> Tuple[int, int]: @@ -640,10 +640,6 @@ def __init__( parallel with client-side processing of the SOMA data, potentially improving overall performance at the cost of doubling memory utilization. Defaults to ``True``. - Returns: - An ``iterable``, which can be iterated over using the Python ``iter()`` statement, or passed directly to - a :class:`torch.data.utils.DataLoader` instance. - Raises: ``ValueError`` on various unsupported or malformed parameter values. @@ -663,7 +659,8 @@ def __init__( ) def __iter__(self) -> Iterator[XObsDatum]: - """Create Iterator yielding tuples of :class:`numpy.ndarray` and :class:`pandas.DataFrame`. + """Create ``Iterator`` yielding "mini-batch" tuples of :class:`numpy.ndarray` (or :class:`scipy.csr_matrix`) and + :class:`pandas.DataFrame`. Returns: ``iterator`` @@ -678,30 +675,29 @@ def __iter__(self) -> Iterator[XObsDatum]: yield X, obs def __len__(self) -> int: - """Return approximate number of batches this iterable will produce. + """Return number of batches this iterable will produce. See important caveats in the PyTorch [:class:`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) documentation regarding ``len(dataloader)``, which also apply to this class. Returns: - An ``int``. + ``int`` (number of batches). Lifecycle: experimental """ - return self._exp_iter.__len__() + return len(self._exp_iter) @property def shape(self) -> Tuple[int, int]: - """Get the shape of the data that will be returned by this :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`. + """Return the number of batches and features that will be yielded from this :class:`tiledbsoma_ml.ExperimentAxisQueryIterable`. - This is the number of obs (cell) and var (feature) counts in the returned data. If used in multiprocessing mode - (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) count will reflect - the size of the partition of the data assigned to the active process. + If used in multiprocessing mode (i.e. :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), + the number of batches will reflect the size of the data partition assigned to the active process. Returns: - A tuple of ``int``s, for obs and var counts, respectively. + A tuple of two ``int`` values: number of batches, number of vars. Lifecycle: experimental From 9c70271dd392703f3f06c698982212dab9fcd825 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Wed, 2 Oct 2024 17:57:58 -0400 Subject: [PATCH 11/25] test parametrize `use_eager_fetch` --- tests/test_pytorch.py | 61 +++++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 9142125..a62e515 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -151,9 +151,10 @@ def soma_experiment( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_non_batched( @@ -197,9 +198,10 @@ def test_non_batched( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_uneven_soma_and_result_batches( @@ -239,9 +241,10 @@ def test_uneven_soma_and_result_batches( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__all_batches_full_size( PipeClass: PipeClassType, @@ -273,12 +276,10 @@ def test_batching__all_batches_full_size( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [ - (range(100_000_000, 100_000_003), 3, pytorch_x_value_gen, use_eager_fetch) - for use_eager_fetch in (True, False) - ], + "obs_range,var_range,X_value_gen", + [(range(100_000_000, 100_000_003), 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_unique_soma_joinids( PipeClass: PipeClassType, @@ -301,9 +302,10 @@ def test_unique_soma_joinids( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(5, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(5, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__partial_final_batch_size( PipeClass: PipeClassType, @@ -329,9 +331,10 @@ def test_batching__partial_final_batch_size( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(3, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__exactly_one_batch( PipeClass: PipeClassType, @@ -357,9 +360,10 @@ def test_batching__exactly_one_batch( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__empty_query_result( PipeClass: PipeClassType, @@ -383,9 +387,10 @@ def test_batching__empty_query_result( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_sparse_output__non_batched( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool @@ -406,9 +411,10 @@ def test_sparse_output__non_batched( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_sparse_output__batched( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool @@ -430,12 +436,10 @@ def test_sparse_output__batched( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [ - (10, 1, pytorch_x_value_gen, use_eager_fetch) - for use_eager_fetch in (True, False) - ], + "obs_range,var_range,X_value_gen", + [(10, 1, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test_batching__partial_soma_batches_are_concatenated( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool @@ -562,9 +566,10 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen,use_eager_fetch", - [(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)], + "obs_range,var_range,X_value_gen", + [(6, 3, pytorch_x_value_gen)], ) +@pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClassImplementation) def test__X_tensor_dtype_matches_X_matrix( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool From a2ad15004acb4688b3b58ba4d3d58ca59f29e4ce Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:31:08 -0400 Subject: [PATCH 12/25] reindex batches --- src/tiledbsoma_ml/pytorch.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index ff0ff3a..51d3cca 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -440,7 +440,9 @@ def _mini_batch_iter( ) result = ( X_datum, - obs_io_batch.iloc[iob_idx : iob_idx + mini_batch_size], + obs_io_batch.iloc[ + iob_idx : iob_idx + mini_batch_size + ].reset_index(drop=True), ) iob_idx += len(result[1]) else: @@ -455,7 +457,12 @@ def _mini_batch_iter( ) result = ( X_datum, - pd.concat([result[1], obs_io_batch.iloc[0:to_take]]), + pd.concat( + [result[1], obs_io_batch.iloc[0:to_take]], + # Index `obs_batch` from 0 to N-1, instead of disjoint, concatenated pieces of IO batches' + # indices + ignore_index=True, + ), ) iob_idx += to_take From 66e99ded86cca00f59574d2a2171869d94b62e5e Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:31:20 -0400 Subject: [PATCH 13/25] comment updates --- src/tiledbsoma_ml/pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 51d3cca..6f629b0 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -446,7 +446,7 @@ def _mini_batch_iter( ) iob_idx += len(result[1]) else: - # use any remnant from previous IO batch + # Use any remnant from previous IO batch to_take = min(mini_batch_size - len(result[1]), iob_len - iob_idx) X_datum = ( sparse.vstack([result[0], X_io_batch[0:to_take]]) @@ -531,7 +531,7 @@ def __iter__(self) -> Iterator[XObsDatum]: batch_size = self._exp_iter.batch_size for X, obs in self._exp_iter: if batch_size == 1: - X = X[0] + X = X[0] # This is a no-op for `csr_matrix`s yield X, obs def __len__(self) -> int: @@ -678,7 +678,7 @@ def __iter__(self) -> Iterator[XObsDatum]: batch_size = self._exp_iter.batch_size for X, obs in self._exp_iter: if batch_size == 1: - X = X[0] + X = X[0] # This is a no-op for `csr_matrix`s yield X, obs def __len__(self) -> int: From bfbe849bcdef025fbf9feb00b9f4dbb8b219122f Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:32:04 -0400 Subject: [PATCH 14/25] rm deprecated `typing.List` usages --- tests/test_pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index a62e515..9ecc858 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -6,7 +6,7 @@ from __future__ import annotations import pathlib -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, Optional, Sequence, Union from unittest.mock import patch import numpy as np @@ -61,17 +61,17 @@ def pytorch_seq_x_value_gen(obs_range: range, var_range: range) -> spmatrix: @pytest.fixture -def X_layer_names() -> List[str]: +def X_layer_names() -> list[str]: return ["raw"] @pytest.fixture -def obsp_layer_names() -> Optional[List[str]]: +def obsp_layer_names() -> Optional[list[str]]: return None @pytest.fixture -def varp_layer_names() -> Optional[List[str]]: +def varp_layer_names() -> Optional[list[str]]: return None From 85ec49205046d6bbc533338d17b8bb07da9b4bb6 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:33:12 -0400 Subject: [PATCH 15/25] `XValueGen`, `Path` type aliases --- tests/test_pytorch.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 9ecc858..1f2a9f3 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -5,7 +5,7 @@ from __future__ import annotations -import pathlib +from pathlib import Path from typing import Callable, Optional, Sequence, Union from unittest.mock import patch @@ -38,6 +38,7 @@ ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset, ) +XValueGen = Callable[[range, range], spmatrix] def pytorch_x_value_gen(obs_range: range, var_range: range) -> spmatrix: @@ -103,7 +104,7 @@ def add_sparse_array( key: str, obs_range: range, var_range: range, - value_gen: Callable[[range, range], spmatrix], + value_gen: XValueGen, ) -> None: a = coll.add_new_sparse_ndarray( key, type=pa.float32(), shape=(obs_range.stop, var_range.stop) @@ -114,10 +115,10 @@ def add_sparse_array( @pytest.fixture(scope="function") def soma_experiment( - tmp_path: pathlib.Path, + tmp_path: Path, obs_range: Union[int, range], var_range: Union[int, range], - X_value_gen: Callable[[range, range], sparse.spmatrix], + X_value_gen: XValueGen, obsp_layer_names: Sequence[str], varp_layer_names: Sequence[str], ) -> soma.Experiment: From 0353bf170e44bb7ecc1aeb405d4dc8786c5c9447 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:37:22 -0400 Subject: [PATCH 16/25] `s/PipeClassImplementation/PipeClasses/g` --- tests/test_pytorch.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 1f2a9f3..606714d 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -33,7 +33,7 @@ ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset, ] -PipeClassImplementation = ( +PipeClasses = ( ExperimentAxisQueryIterable, ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset, @@ -157,7 +157,7 @@ def soma_experiment( ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("return_sparse_X", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_non_batched( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -204,7 +204,7 @@ def test_non_batched( ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("return_sparse_X", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_uneven_soma_and_result_batches( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -246,7 +246,7 @@ def test_uneven_soma_and_result_batches( [(6, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_batching__all_batches_full_size( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -281,7 +281,7 @@ def test_batching__all_batches_full_size( [(range(100_000_000, 100_000_003), 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_unique_soma_joinids( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -307,7 +307,7 @@ def test_unique_soma_joinids( [(5, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_batching__partial_final_batch_size( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -336,7 +336,7 @@ def test_batching__partial_final_batch_size( [(3, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_batching__exactly_one_batch( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -365,7 +365,7 @@ def test_batching__exactly_one_batch( [(6, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_batching__empty_query_result( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -392,7 +392,7 @@ def test_batching__empty_query_result( [(6, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_sparse_output__non_batched( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool ) -> None: @@ -416,7 +416,7 @@ def test_sparse_output__non_batched( [(6, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_sparse_output__batched( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool ) -> None: @@ -441,7 +441,7 @@ def test_sparse_output__batched( [(10, 1, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_batching__partial_soma_batches_are_concatenated( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool ) -> None: @@ -469,7 +469,7 @@ def test_batching__partial_soma_batches_are_concatenated( "world_size,rank", [(3, 0), (3, 1), (3, 2), (2, 0), (2, 1)], ) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_distributed__returns_data_partition_for_rank( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -518,7 +518,7 @@ def test_distributed__returns_data_partition_for_rank( (3, 1, 2, 1), ], ) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test_distributed_and_multiprocessing__returns_data_partition_for_rank( PipeClass: PipeClassType, soma_experiment: Experiment, @@ -571,7 +571,7 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( [(6, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +@pytest.mark.parametrize("PipeClass", PipeClasses) def test__X_tensor_dtype_matches_X_matrix( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool ) -> None: From d41af0e946e2c57c033547563b0b775a8cd6ebae Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:47:55 -0400 Subject: [PATCH 17/25] update `test_non_batched`, rm redundant `test_sparse_output__non_batched` --- tests/test_pytorch.py | 68 +++++++++++++------------------------------ 1 file changed, 21 insertions(+), 47 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 606714d..13c54b9 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -14,6 +14,7 @@ import pyarrow as pa import pytest import tiledbsoma as soma +from pandas._testing import assert_frame_equal from scipy import sparse from scipy.sparse import coo_matrix, spmatrix from tiledbsoma import Experiment, _factory @@ -164,7 +165,7 @@ def test_non_batched( use_eager_fetch: bool, return_sparse_X: bool, ) -> None: - # batch_size should default to 1 + """Check batches of size 1 (the default)""" with soma_experiment.axis_query(measurement_name="RNA") as query: exp_data_pipe = PipeClass( query, @@ -173,29 +174,26 @@ def test_non_batched( use_eager_fetch=use_eager_fetch, return_sparse_X=return_sparse_X, ) - assert type(exp_data_pipe.shape) is tuple - assert len(exp_data_pipe.shape) == 2 assert exp_data_pipe.shape == (6, 3) - - row_iter = iter(exp_data_pipe) - - row = next(row_iter) - - if return_sparse_X: - # sparse slices remain 2D, always - assert isinstance(row[0], sparse.csr_matrix) - assert row[0].shape == (1, 3) - assert row[0].todense().tolist() == [[0, 1, 0]] - - else: - assert isinstance(row[0], np.ndarray) - assert np.squeeze(row[0]).shape == (3,) - assert np.squeeze(row[0]).tolist() == [0, 1, 0] - - assert isinstance(row[1], pd.DataFrame) - assert row[1].shape == (1, 1) - assert row[1].keys() == ["label"] - assert row[1]["label"].tolist() == ["0"] + batch_iter = iter(exp_data_pipe) + for idx, (X_batch, obs_batch) in enumerate(batch_iter): + expected_X = [0, 1, 0] if idx % 2 == 0 else [1, 0, 1] + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + # Sparse slices are always 2D + assert X_batch.shape == (1, 3) + assert X_batch.todense().tolist() == [expected_X] + else: + assert isinstance(X_batch, np.ndarray) + if PipeClass is ExperimentAxisQueryIterable: + assert X_batch.shape == (1, 3) + assert X_batch.tolist() == [expected_X] + else: + # ExperimentAxisQueryIterData{Pipe,set} "squeeze" dense single-row batches + assert X_batch.shape == (3,) + assert X_batch.tolist() == expected_X + + assert_frame_equal(obs_batch, pd.DataFrame({"label": [str(idx)]})) @pytest.mark.parametrize( @@ -387,30 +385,6 @@ def test_batching__empty_query_result( next(batch_iter) -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen", - [(6, 3, pytorch_x_value_gen)], -) -@pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClasses) -def test_sparse_output__non_batched( - PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool -) -> None: - with soma_experiment.axis_query(measurement_name="RNA") as query: - exp_data_pipe = PipeClass( - query, - X_name="raw", - obs_column_names=["label"], - return_sparse_X=True, - use_eager_fetch=use_eager_fetch, - ) - batch_iter = iter(exp_data_pipe) - - batch = next(batch_iter) - assert isinstance(batch[0], sparse.csr_matrix) - assert batch[0].todense().A.squeeze().tolist() == [0, 1, 0] - - @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)], From 7e8bb6ffe64bafd29d97b7ef6ee52ad809f1e5da Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:48:41 -0400 Subject: [PATCH 18/25] update `test_uneven_soma_and_result_batches` --- tests/test_pytorch.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 13c54b9..72f96d9 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -209,7 +209,7 @@ def test_uneven_soma_and_result_batches( use_eager_fetch: bool, return_sparse_X: bool, ) -> None: - """This is checking that batches are correctly created when they require fetching multiple chunks.""" + """Check that batches are correctly created when they require fetching multiple chunks.""" with soma_experiment.axis_query(measurement_name="RNA") as query: exp_data_pipe = PipeClass( query, @@ -220,23 +220,28 @@ def test_uneven_soma_and_result_batches( use_eager_fetch=use_eager_fetch, return_sparse_X=return_sparse_X, ) - row_iter = iter(exp_data_pipe) - - X_batch, obs_batch = next(row_iter) + assert exp_data_pipe.shape == (2, 3) + batch_iter = iter(exp_data_pipe) + X_batch, obs_batch = next(batch_iter) + assert X_batch.shape == (3, 3) if return_sparse_X: assert isinstance(X_batch, sparse.csr_matrix) - assert X_batch.todense()[0].tolist() == [[0, 1, 0]] + X_batch = X_batch.todense() else: assert isinstance(X_batch, np.ndarray) - assert X_batch[0].tolist() == [0, 1, 0] + assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]})) - assert isinstance(obs_batch, pd.DataFrame) - assert X_batch.shape[0] == obs_batch.shape[0] + X_batch, obs_batch = next(batch_iter) assert X_batch.shape == (3, 3) - assert obs_batch.shape == (3, 1) - assert ["label"] == obs_batch.keys() - assert obs_batch["label"].tolist() == ["0", "1", "2"] + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + else: + assert isinstance(X_batch, np.ndarray) + assert X_batch.tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["3", "4", "5"]})) @pytest.mark.parametrize( From fc768fd8c59a07fc6d2403ed1ee22446eedf4ddd Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:51:28 -0400 Subject: [PATCH 19/25] update `test_batching__all_batches_full_size`, rm redundant `test_sparse_output__batched` --- tests/test_pytorch.py | 50 +++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 72f96d9..e5c3cc5 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -249,11 +249,13 @@ def test_uneven_soma_and_result_batches( [(6, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClasses) def test_batching__all_batches_full_size( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, + return_sparse_X: bool, ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: exp_data_pipe = PipeClass( @@ -262,18 +264,24 @@ def test_batching__all_batches_full_size( obs_column_names=["label"], batch_size=3, use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, ) batch_iter = iter(exp_data_pipe) + assert exp_data_pipe.shape == (2, 3) - batch = next(batch_iter) - assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - assert batch[1].keys() == ["label"] - assert batch[1]["label"].tolist() == ["0", "1", "2"] + X_batch, obs_batch = next(batch_iter) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]})) - batch = next(batch_iter) - assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] - assert batch[1].keys() == ["label"] - assert batch[1]["label"].tolist() == ["3", "4", "5"] + X_batch, obs_batch = next(batch_iter) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + assert X_batch.tolist() == [[1, 0, 1], [0, 1, 0], [1, 0, 1]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["3", "4", "5"]})) with pytest.raises(StopIteration): next(batch_iter) @@ -384,37 +392,13 @@ def test_batching__empty_query_result( batch_size=3, use_eager_fetch=use_eager_fetch, ) + assert exp_data_pipe.shape == (0, 3) batch_iter = iter(exp_data_pipe) with pytest.raises(StopIteration): next(batch_iter) -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen", - [(6, 3, pytorch_x_value_gen)], -) -@pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClasses) -def test_sparse_output__batched( - PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool -) -> None: - with soma_experiment.axis_query(measurement_name="RNA") as query: - exp_data_pipe = PipeClass( - query, - X_name="raw", - obs_column_names=["label"], - batch_size=3, - return_sparse_X=True, - use_eager_fetch=use_eager_fetch, - ) - batch_iter = iter(exp_data_pipe) - - batch = next(batch_iter) - assert isinstance(batch[0], sparse.csr_matrix) - assert batch[0].todense().tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - - @pytest.mark.parametrize( "obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)], From 9cd03db5e71559fe0c4d008f1d34844843e87e14 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:53:52 -0400 Subject: [PATCH 20/25] update `test_soma_joinids`, rm redundant `test__X_tensor_dtype_matches_X_matrix` --- tests/test_pytorch.py | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index e5c3cc5..a8379dc 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -5,6 +5,7 @@ from __future__ import annotations +from functools import partial from pathlib import Path from typing import Callable, Optional, Sequence, Union from unittest.mock import patch @@ -27,6 +28,8 @@ ExperimentAxisQueryIterDataPipe, ) +assert_array_equal = partial(np.testing.assert_array_equal, strict=True) + # These control which classes are tested (for most, but not all tests). # Centralized to allow easy add/delete of specific test parameters. PipeClassType = Union[ @@ -293,7 +296,7 @@ def test_batching__all_batches_full_size( ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClasses) -def test_unique_soma_joinids( +def test_soma_joinids( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, @@ -306,11 +309,12 @@ def test_unique_soma_joinids( batch_size=3, use_eager_fetch=use_eager_fetch, ) + assert exp_data_pipe.shape == (1, 3) soma_joinids = np.concatenate( [batch[1]["soma_joinid"].to_numpy() for batch in exp_data_pipe] ) - assert len(np.unique(soma_joinids)) == len(soma_joinids) + assert_array_equal(soma_joinids, np.arange(100_000_000, 100_000_003)) @pytest.mark.parametrize( @@ -529,28 +533,6 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( assert sorted(soma_joinids) == expected_joinids.tolist() -@pytest.mark.parametrize( - "obs_range,var_range,X_value_gen", - [(6, 3, pytorch_x_value_gen)], -) -@pytest.mark.parametrize("use_eager_fetch", [True, False]) -@pytest.mark.parametrize("PipeClass", PipeClasses) -def test__X_tensor_dtype_matches_X_matrix( - PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool -) -> None: - with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = PipeClass( - query, - X_name="raw", - obs_column_names=["label"], - batch_size=3, - use_eager_fetch=use_eager_fetch, - ) - data = next(iter(dp)) - - assert data[0].dtype == np.float32 - - def test_batched() -> None: from tiledbsoma_ml.pytorch import _batched From b8e2c835c1d6627b119e850cf77b4d44cc92e479 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:54:22 -0400 Subject: [PATCH 21/25] update `test_batching__partial_final_batch_size` --- tests/test_pytorch.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index a8379dc..493c7eb 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -322,11 +322,13 @@ def test_soma_joinids( [(5, 3, pytorch_x_value_gen)], ) @pytest.mark.parametrize("use_eager_fetch", [True, False]) +@pytest.mark.parametrize("return_sparse_X", [True, False]) @pytest.mark.parametrize("PipeClass", PipeClasses) def test_batching__partial_final_batch_size( PipeClass: PipeClassType, soma_experiment: Experiment, use_eager_fetch: bool, + return_sparse_X: bool, ) -> None: with soma_experiment.axis_query(measurement_name="RNA") as query: exp_data_pipe = PipeClass( @@ -335,12 +337,18 @@ def test_batching__partial_final_batch_size( obs_column_names=["label"], batch_size=3, use_eager_fetch=use_eager_fetch, + return_sparse_X=return_sparse_X, ) + assert exp_data_pipe.shape == (2, 3) batch_iter = iter(exp_data_pipe) next(batch_iter) - batch = next(batch_iter) - assert batch[0].tolist() == [[1, 0, 1], [0, 1, 0]] + X_batch, obs_batch = next(batch_iter) + if return_sparse_X: + assert isinstance(X_batch, sparse.csr_matrix) + X_batch = X_batch.todense() + assert X_batch.tolist() == [[1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["3", "4"]})) with pytest.raises(StopIteration): next(batch_iter) From a46e10adda55d23dee36726eccfbfc0f59739499 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:54:56 -0400 Subject: [PATCH 22/25] update `test_batching__exactly_one_batch` --- tests/test_pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 493c7eb..6c73cc3 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -373,11 +373,11 @@ def test_batching__exactly_one_batch( batch_size=3, use_eager_fetch=use_eager_fetch, ) + assert exp_data_pipe.shape == (1, 3) batch_iter = iter(exp_data_pipe) - - batch = next(batch_iter) - assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] - assert batch[1]["label"].tolist() == ["0", "1", "2"] + X_batch, obs_batch = next(batch_iter) + assert X_batch.tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + assert_frame_equal(obs_batch, pd.DataFrame({"label": ["0", "1", "2"]})) with pytest.raises(StopIteration): next(batch_iter) From 0f76f03088b614fabd6b12a27917cda84f445fa7 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:55:30 -0400 Subject: [PATCH 23/25] update var names: `batch`/`batches` --- tests/test_pytorch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 6c73cc3..1d0ce35 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -431,9 +431,9 @@ def test_batching__partial_soma_batches_are_concatenated( use_eager_fetch=use_eager_fetch, ) - full_result = list(exp_data_pipe) + batches = list(exp_data_pipe) - assert [len(batch[0]) for batch in full_result] == [3, 3, 3, 1] + assert [len(batch[0]) for batch in batches] == [3, 3, 3, 1] @pytest.mark.parametrize( @@ -471,9 +471,9 @@ def test_distributed__returns_data_partition_for_rank( obs_column_names=["soma_joinid"], io_batch_size=2, ) - full_result = list(iter(dp)) + batches = list(iter(dp)) soma_joinids = np.concatenate( - [t[1]["soma_joinid"].to_numpy() for t in full_result] + [batch[1]["soma_joinid"].to_numpy() for batch in batches] ) expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ From ffaccdd010d8cd36919d776d476181ed06e10ded Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:56:42 -0400 Subject: [PATCH 24/25] update `test_distributed_and_multiprocessing__returns_data_partition_for_rank` --- tests/test_pytorch.py | 84 ++++++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 40 deletions(-) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 1d0ce35..2bf4b4c 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -482,63 +482,67 @@ def test_distributed__returns_data_partition_for_rank( assert sorted(soma_joinids) == expected_joinids +# fmt: off @pytest.mark.parametrize( - "obs_range,var_range,X_value_gen", - [(12, 3, pytorch_x_value_gen), (13, 3, pytorch_x_value_gen)], -) -@pytest.mark.parametrize( - "world_size,rank,num_workers,worker_id", + "obs_range,var_range,X_value_gen,world_size,num_workers,splits", [ - (3, 1, 2, 0), - (3, 1, 2, 1), + (12, 3, pytorch_x_value_gen, 3, 2, [[0, 2, 4], [4, 6, 8], [ 8, 10, 12]]), + (13, 3, pytorch_x_value_gen, 3, 2, [[0, 2, 4], [5, 7, 9], [ 9, 11, 13]]), + (15, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 5], [5, 9, 10], [10, 14, 15]]), + (16, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 5], [6, 10, 11], [11, 15, 16]]), + (18, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [6, 10, 12], [12, 16, 18]]), + (19, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [7, 11, 13], [13, 17, 19]]), + (20, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 6], [7, 11, 13], [14, 18, 20]]), + (21, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 7], [7, 11, 14], [14, 18, 21]]), + (25, 3, pytorch_x_value_gen, 3, 2, [[0, 4, 8], [9, 13, 17], [17, 21, 25]]), + (27, 3, pytorch_x_value_gen, 3, 2, [[0, 6, 9], [9, 15, 18], [18, 24, 27]]), ], ) -@pytest.mark.parametrize("PipeClass", PipeClasses) +# fmt: on def test_distributed_and_multiprocessing__returns_data_partition_for_rank( - PipeClass: PipeClassType, soma_experiment: Experiment, - obs_range: int, world_size: int, - rank: int, num_workers: int, - worker_id: int, + splits: list[list[int]], ) -> None: """Tests pytorch._partition_obs_joinids() behavior in a simulated PyTorch distributed processing mode and DataLoader multiprocessing mode, using mocks to avoid having to do distributed pytorch setup or real DataLoader multiprocessing.""" - with ( - patch("torch.utils.data.get_worker_info") as mock_get_worker_info, - patch("torch.distributed.is_initialized") as mock_dist_is_initialized, - patch("torch.distributed.get_rank") as mock_dist_get_rank, - patch("torch.distributed.get_world_size") as mock_dist_get_world_size, - ): - mock_get_worker_info.return_value = WorkerInfo( - id=worker_id, num_workers=num_workers, seed=1234 - ) - mock_dist_is_initialized.return_value = True - mock_dist_get_rank.return_value = rank - mock_dist_get_world_size.return_value = world_size - - with soma_experiment.axis_query(measurement_name="RNA") as query: - dp = PipeClass( - query, - X_name="raw", - obs_column_names=["soma_joinid"], - io_batch_size=2, + for rank in range(world_size): + proc_splits = splits[rank] + for worker_id in range(num_workers): + expected_joinids = list( + range(proc_splits[worker_id], proc_splits[worker_id + 1]) ) + with ( + patch("torch.utils.data.get_worker_info") as mock_get_worker_info, + patch("torch.distributed.is_initialized") as mock_dist_is_initialized, + patch("torch.distributed.get_rank") as mock_dist_get_rank, + patch("torch.distributed.get_world_size") as mock_dist_get_world_size, + ): + mock_get_worker_info.return_value = WorkerInfo( + id=worker_id, num_workers=num_workers, seed=1234 + ) + mock_dist_is_initialized.return_value = True + mock_dist_get_rank.return_value = rank + mock_dist_get_world_size.return_value = world_size - full_result = list(iter(dp)) + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = ExperimentAxisQueryIterable( + query, + X_name="raw", + obs_column_names=["soma_joinid"], + io_batch_size=2, + ) - soma_joinids = np.concatenate( - [t[1]["soma_joinid"].to_numpy() for t in full_result] - ) + batches = list(iter(dp)) - expected_joinids = np.array_split(np.arange(obs_range), world_size)[rank][ - 0 : obs_range // world_size - ] - expected_joinids = np.array_split(expected_joinids, num_workers)[worker_id] - assert sorted(soma_joinids) == expected_joinids.tolist() + soma_joinids = np.concatenate( + [batch[1]["soma_joinid"].to_numpy() for batch in batches] + ).tolist() + + assert soma_joinids == expected_joinids def test_batched() -> None: From 7801854bf17ea9cefc883e534462a1e2b49d8832 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Thu, 3 Oct 2024 15:56:50 -0400 Subject: [PATCH 25/25] add `test_splits` case --- tests/test_pytorch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 2bf4b4c..8cbae43 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -567,6 +567,7 @@ def test_splits() -> None: from tiledbsoma_ml.pytorch import _splits assert _splits(10, 1).tolist() == [0, 10] + assert _splits(10, 2).tolist() == [0, 5, 10] assert _splits(10, 3).tolist() == [0, 4, 7, 10] assert _splits(10, 4).tolist() == [0, 3, 6, 8, 10] assert _splits(10, 10).tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]