-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: pablo-gar <[email protected]>
- Loading branch information
Showing
8 changed files
with
332 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Census trained scVI model | ||
|
||
## Training Pipeline | ||
|
||
This directory contains a set of scripts that can be used to train [scVI](https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html) on the whole Census data, and to generate its latent space representation embeddings. | ||
|
||
The model can be trained separately on each experiment (`homo_sapiens` and `mus_musculus`), and produces separate artifacts. | ||
|
||
In order to run the training pipeline, three separate files are provided: | ||
|
||
[scvi-prepare.py](scvi-prepare.py) | ||
|
||
This file prepares an AnnData file that can be fed to the scVI trainer directly. The preparation strategy is: | ||
|
||
1. Take all the primary cells from the census with a gene count greater or equal than 300 (`is_primary_data == True and nnz >= 300`). | ||
1. Extract the top 8000 highly variable genes (using the Census `highly_variable_genes` function). Those are serialized to a `hv_genes.pkl` numpy.ndarray. | ||
1. A batch_key column is created by concatenating the `[dataset_id, assay, suspension_type, donor_id]` covariates | ||
|
||
The output of this file is an `anndata_model.h5ad` file. | ||
|
||
[scvi-train.py](scvi-train.py) | ||
|
||
This file takes the AnnData file from the previous step and trains an scVI model on it. See [scvi-config.yaml](scvi.config.yaml) for an up-to-date list of model and training parameters. | ||
|
||
The resulting model weights are saved to an `scvi.model` directory. Tensorboard logs are also available as part of the output. | ||
|
||
[scvi-create-latent-update.py](scvi-create-latent-update.py) | ||
|
||
This file takes the previously generated model and obtains the latent space representation to generate cell embeddings. The generation strategy is: | ||
|
||
1. Take all the cells from the Census (since we want to generate embeddings for every cell) | ||
1. Take the same highly variable genes from the `prepare` step | ||
1. Generate an AnnData with the same properties as the `prepare` step | ||
1. Call the `scvi.model.SCVI.load_query_data()` function on this AnnData. This allows to work on a dataset that has more cells than the one the model is trained on (which is required so that the model doesn't need to be re-trained from scratch on each Census version). A further pass of training is possible, but we just set `is_trained = True` to skip it. | ||
1. We call `get_latent_representation()` to generate the embeddings | ||
1. Both the final h5ad file, the embeddings and the cell index are saved as part of the output. | ||
|
||
## Selection of model parameters | ||
|
||
The final selection of parameters for the training phase was based on a hyper parameter search as described in the [CELLxGENE Discover Census scvi-tools initial autotune report](https://github.com/YosefLab/census-scvi/blob/main/experiments/autotune/notebooks/2023_09_autotune_report.ipynb) | ||
|
||
## Environment setup | ||
|
||
The training has been performed on an AWS EC2 machine (instance type: g4dn.12xlarge), running on Ubuntu 20.04. Run [scvi-init.sh](scvi-init.sh) to set up the environment required to run the pipeline. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
census: | ||
organism: | ||
"mus_musculus" | ||
obs_query: # Use if you want to train on a subset of the model | ||
null | ||
obs_query_model: # Required when loading data for model training. Do not change. | ||
'is_primary_data == True and nnz >= 300' | ||
hvg: | ||
top_n_hvg: | ||
8000 | ||
hvg_batch: | ||
[suspension_type, assay] | ||
anndata: | ||
batch_key: | ||
[dataset_id, assay, suspension_type, donor_id] | ||
model_filename: | ||
anndata_model.h5ad | ||
model: | ||
filename: "scvi.model" | ||
n_hidden: 512 | ||
n_latent: 200 | ||
n_layers: 1 | ||
dropout_rate: 0.1 | ||
train: | ||
max_epochs: 20 | ||
batch_size: 1048 | ||
train_size: 0.95 | ||
early_stopping: True | ||
trainer: | ||
early_stopping_patience: 2 | ||
early_stopping_monitor: validation_loss # should be validation_loss - see https://github.com/chanzuckerberg/cellxgene-census/issues/777#issuecomment-1743196837 | ||
check_val_every_n_epoch: 1 | ||
multi_gpu: False | ||
num_workers: 4 | ||
devices: [0, 1, 2, 3] | ||
training_plan: | ||
lr: 1.0e-4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import functools | ||
import gc | ||
|
||
import cellxgene_census | ||
import numpy as np | ||
import pandas as pd | ||
import scvi | ||
import tiledbsoma as soma | ||
import yaml | ||
|
||
file = "scvi-config.yaml" | ||
|
||
if __name__ == "__main__": | ||
with open(file) as f: | ||
config = yaml.safe_load(f) | ||
|
||
census = cellxgene_census.open_soma(census_version="latest") | ||
|
||
census_config = config.get("census") | ||
experiment_name = census_config.get("organism") | ||
obs_value_filter = census_config.get("obs_query") | ||
|
||
hv = pd.read_pickle("hv_genes.pkl") | ||
hv_idx = hv[hv].index | ||
|
||
if obs_value_filter is not None: | ||
obs_query = soma.AxisQuery(value_filter=obs_value_filter) | ||
else: | ||
obs_query = None | ||
|
||
query = census["census_data"][experiment_name].axis_query( | ||
measurement_name="RNA", | ||
obs_query=obs_query, | ||
var_query=soma.AxisQuery(coords=(list(hv_idx),)), | ||
) | ||
|
||
adata_config = config["anndata"] | ||
batch_key = adata_config.get("batch_key") | ||
ad_filename = adata_config.get("model_filename") | ||
|
||
print("Converting to AnnData") | ||
ad = query.to_anndata(X_name="raw") | ||
ad.obs["batch"] = functools.reduce(lambda a, b: a + b, [ad.obs[c].astype(str) for c in batch_key]) | ||
|
||
ad.var.set_index("feature_id", inplace=True) | ||
|
||
idx = query.obs(column_names=["soma_joinid"]).concat().to_pandas().index.to_numpy() | ||
|
||
del census, query, hv, hv_idx | ||
gc.collect() | ||
|
||
model_config = config.get("model") | ||
model_filename = model_config.get("filename") | ||
n_latent = model_config.get("n_latent") | ||
|
||
scvi.model.SCVI.prepare_query_anndata(ad, model_filename) | ||
|
||
vae_q = scvi.model.SCVI.load_query_data( | ||
ad, | ||
model_filename, | ||
) | ||
# Uncomment #1 if you want to do a forward pass with an additional training epoch. | ||
# Uncomment #2 if you want to do a forward pass without additional training. | ||
|
||
# vae_q.train(max_epochs=1, plan_kwargs=dict(weight_decay=0.0)) # 1 | ||
vae_q.is_trained = True # 2 | ||
latent = vae_q.get_latent_representation() | ||
|
||
ad.write_h5ad("anndata-full.h5ad", compression="gzip") | ||
|
||
del vae_q, ad | ||
gc.collect() | ||
|
||
with open("latent-idx.npy", "wb") as f: | ||
np.save(f, idx) | ||
|
||
with open("latent.npy", "wb") as f: | ||
np.save(f, latent) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/bin/bash | ||
|
||
# Can be used to bootstrap a g4dn.* instance with scvi-tools and cellxgene-census | ||
|
||
export DEBIAN_FRONTEND=noninteractive | ||
|
||
sudo apt -y update | ||
sudo apt -y install python3-pip | ||
sudo add-apt-repository -y ppa:deadsnakes/ppa | ||
sudo apt -y update | ||
sudo apt -y install python3.11 | ||
sudo apt -y install python3.11-venv | ||
curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11 | ||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python2 1 | ||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.8 2 | ||
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.11 3 | ||
sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 | ||
sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 2 | ||
sudo update-alternatives --config python | ||
sudo update-alternatives --config python3 | ||
sudo cp /usr/lib/python3/dist-packages/apt_pkg.cpython-38-x86_64-linux-gnu.so /usr/lib/python3/dist-packages/apt_pkg.so | ||
|
||
sudo apt -y install libnvidia-gl-535 libnvidia-common-535 libnvidia-compute-535 libnvidia-encode-535 libnvidia-decode-535 nvidia-compute-utils-535 libnvidia-fbc1-535 nvidia-driver-535 | ||
|
||
ipip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
pip install pathlib torch click ray hyperopt | ||
pip install git+https://github.com/scverse/scvi-tools.git | ||
|
||
|
||
# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 | ||
# pip install nvidia-cusolver-cu11 | ||
|
||
pip install scikit-misc | ||
pip install git+https://github.com/chanzuckerberg/cellxgene-census#subdirectory=api/python/cellxgene_census |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import functools | ||
|
||
import cellxgene_census | ||
import tiledbsoma as soma | ||
import yaml | ||
from cellxgene_census.experimental.pp import highly_variable_genes | ||
|
||
file = "scvi-config.yaml" | ||
|
||
if __name__ == "__main__": | ||
with open(file) as f: | ||
config = yaml.safe_load(f) | ||
|
||
census = cellxgene_census.open_soma(census_version="latest") | ||
|
||
census_config = config.get("census") | ||
experiment_name = census_config.get("organism") | ||
obs_query = census_config.get("obs_query") | ||
obs_query_model = census_config.get("obs_query_model") | ||
|
||
if obs_query is None: | ||
obs_value_filter = obs_query_model | ||
else: | ||
obs_value_filter = f"{obs_query} and {obs_query_model}" | ||
|
||
query = census["census_data"][experiment_name].axis_query( | ||
measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter) | ||
) | ||
|
||
hvg_config = config.get("hvg") | ||
top_n_hvg = hvg_config.get("top_n_hvg") | ||
hvg_batch = hvg_config.get("hvg_batch") | ||
min_genes = hvg_config.get("min_genes") | ||
|
||
print("Starting hvg selection") | ||
|
||
hvgs_df = highly_variable_genes(query, n_top_genes=top_n_hvg, batch_key=hvg_batch) | ||
|
||
hv = hvgs_df.highly_variable | ||
|
||
hv.to_pickle("hv_genes.pkl") | ||
hv_idx = hv[hv].index | ||
|
||
query = census["census_data"][experiment_name].axis_query( | ||
measurement_name="RNA", | ||
obs_query=soma.AxisQuery(value_filter=obs_value_filter), | ||
var_query=soma.AxisQuery(coords=(list(hv_idx),)), | ||
) | ||
|
||
print("Converting to AnnData") | ||
ad = query.to_anndata(X_name="raw") | ||
|
||
adata_config = config["anndata"] | ||
batch_key = adata_config.get("batch_key") | ||
filename = adata_config.get("model_filename") | ||
|
||
ad.obs["batch"] = functools.reduce(lambda a, b: a + b, [ad.obs[c].astype(str) for c in batch_key]) | ||
ad.var.set_index("feature_id", inplace=True) | ||
|
||
print("AnnData conversion completed. Saving...") | ||
ad.write_h5ad(filename, compression="gzip") | ||
print("AnnData saved") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import anndata as ad | ||
import scvi | ||
import yaml | ||
from lightning.pytorch.loggers import TensorBoardLogger | ||
|
||
file = "scvi-config.yaml" | ||
|
||
if __name__ == "__main__": | ||
with open(file) as f: | ||
config = yaml.safe_load(f) | ||
|
||
print("Start SCVI run") | ||
|
||
# scvi settings | ||
scvi.settings.seed = 0 | ||
print("Last run with scvi-tools version:", scvi.__version__) | ||
|
||
adata_config = config["anndata"] | ||
filename = adata_config.get("model_filename") | ||
|
||
adata = ad.read_h5ad(filename) | ||
|
||
scvi.model.SCVI.setup_anndata(adata, batch_key="batch") | ||
|
||
model_config = config.get("model") | ||
n_hidden = model_config.get("n_hidden") | ||
n_latent = model_config.get("n_latent") | ||
n_layers = model_config.get("n_layers") | ||
dropout_rate = model_config.get("dropout_rate") | ||
filename = model_config.get("filename") | ||
|
||
print("Configure model") | ||
|
||
model = scvi.model.SCVI(adata, n_layers=n_layers, n_latent=n_latent, gene_likelihood="nb", encode_covariates=True) | ||
|
||
train_config = config.get("train") | ||
max_epochs = train_config.get("max_epochs") | ||
batch_size = train_config.get("batch_size") | ||
train_size = train_config.get("train_size") | ||
early_stopping = train_config.get("early_stopping") | ||
devices = train_config.get("devices") | ||
multi_gpu = train_config.get("multi_gpu", False) | ||
|
||
trainer_config = train_config.get("trainer") | ||
|
||
training_plan_config = config.get("training_plan") | ||
|
||
if multi_gpu: | ||
scvi.settings.dl_num_workers = train_config.get("num_workers") | ||
strategy = "ddp_find_unused_parameters_true" | ||
devices = devices | ||
else: | ||
strategy = "auto" | ||
devices = 1 | ||
|
||
print("Start training model") | ||
|
||
logger = TensorBoardLogger("tb_logs", name="my_model") | ||
|
||
model.train( | ||
max_epochs=max_epochs, | ||
batch_size=batch_size, | ||
train_size=train_size, | ||
early_stopping=early_stopping, | ||
plan_kwargs=training_plan_config, | ||
strategy=strategy, # Required for Multi-GPU training. | ||
devices=devices, | ||
logger=logger, | ||
**trainer_config, | ||
) | ||
|
||
model.save(filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters