Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] CensusSCVIDataModule + notebook #1196

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3113626
Support for custom obs encoders
ebezzi Jun 12, 2024
23d000c
Add docstrings
ebezzi Jun 12, 2024
945f14f
Update api/python/cellxgene_census/src/cellxgene_census/experimental/…
ebezzi Jun 12, 2024
31f49a7
Doc changes
ebezzi Jun 13, 2024
1c5e875
merge from main
ebezzi Jun 13, 2024
b9ca1e0
merge from main
ebezzi Jun 13, 2024
f688020
Revert some changes
ebezzi Jun 13, 2024
8ea52b2
Lightning datamodule
ebezzi Jun 17, 2024
222efdd
New notebook
ebezzi Jun 17, 2024
721bac5
Small refactor
ebezzi Jun 18, 2024
d4e1d9b
More refactor
ebezzi Jun 18, 2024
9fa5b53
Partial test upgrades
ebezzi Jun 25, 2024
367cc55
More fixes
ebezzi Jun 25, 2024
99fbc76
Revert dockerfile
ebezzi Jun 25, 2024
ad6b0ec
Explicit columns
ebezzi Jun 26, 2024
b6253f9
Consolidate encoders variable
ebezzi Jun 27, 2024
fc4271f
Add duplicate check
ebezzi Jun 28, 2024
085ec6e
Fix typying issue
ebezzi Jun 28, 2024
5221ce6
Notebook changes
ebezzi Jun 28, 2024
79897be
Rerun notebook
ebezzi Jun 28, 2024
32b99f5
Merge branch 'ebezzi/support-custom-obs-encoders' into ebezzi/census-…
ebezzi Jun 28, 2024
2307216
Some fixes
ebezzi Jun 29, 2024
f4063ac
Update api/python/notebooks/experimental/pytorch_loader_scvi.ipynb
ebezzi Jul 1, 2024
bfdaa00
Update api/python/notebooks/experimental/pytorch_loader_scvi.ipynb
ebezzi Jul 1, 2024
2fcd55c
Update api/python/notebooks/experimental/pytorch_loader_scvi.ipynb
ebezzi Jul 1, 2024
9a63be5
Update api/python/notebooks/experimental/pytorch_loader_scvi.ipynb
ebezzi Jul 1, 2024
36954e0
Update api/python/notebooks/experimental/pytorch_loader_scvi.ipynb
ebezzi Jul 1, 2024
1ddff7e
Update api/python/notebooks/experimental/pytorch_loader_scvi.ipynb
ebezzi Jul 1, 2024
ab0e660
Several changes
ebezzi Jul 1, 2024
77d1a59
Merge branch 'ebezzi/census-scvi-datamodule' of github.com:chanzucker…
ebezzi Jul 1, 2024
110021b
Add mypy ignore for datamodule.py
ebezzi Jul 1, 2024
6edd123
Merge branch 'main' into ebezzi/census-scvi-datamodule
ebezzi Jul 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""An API to facilitate use of PyTorch ML training with data from the CZI Science CELLxGENE Census."""

from .pytorch import ExperimentDataPipe, Stats, experiment_dataloader
from .pytorch import Encoder, ExperimentDataPipe, Stats, experiment_dataloader

__all__ = [
"Stats",
"ExperimentDataPipe",
"experiment_dataloader",
"Encoder",
"CensusSCVIDataModule",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import functools

import numpy as np
import pandas as pd
import torch
from lightning.pytorch import LightningDataModule

from .pytorch import Encoder, ExperimentDataPipe, experiment_dataloader


class BatchEncoder(Encoder):
"""An encoder that concatenates and encodes several obs columns."""

def __init__(self, cols: list[str], name: str = "batch"):
self.cols = cols
from sklearn.preprocessing import LabelEncoder

self._name = name
self._encoder = LabelEncoder()

def _join_cols(self, df: pd.DataFrame):
return functools.reduce(lambda a, b: a + b, [df[c].astype(str) for c in self.cols])

def transform(self, df: pd.DataFrame):
arr = self._join_cols(df)
return self._encoder.transform(arr)

def inverse_transform(self, encoded_values: np.ndarray) -> np.ndarray:
return self._encoder.inverse_transform(encoded_values)

def fit(self, obs: pd.DataFrame):
arr = self._join_cols(obs)
self._encoder.fit(arr.unique())

@property
def columns(self):
return self.cols

@property
def name(self) -> str:
return self._name

@property
def classes_(self):
return self._encoder.classes_


class CensusSCVIDataModule(LightningDataModule):
"""Lightning data module for training an scVI model using the ExperimentDataPipe.

