Skip to content

Commit

Permalink
add experiment_dataloader helper API
Browse files Browse the repository at this point in the history
  • Loading branch information
bkmartinjr authored and ryan-williams committed Oct 3, 2024
1 parent d123979 commit 3d8bf2a
Show file tree
Hide file tree
Showing 3 changed files with 283 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/tiledbsoma_ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from .pytorch import (
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
experiment_dataloader,
)

__version__ = "0.1.0-dev"

__all__ = [
"ExperimentAxisQueryIterDataPipe",
"ExperimentAxisQueryIterableDataset",
"experiment_dataloader",
]
87 changes: 87 additions & 0 deletions src/tiledbsoma_ml/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,73 @@ def shape(self) -> Tuple[int, int]:
return self._exp_iter.shape


def experiment_dataloader(
ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset,
**dataloader_kwargs: Any,
) -> torch.utils.data.DataLoader:
"""Factory method for :class:`torch.utils.data.DataLoader`. This method can be used to safely instantiate a
:class:`torch.utils.data.DataLoader` that works with :class:`tiledbsoma_ml.ExperimentAxisQueryIterableDataset`
or :class:`tiledbsoma_ml.ExperimentAxisQueryIterDataPipe`.
Several :class:`torch.utils.data.DataLoader` constructor parameters are not applicable, or are non-performant,
when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``.
Specifying any of these parameters will result in an error.
Refer to ``https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader`` for more information on
:class:`torch.utils.data.DataLoader` parameters.
Args:
ds:
A :class:`torch.utils.data.IterableDataset` or a :class:`torchdata.datapipes.iter.IterDataPipe`. May
include chained data pipes.
**dataloader_kwargs:
Additional keyword arguments to pass to the :class:`torch.utils.data.DataLoader` constructor,
except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not
supported when using data loaders in this module.
Returns:
A :class:`torch.utils.data.DataLoader`.
Raises:
ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params
are passed as keyword arguments.
Lifecycle:
experimental
"""
unsupported_dataloader_args = [
"shuffle",
"batch_size",
"sampler",
"batch_sampler",
]
if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()):
raise ValueError(
f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported"
)

if dataloader_kwargs.get("num_workers", 0) > 0:
_init_multiprocessing()

if "collate_fn" not in dataloader_kwargs:
dataloader_kwargs["collate_fn"] = _collate_noop

return torch.utils.data.DataLoader(
ds,
batch_size=None, # batching is handled by upstream iterator
shuffle=False, # shuffling is handled by upstream iterator
**dataloader_kwargs,
)


def _collate_noop(datum: _T) -> _T:
"""Noop collation for use with a dataloader instance.
Private.
"""
return datum


