Skip to content

Commit

Permalink
scVI pipeline (#787)
Browse files Browse the repository at this point in the history
Co-authored-by: pablo-gar <[email protected]>
  • Loading branch information
ebezzi and pablo-gar authored Dec 8, 2023
1 parent d2e7468 commit 7009201
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 0 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ repos:
- numpy
- pandas-stubs
- typing_extensions
- types-PyYAML

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
Expand Down
44 changes: 44 additions & 0 deletions tools/models/scvi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Census trained scVI model

## Training Pipeline

This directory contains a set of scripts that can be used to train [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html) on the whole Census data, and to generate its latent space representation embeddings.

The model can be trained separately on each experiment (`homo_sapiens` and `mus_musculus`), and produces separate artifacts.

In order to run the training pipeline, three separate files are provided:

[scvi-prepare.py](scvi-prepare.py)

This file prepares an AnnData file that can be fed to the scVI trainer directly. The preparation strategy is:

1. Take all the primary cells from the census with a gene count greater or equal than 300 (`is_primary_data == True and nnz >= 300`).
1. Extract the top 8000 highly variable genes (using the Census `highly_variable_genes` function). Those are serialized to a `hv_genes.pkl` numpy.ndarray.
1. A batch_key column is created by concatenating the `[dataset_id, assay, suspension_type, donor_id]` covariates

The output of this file is an `anndata_model.h5ad` file.

[scvi-train.py](scvi-train.py)

This file takes the AnnData file from the previous step and trains an scVI model on it. See [scvi-config.yaml](scvi.config.yaml) for an up-to-date list of model and training parameters.

The resulting model weights are saved to an `scvi.model` directory. Tensorboard logs are also available as part of the output.

[scvi-create-latent-update.py](scvi-create-latent-update.py)

This file takes the previously generated model and obtains the latent space representation to generate cell embeddings. The generation strategy is:

1. Take all the cells from the Census (since we want to generate embeddings for every cell)
1. Take the same highly variable genes from the `prepare` step
1. Generate an AnnData with the same properties as the `prepare` step
1. Call the `scvi.model.SCVI.load_query_data()` function on this AnnData. This allows to work on a dataset that has more cells than the one the model is trained on (which is required so that the model doesn't need to be re-trained from scratch on each Census version). A further pass of training is possible, but we just set `is_trained = True` to skip it.
1. We call `get_latent_representation()` to generate the embeddings
1. Both the final h5ad file, the embeddings and the cell index are saved as part of the output.

## Selection of model parameters

The final selection of parameters for the training phase was based on a hyper parameter search as described in the [CELLxGENE Discover Census scvi-tools initial autotune report](https://github.com/YosefLab/census-scvi/blob/main/experiments/autotune/notebooks/2023_09_autotune_report.ipynb)

## Environment setup

The training has been performed on an AWS EC2 machine (instance type: g4dn.12xlarge), running on Ubuntu 20.04. Run [scvi-init.sh](scvi-init.sh) to set up the environment required to run the pipeline.
37 changes: 37 additions & 0 deletions tools/models/scvi/scvi-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
census:
organism:
"mus_musculus"
obs_query: # Use if you want to train on a subset of the model
null
obs_query_model: # Required when loading data for model training. Do not change.
'is_primary_data == True and nnz >= 300'
hvg:
top_n_hvg:
8000
hvg_batch:
[suspension_type, assay]
anndata:
batch_key:
[dataset_id, assay, suspension_type, donor_id]
model_filename:
anndata_model.h5ad
model:
filename: "scvi.model"
n_hidden: 512
n_latent: 200
n_layers: 1
dropout_rate: 0.1
train:
max_epochs: 20
batch_size: 1048
train_size: 0.95
early_stopping: True
trainer:
early_stopping_patience: 2
early_stopping_monitor: validation_loss # should be validation_loss - see https://github.com/chanzuckerberg/cellxgene-census/issues/777#issuecomment-1743196837
check_val_every_n_epoch: 1
multi_gpu: False
num_workers: 4
devices: [0, 1, 2, 3]
training_plan:
lr: 1.0e-4
78 changes: 78 additions & 0 deletions tools/models/scvi/scvi-create-latent-update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import functools
import gc

import cellxgene_census
import numpy as np
import pandas as pd
import scvi
import tiledbsoma as soma
import yaml

file = "scvi-config.yaml"

if __name__ == "__main__":
with open(file) as f:
config = yaml.safe_load(f)

census = cellxgene_census.open_soma(census_version="latest")

census_config = config.get("census")
experiment_name = census_config.get("organism")
obs_value_filter = census_config.get("obs_query")

hv = pd.read_pickle("hv_genes.pkl")
hv_idx = hv[hv].index

if obs_value_filter is not None:
obs_query = soma.AxisQuery(value_filter=obs_value_filter)
else:
obs_query = None

query = census["census_data"][experiment_name].axis_query(
measurement_name="RNA",
obs_query=obs_query,
var_query=soma.AxisQuery(coords=(list(hv_idx),)),
)

adata_config = config["anndata"]
batch_key = adata_config.get("batch_key")
ad_filename = adata_config.get("model_filename")

print("Converting to AnnData")
ad = query.to_anndata(X_name="raw")
ad.obs["batch"] = functools.reduce(lambda a, b: a + b, [ad.obs[c].astype(str) for c in batch_key])

ad.var.set_index("feature_id", inplace=True)

idx = query.obs(column_names=["soma_joinid"]).concat().to_pandas().index.to_numpy()

del census, query, hv, hv_idx
gc.collect()

model_config = config.get("model")
model_filename = model_config.get("filename")
n_latent = model_config.get("n_latent")

scvi.model.SCVI.prepare_query_anndata(ad, model_filename)

vae_q = scvi.model.SCVI.load_query_data(
ad,
model_filename,
)
# Uncomment #1 if you want to do a forward pass with an additional training epoch.
# Uncomment #2 if you want to do a forward pass without additional training.

# vae_q.train(max_epochs=1, plan_kwargs=dict(weight_decay=0.0)) # 1
vae_q.is_trained = True # 2
latent = vae_q.get_latent_representation()

ad.write_h5ad("anndata-full.h5ad", compression="gzip")

del vae_q, ad
gc.collect()

with open("latent-idx.npy", "wb") as f:
np.save(f, idx)

with open("latent.npy", "wb") as f:
np.save(f, latent)
34 changes: 34 additions & 0 deletions tools/models/scvi/scvi-init.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

# Can be used to bootstrap a g4dn.* instance with scvi-tools and cellxgene-census

export DEBIAN_FRONTEND=noninteractive

sudo apt -y update
sudo apt -y install python3-pip
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt -y update
sudo apt -y install python3.11
sudo apt -y install python3.11-venv
curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
sudo update-alternatives --install /usr/bin/python python /usr/bin/python2 1
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.8 2
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 3
sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1
sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 2
sudo update-alternatives --config python
sudo update-alternatives --config python3
sudo cp /usr/lib/python3/dist-packages/apt_pkg.cpython-38-x86_64-linux-gnu.so /usr/lib/python3/dist-packages/apt_pkg.so

sudo apt -y install libnvidia-gl-535 libnvidia-common-535 libnvidia-compute-535 libnvidia-encode-535 libnvidia-decode-535 nvidia-compute-utils-535 libnvidia-fbc1-535 nvidia-driver-535

ipip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install pathlib torch click ray hyperopt
pip install git+https://github.com/scverse/scvi-tools.git


# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# pip install nvidia-cusolver-cu11

pip install scikit-misc
pip install git+https://github.com/chanzuckerberg/cellxgene-census#subdirectory=api/python/cellxgene_census
62 changes: 62 additions & 0 deletions tools/models/scvi/scvi-prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import functools

import cellxgene_census
import tiledbsoma as soma
import yaml
from cellxgene_census.experimental.pp import highly_variable_genes

file = "scvi-config.yaml"

if __name__ == "__main__":
with open(file) as f:
config = yaml.safe_load(f)

census = cellxgene_census.open_soma(census_version="latest")

census_config = config.get("census")
experiment_name = census_config.get("organism")
obs_query = census_config.get("obs_query")
obs_query_model = census_config.get("obs_query_model")

if obs_query is None:
obs_value_filter = obs_query_model
else:
obs_value_filter = f"{obs_query} and {obs_query_model}"

query = census["census_data"][experiment_name].axis_query(
measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter)
)

hvg_config = config.get("hvg")
top_n_hvg = hvg_config.get("top_n_hvg")
hvg_batch = hvg_config.get("hvg_batch")
min_genes = hvg_config.get("min_genes")

print("Starting hvg selection")

hvgs_df = highly_variable_genes(query, n_top_genes=top_n_hvg, batch_key=hvg_batch)

hv = hvgs_df.highly_variable

hv.to_pickle("hv_genes.pkl")
hv_idx = hv[hv].index

query = census["census_data"][experiment_name].axis_query(
measurement_name="RNA",
obs_query=soma.AxisQuery(value_filter=obs_value_filter),
var_query=soma.AxisQuery(coords=(list(hv_idx),)),
)

print("Converting to AnnData")
ad = query.to_anndata(X_name="raw")

adata_config = config["anndata"]
batch_key = adata_config.get("batch_key")
filename = adata_config.get("model_filename")

ad.obs["batch"] = functools.reduce(lambda a, b: a + b, [ad.obs[c].astype(str) for c in batch_key])
ad.var.set_index("feature_id", inplace=True)

print("AnnData conversion completed. Saving...")
ad.write_h5ad(filename, compression="gzip")
print("AnnData saved")
72 changes: 72 additions & 0 deletions tools/models/scvi/scvi-train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import anndata as ad
import scvi
import yaml
from lightning.pytorch.loggers import TensorBoardLogger

file = "scvi-config.yaml"

if __name__ == "__main__":
with open(file) as f:
config = yaml.safe_load(f)

print("Start SCVI run")

# scvi settings
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

adata_config = config["anndata"]
filename = adata_config.get("model_filename")

adata = ad.read_h5ad(filename)

scvi.model.SCVI.setup_anndata(adata, batch_key="batch")

model_config = config.get("model")
n_hidden = model_config.get("n_hidden")
n_latent = model_config.get("n_latent")
n_layers = model_config.get("n_layers")
dropout_rate = model_config.get("dropout_rate")
filename = model_config.get("filename")

print("Configure model")

model = scvi.model.SCVI(adata, n_layers=n_layers, n_latent=n_latent, gene_likelihood="nb", encode_covariates=True)

train_config = config.get("train")
max_epochs = train_config.get("max_epochs")
batch_size = train_config.get("batch_size")
train_size = train_config.get("train_size")
early_stopping = train_config.get("early_stopping")
devices = train_config.get("devices")
multi_gpu = train_config.get("multi_gpu", False)

trainer_config = train_config.get("trainer")

training_plan_config = config.get("training_plan")

if multi_gpu:
scvi.settings.dl_num_workers = train_config.get("num_workers")
strategy = "ddp_find_unused_parameters_true"
devices = devices
else:
strategy = "auto"
devices = 1

print("Start training model")

logger = TensorBoardLogger("tb_logs", name="my_model")

model.train(
max_epochs=max_epochs,
batch_size=batch_size,
train_size=train_size,
early_stopping=early_stopping,
plan_kwargs=training_plan_config,
strategy=strategy, # Required for Multi-GPU training.
devices=devices,
logger=logger,
**trainer_config,
)

model.save(filename)
4 changes: 4 additions & 0 deletions tools/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ warn_unreachable = true
strict = true
plugins = "numpy.typing.mypy_plugin"

[[tool.mypy.overrides]]
module = "tools.models.scvi.*"
ignore_errors = true

[tool.ruff]
select = ["E", "F", "B", "I"]
ignore = ["E501", "E402", "C408", ]
Expand Down

0 comments on commit 7009201

Please sign in to comment.