Parameters
----------
*args
Positional arguments passed to
:class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`.
batch_keys
List of obs column names concatenated to form the batch column.
train_size
Fraction of data to use for training.
split_seed
Seed for data split.
dataloader_kwargs
Keyword arguments passed into
:func:`~cellxgene_census.experimental.ml.pytorch.experiment_dataloader`.
**kwargs
Additional keyword arguments passed into
:class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`. Must not include
``obs_column_names``.
"""

_TRAIN_KEY = "train"
_VALIDATION_KEY = "validation"

def __init__(
self,
*args,
batch_keys: list[str] | None = None,
train_size: float | None = None,
split_seed: int | None = None,
dataloader_kwargs: dict[str, any] | None = None,
**kwargs,
):
super().__init__()
self.datapipe_args = args
self.datapipe_kwargs = kwargs
self.batch_keys = batch_keys
self.train_size = train_size
self.split_seed = split_seed
self.dataloader_kwargs = dataloader_kwargs or {}

@property
def batch_keys(self) -> list[str]:
"""List of obs column names concatenated to form the batch column."""
if not hasattr(self, "_batch_keys"):
raise AttributeError("`batch_keys` not set.")
return self._batch_keys

@batch_keys.setter
def batch_keys(self, value: list[str] | None):
if value is None or not isinstance(value, list):
raise ValueError("`batch_keys` must be a list of strings.")
self._batch_keys = value

@property
def obs_column_names(self) -> list[str]:
"""Passed to :class:`~cellxgene_census.experimental.ml.pytorch.ExperimentDataPipe`."""
if hasattr(self, "_obs_column_names"):
return self._obs_column_names

obs_column_names = []
if self.batch_keys is not None:
obs_column_names.extend(self.batch_keys)

self._obs_column_names = obs_column_names
return self._obs_column_names

@property
def split_seed(self) -> int:
"""Seed for data split."""
if not hasattr(self, "_split_seed"):
raise AttributeError("`split_seed` not set.")
return self._split_seed

@split_seed.setter
def split_seed(self, value: int | None):
if value is not None and not isinstance(value, int):
raise ValueError("`split_seed` must be an integer.")
self._split_seed = value or 0

@property
def train_size(self) -> float:
"""Fraction of data to use for training."""
if not hasattr(self, "_train_size"):
raise AttributeError("`train_size` not set.")
return self._train_size

@train_size.setter
def train_size(self, value: float | None):
if value is not None and not isinstance(value, float):
raise ValueError("`train_size` must be a float.")
elif value is not None and (value < 0.0 or value > 1.0):
raise ValueError("`train_size` must be between 0.0 and 1.0.")
self._train_size = value or 1.0

@property
def validation_size(self) -> float:
"""Fraction of data to use for validation."""
if not hasattr(self, "_train_size"):
raise AttributeError("`validation_size` not available.")
return 1.0 - self.train_size

@property
def weights(self) -> dict[str, float]:
"""Passed to :meth:`~cellxgene_census.experimental.ml.ExperimentDataPipe.random_split`."""
if not hasattr(self, "_weights"):
self._weights = {self._TRAIN_KEY: self.train_size}
if self.validation_size > 0.0:
self._weights[self._VALIDATION_KEY] = self.validation_size
return self._weights

@property
def datapipe(self) -> ExperimentDataPipe:
"""Experiment data pipe."""
if not hasattr(self, "_datapipe"):
encoder = BatchEncoder(self.obs_column_names)
self._datapipe = ExperimentDataPipe(
*self.datapipe_args,
encoders=[encoder],
**self.datapipe_kwargs,
)
return self._datapipe

def setup(self, stage: str | None = None):
"""Set up the train and validation data pipes."""
datapipes = self.datapipe.random_split(weights=self.weights, seed=self.split_seed)
self._train_datapipe = datapipes[0]
if self.validation_size > 0.0:
self._validation_datapipe = datapipes[1]
else:
self._validation_datapipe = None

def train_dataloader(self):
"""Training data loader."""
return experiment_dataloader(self._train_datapipe, **self.dataloader_kwargs)

def val_dataloader(self):
"""Validation data loader."""
if self._validation_datapipe is not None:
return experiment_dataloader(self._validation_datapipe, **self.dataloader_kwargs)

@property
def n_obs(self) -> int:
"""Number of observations in the query.

Necessary in scvi-tools to compute a heuristic of ``max_epochs``.
"""
return self.datapipe.shape[0]

@property
def n_vars(self) -> int:
"""Number of features in the query.
Necessary in scvi-tools to initialize the actual layers in the model.
"""
return self.datapipe.shape[1]

@property
def n_batch(self) -> int:
"""Number of unique batches (after concatenation of ``batch_keys``).
Necessary in scvi-tools so that the model knows how to one-hot encode batches.
"""
return self.get_n_classes("batch")

def get_n_classes(self, key: str) -> int:
"""Return the number of classes for a given obs column."""
return len(self.datapipe.obs_encoders[key].classes_)

def on_before_batch_transfer(
self,
batch: tuple[torch.Tensor, torch.Tensor],
dataloader_idx: int,
) -> dict[str, torch.Tensor | None]:
"""Format the datapipe output with registry keys for scvi-tools."""
X, obs = batch

X_KEY: str = "X"
BATCH_KEY: str = "batch"
LABELS_KEY: str = "labels"

return {
X_KEY: X,
BATCH_KEY: obs,
LABELS_KEY: None,
}
Loading
Loading