def _splits(total_length: int, sections: int) -> npt.NDArray[np.intp]:
"""For ``total_length`` points, compute start/stop offsets that split the length into roughly equal sizes.
Expand Down Expand Up @@ -784,3 +851,23 @@ def _get_worker_world_rank() -> Tuple[int, int]:
num_workers = worker_info.num_workers
worker = worker_info.id
return num_workers, worker


def _init_multiprocessing() -> None:
"""Ensures use of "spawn" for starting child processes with multiprocessing.
Forked processes are known to be problematic:
https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks
Also, CUDA does not support forked child processes:
https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing
Private.
"""
orig_start_method = torch.multiprocessing.get_start_method()
if orig_start_method != "spawn":
if orig_start_method:
logger.warning(
"switching torch multiprocessing start method from "
f'"{torch.multiprocessing.get_start_method()}" to "spawn"'
)
torch.multiprocessing.set_start_method("spawn", force=True)
195 changes: 194 additions & 1 deletion tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from functools import partial
from pathlib import Path
from typing import Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from unittest.mock import patch

import numpy as np
import numpy.typing as npt
import pandas as pd
import pyarrow as pa
import pytest
Expand All @@ -26,6 +27,7 @@
ExperimentAxisQueryIterable,
ExperimentAxisQueryIterableDataset,
ExperimentAxisQueryIterDataPipe,
experiment_dataloader,
)

assert_array_equal = partial(np.testing.assert_array_equal, strict=True)
Expand Down Expand Up @@ -436,6 +438,37 @@ def test_batching__partial_soma_batches_are_concatenated(
assert [len(batch[0]) for batch in batches] == [3, 3, 3, 1]


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)]
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_multiprocessing__returns_full_result(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
) -> None:
"""Tests the ExperimentAxisQueryIterDataPipe provides all data, as collected from multiple processes that are managed by a
PyTorch DataLoader with multiple workers configured."""
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["soma_joinid", "label"],
io_batch_size=3, # two chunks, one per worker
)
# Note we're testing the ExperimentAxisQueryIterDataPipe via a DataLoader, since this is what sets up the multiprocessing
dl = experiment_dataloader(dp, num_workers=2)

full_result = list(iter(dl))

soma_joinids = np.concatenate(
[t[1]["soma_joinid"].to_numpy() for t in full_result]
)
assert sorted(soma_joinids) == list(range(6))


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen",
[(6, 3, pytorch_x_value_gen), (7, 3, pytorch_x_value_gen)],
Expand Down Expand Up @@ -545,6 +578,166 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank(
assert soma_joinids == expected_joinids


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(3, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_experiment_dataloader__non_batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
data = [row for row in dl]
assert all(d[0].shape == (3,) for d in data)
assert all(d[1].shape == (1, 1) for d in data)

row = data[0]
assert row[0].tolist() == [0, 1, 0]
assert row[1]["label"].tolist() == ["0"]


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[(6, 3, pytorch_x_value_gen, use_eager_fetch) for use_eager_fetch in (True, False)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_experiment_dataloader__batched(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
batch_size=3,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
data = [row for row in dl]

batch = data[0]
assert batch[0].tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
assert batch[1].to_numpy().tolist() == [[0], [1], [2]]


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,use_eager_fetch",
[
(10, 3, pytorch_x_value_gen, use_eager_fetch)
for use_eager_fetch in (True, False)
],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_experiment_dataloader__batched_length(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
use_eager_fetch: bool,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
batch_size=3,
use_eager_fetch=use_eager_fetch,
)
dl = experiment_dataloader(dp)
assert len(dl) == len(list(dl))


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen,batch_size",
[(10, 3, pytorch_x_value_gen, batch_size) for batch_size in (1, 3, 10)],
)
@pytest.mark.parametrize(
"PipeClass",
(ExperimentAxisQueryIterDataPipe, ExperimentAxisQueryIterableDataset),
)
def test_experiment_dataloader__collate_fn(
PipeClass: ExperimentAxisQueryIterDataPipe | ExperimentAxisQueryIterableDataset,
soma_experiment: Experiment,
batch_size: int,
) -> None:
def collate_fn(
batch_size: int, data: Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]
) -> Tuple[npt.NDArray[np.number[Any]], pd.DataFrame]:
assert isinstance(data, tuple)
assert len(data) == 2
assert isinstance(data[0], np.ndarray) and isinstance(data[1], pd.DataFrame)
if batch_size > 1:
assert data[0].shape[0] == data[1].shape[0]
assert data[0].shape[0] <= batch_size
else:
assert data[0].ndim == 1
assert data[1].shape[1] <= batch_size
return data

with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = PipeClass(
query,
X_name="raw",
obs_column_names=["label"],
batch_size=batch_size,
)
dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size))
assert len(list(dl)) > 0


@pytest.mark.parametrize(
"obs_range,var_range,X_value_gen", [(10, 1, pytorch_x_value_gen)]
)
def test__pytorch_splitting(
soma_experiment: Experiment,
) -> None:
with soma_experiment.axis_query(measurement_name="RNA") as query:
dp = ExperimentAxisQueryIterDataPipe(
query,
X_name="raw",
obs_column_names=["label"],
)
# function not available for IterableDataset, yet....
dp_train, dp_test = dp.random_split(
weights={"train": 0.7, "test": 0.3}, seed=1234
)
dl = experiment_dataloader(dp_train)

all_rows = list(iter(dl))
assert len(all_rows) == 7


def test_experiment_dataloader__unsupported_params__fails() -> None:
with patch(
"tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe"
) as dummy_exp_data_pipe:
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, shuffle=True)
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, batch_size=3)
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, batch_sampler=[])
with pytest.raises(ValueError):
experiment_dataloader(dummy_exp_data_pipe, sampler=[])


def test_batched() -> None:
from tiledbsoma_ml.pytorch import _batched

Expand Down

0 comments on commit 3d8bf2a

Please sign in to comment.