From e11f4fd53e47125161782d1e8d24036130180563 Mon Sep 17 00:00:00 2001 From: Dorien <41797896+dorien-er@users.noreply.github.com> Date: Thu, 5 Dec 2024 10:49:55 +0100 Subject: [PATCH] Cell type annotation: scGPT workflow (#832) Co-authored-by: DriesSchaumont <5946712+DriesSchaumont@users.noreply.github.com> Co-authored-by: Sarah <68101181+SarahOuologuem@users.noreply.github.com> Co-authored-by: Vladimir Shitov <35199218+VladimirShitov@users.noreply.github.com> Co-authored-by: Jakub Majercik <57993790+jakubmajercik@users.noreply.github.com> Co-authored-by: Robrecht Cannoodt --- CHANGELOG.md | 18 +- resources_test_scripts/scgpt.sh | 32 ++ src/scgpt/binning/config.vsh.yaml | 16 +- src/scgpt/binning/script.py | 36 +- src/scgpt/binning/test.py | 25 +- src/scgpt/cell_type_annotation/script.py | 381 +++++++++--------- src/scgpt/cross_check_genes/config.vsh.yaml | 11 +- src/scgpt/cross_check_genes/script.py | 37 +- src/scgpt/cross_check_genes/test.py | 34 +- src/scgpt/pad_tokenize/config.vsh.yaml | 20 +- src/scgpt/pad_tokenize/script.py | 32 +- src/scgpt/pad_tokenize/test.py | 72 +--- .../scgpt_annotation/config.vsh.yaml | 195 +++++++++ .../scgpt_annotation/integration_test.sh | 15 + .../annotation/scgpt_annotation/main.nf | 112 +++++ .../scgpt_annotation/nextflow.config | 10 + .../annotation/scgpt_annotation/test.nf | 55 +++ .../annotation/scgpt/config.vsh.yaml | 40 ++ .../test_workflows/annotation/scgpt/script.py | 39 ++ 19 files changed, 857 insertions(+), 323 deletions(-) create mode 100644 src/workflows/annotation/scgpt_annotation/config.vsh.yaml create mode 100755 src/workflows/annotation/scgpt_annotation/integration_test.sh create mode 100644 src/workflows/annotation/scgpt_annotation/main.nf create mode 100644 src/workflows/annotation/scgpt_annotation/nextflow.config create mode 100644 src/workflows/annotation/scgpt_annotation/test.nf create mode 100644 src/workflows/test_workflows/annotation/scgpt/config.vsh.yaml create mode 100644 src/workflows/test_workflows/annotation/scgpt/script.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a05fcb1241c..dece4d32675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,25 @@ # openpipelines x.x.x -# MINOR CHANGES +## BREAKING CHANGES + +* Several components under `src/scgpt` (`cross_check_genes`, `tokenize_pad`, `binning`) now processes the input (query) datasets differently. Instead of subsetting datasets based on genes in the model vocabulary and/or highly variable genes, these components require an input .var column with a boolean mask specifying this information. The results are written back to the original input data, preserving the dataset structure (PR #832). + +## NEW FUNCTIONALITY + +* `scgpt/cell_type_annotation` component update: Added support for multi-processing (PR #832). + +## MINOR CHANGES * Several component (cleanup): remove workaround for using being able to use shared utility functions with Nextflow Fusion (PR #920). +* `workflows/annotation/scgpt_annotation` workflow: Added a scGPT transformer-based cell type annotation workflow (PR #832). + +* `scgpt/cross_check_genes` component update: Highly variable genes are now cross-checked based on the boolean mask in `var_input`. The filtering information is stored in the `--output_var_filter` .var field instead of subsetting the dataset (PR #832). + +* `scgpt/binning` component update: This component now requires the `--var_input` parameter to provide gene filtering information. Binned data is written to the `--output_obsm_binned_counts` .obsm field in the original input data (PR #832). + +* `scgpt/pad_tokenize` component update: Genes are padded and tokenized based on filtering information in `--var_input` and `--input_obsm_binned_counts` (PR #832). + # openpipelines 2.0.0-rc.2 ## BUG FIXES diff --git a/resources_test_scripts/scgpt.sh b/resources_test_scripts/scgpt.sh index ead65138dd0..f6cd89a14e1 100644 --- a/resources_test_scripts/scgpt.sh +++ b/resources_test_scripts/scgpt.sh @@ -11,6 +11,12 @@ OUT=resources_test/$ID # create foundational model directory foundation_model_dir="$OUT/source" mkdir -p "$foundation_model_dir" +export foundation_model_dir + +# create finetuned model directory +finetuned_model_dir="$OUT/finetuned_model" +mkdir -p "$finetuned_model_dir" +export finetuned_model_dir # install gdown if necessary # Check whether gdown is available @@ -19,6 +25,13 @@ if ! command -v gdown &> /dev/null; then exit 1 fi +# install torch if necessary +# Check whether torch is available +if ! command -v torch &> /dev/null; then + echo "This script requires torch. Please make sure the binary is added to your PATH." + exit 1 +fi + echo "> Downloading scGPT foundation model (full_human)" # download foundational model files (full_human) # https://drive.google.com/drive/folders/1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y @@ -26,6 +39,25 @@ gdown '1H3E_MJ-Dl36AQV6jLbna2EdvgPaqvqcC' -O "${foundation_model_dir}/vocab.json gdown '1hh2zGKyWAx3DyovD30GStZ3QlzmSqdk1' -O "${foundation_model_dir}/args.json" gdown '14AebJfGOUF047Eg40hk57HCtrb0fyDTm' -O "${foundation_model_dir}/best_model.pt" +echo "> Converting to finetuned model format" +python < np.ndarray: digits = np.ceil(digits) smallest_dtype = np.min_scalar_type(digits.max().astype(np.uint)) # Already checked for non-negative values digits = digits.astype(smallest_dtype) - + return digits @@ -78,32 +87,31 @@ def _digitize(x: np.ndarray, bins: np.ndarray) -> np.ndarray: "this is expected. You can use the `filter_cell_by_counts` " "arg to filter out all zero rows." ) - + # Add binned_rows and bin_edges as all 0 # np.stack will upcast the dtype later binned_rows.append(np.zeros_like(non_zero_row, dtype=np.int8)) bin_edges.append(np.array([0] * n_bins)) continue - + # Binning of non-zero values bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1)) non_zero_digits = _digitize(non_zero_row, bins) assert non_zero_digits.min() >= 1 assert non_zero_digits.max() <= n_bins - 1 binned_rows.append(non_zero_digits) - + bin_edges.append(np.concatenate([[0], bins])) # Create new CSR matrix logger.info("Creating a new CSR matrix of the binned count values") -binned_layer = csr_matrix((np.concatenate(binned_rows, casting="same_kind"), +binned_counts = csr_matrix((np.concatenate(binned_rows, casting="same_kind"), layer_data.indices, layer_data.indptr), shape=layer_data.shape) # Set binned values and bin edges layers to adata object -adata.layers[par["binned_layer"]] = binned_layer -adata.obsm["bin_edges"] = np.stack(bin_edges) - +input_adata.obsm[par["output_obsm_binned_counts"]] = binned_counts +input_adata.obsm["bin_edges"] = np.stack(bin_edges) + # Write mudata output logger.info("Writing output data") -mdata.mod[par["modality"]] = adata -mdata.write(par["output"], compression=par["output_compression"]) \ No newline at end of file +mdata. write_h5mu(par["output"], compression=par["output_compression"]) diff --git a/src/scgpt/binning/test.py b/src/scgpt/binning/test.py index a54422f6c54..0c5f5faaf9c 100644 --- a/src/scgpt/binning/test.py +++ b/src/scgpt/binning/test.py @@ -14,15 +14,16 @@ def test_binning(run_component, tmp_path): - - input_file_path = f"{meta['resources_dir']}/Kim2020_Lung_subset.h5mu" + + input_file_path = f"{meta['resources_dir']}/Kim2020_Lung_subset_preprocessed.h5mu" output_file_path = tmp_path / "Kim2020_Lung_subset_binned.h5mu" run_component([ "--input", input_file_path, "--modality", "rna", - "--binned_layer", "binned", + "--output_obsm_binned_counts", "binned_counts", "--n_input_bins", "51", + "--var_input", "filter_with_hvg", "--output", output_file_path ]) @@ -31,21 +32,19 @@ def test_binning(run_component, tmp_path): output_adata = output_mdata.mod["rna"] # Check presence of binning layers - assert "bin_edges" in output_adata.obsm.keys() - assert "binned" in output_adata.layers.keys() - + assert {"bin_edges", "binned_counts"}.issubset(output_adata.obsm.keys()), "Binning obsm fields were not added." + # Check bin edges bin_edges = output_adata.obsm["bin_edges"] assert all(bin_edges[:, 0] == 0) assert bin_edges.shape[1] == 51 - assert all(all(i>=0) for i in bin_edges) - + assert all(all(i >= 0) for i in bin_edges) + # Check binned values - binned_values = output_adata.layers["binned"] + binned_values = output_adata.obsm["binned_counts"] assert issparse(binned_values) - assert binned_values.shape == output_adata.X.shape assert (binned_values.data <= 51).all(axis=None) - - + + if __name__ == '__main__': - sys.exit(pytest.main([__file__])) \ No newline at end of file + sys.exit(pytest.main([__file__])) diff --git a/src/scgpt/cell_type_annotation/script.py b/src/scgpt/cell_type_annotation/script.py index df2bd68c13a..9a2d79436af 100644 --- a/src/scgpt/cell_type_annotation/script.py +++ b/src/scgpt/cell_type_annotation/script.py @@ -1,5 +1,6 @@ import sys import json +from multiprocessing import freeze_support import os import mudata as mu from typing import Dict @@ -11,12 +12,13 @@ from scgpt.model import TransformerModel from scgpt.tokenizer.gene_tokenizer import GeneVocab from scgpt.utils import set_seed +from tqdm import tqdm ## VIASH START par = { 'input': r'resources_test/scgpt/test_resources/Kim2020_Lung_subset_tokenized.h5mu', 'modality': r'rna', - 'model': r'resources_test/scgpt/source/best_model.pt', + 'model': r'resources_test/scgpt/finetuned_model/best_model.pt', 'model_config': r'resources_test/scgpt/source/args.json', 'model_vocab': r'resources_test/scgpt/source/vocab.json', 'obs_batch_label': r'sample', @@ -24,15 +26,15 @@ 'obsm_tokenized_values': r'values_tokenized', 'output': r'output.h5mu', 'output_compression': None, - 'obs_predicted_cell_class': r'predicted_cell_class', - 'obs_predicted_cell_label': r'predicted_cell_label', + 'output_obs_predictions': r'predictions', + 'output_obs_probability': r'probabilities', 'dsbn': True, 'seed': 0, 'pad_token': "", 'pad_value': -2, 'n_input_bins': 51, 'batch_size': 64, - 'finetuned_checkpoints_key': 'mapping_dic', + 'finetuned_checkpoints_key': 'model_state_dict', 'label_mapper_key': 'id_to_class' } @@ -52,188 +54,193 @@ def __len__(self): def __getitem__(self, idx): return {k: v[idx] for k, v in self.data.items()} -warnings.filterwarnings('ignore') - -# Setting seed -if par["seed"]: - set_seed(par["seed"]) - -# Setting device -logger.info(f"Setting device to {'cuda' if torch.cuda.is_available() else 'cpu'}") -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# Read in data -logger.info("Reading in data") -mdata = mu.read(par["input"]) -input_adata = mdata.mod[par["modality"]] -adata = input_adata.copy() - -# Fetch batch ids for domain-specific batch normalization -if par["dsbn"] and not par["obs_batch_label"]: - raise ValueError("When dsbn is set to True, you are required to provide batch labels (obs_batch_labels).") -elif par["dsbn"] and par["obs_batch_label"]: - logger.info("Fetching batch id's for domain-specific batch normalization") - batch_id_cats = adata.obs[par["obs_batch_label"]].astype("category") - batch_id_labels = batch_id_cats.cat.codes.values - batch_ids = batch_id_labels.tolist() - batch_ids = np.array(batch_ids) - num_batch_types = len(set(batch_ids)) -elif not par["dsbn"]: - # forward pass requires a tensor as input - batch_ids = np.zeros(adata.shape[0]) - -# Vocabulary configuration -logger.info("Loading model vocabulary") -special_tokens = [par["pad_token"], "", ""] -logger.info(f"Loading model vocab from {par['model_vocab']}") -vocab_file = par["model_vocab"] -vocab = GeneVocab.from_file(vocab_file) -[vocab.append_token(s) for s in special_tokens if s not in vocab] -vocab.set_default_index(vocab[par["pad_token"]]) -ntokens = len(vocab) - -# Model configuration -logger.info("Loading model and configurations") -model_config_file = par["model_config"] -with open(model_config_file, "r") as f: - model_configs = json.load(f) -embsize = model_configs["embsize"] -nhead = model_configs["nheads"] -d_hid = model_configs["d_hid"] -nlayers = model_configs["nlayers"] - -# Ensure the provided model has the correct architecture -logger.info("Loading model") -model_file = par["model"] -model_dict = torch.load(model_file, map_location=device) -for k, v in { - "--finetuned_checkpoints_key": par["finetuned_checkpoints_key"], - "--label_mapper_key": par["label_mapper_key"], - }.items(): - if v not in model_dict.keys(): - raise KeyError(f"The key '{v}' provided for '{k}' could not be found in the provided --model file. The finetuned model file for cell type annotation requires valid keys for the checkpoints and the label mapper.") -pretrained_dict = model_dict[par["finetuned_checkpoints_key"]] - -# Label mapper configuration -logger.info("Loading label mapper") -label_mapper = model_dict[par["label_mapper_key"]] -cell_type_mapper = {int(k): v for k, v in label_mapper.items()} -n_cls = len(cell_type_mapper) - -# Model instatiation -logger.info("Instantiating model") -model = TransformerModel( - ntokens, - d_model=embsize, # self.encoder (GenEncoder), self.value_encoder (ContinuousValueEncoder), self.transformerencoder(TransformerEncoderLayer) - nhead=nhead, # self.transformer_encoder(TransformerEncoderLayer) - d_hid=d_hid, # self.transformer_encoder(TransformerEncoderLayer) - nlayers=nlayers, # self.transformer_encoder(TransformerEncoderLayer), self.cls_decoder - nlayers_cls=3, # self.cls_decoder - n_cls=n_cls, # self.cls_decoder - vocab=vocab, - dropout=0.2, # self.transformer_encoder - pad_token=par["pad_token"], - pad_value=par["pad_value"], - do_mvc=False, - do_dab=False, - use_batch_labels=par["dsbn"], - num_batch_labels=num_batch_types if par["dsbn"] else None, - domain_spec_batchnorm=par["dsbn"], - input_emb_style="continuous", - n_input_bins=par["n_input_bins"], - cell_emb_style="cls", # required for cell-type annotation - use_fast_transformer=False, #TODO: parametrize when GPU is available - fast_transformer_backend="flash", #TODO: parametrize when GPU is available - pre_norm=False, #TODO: parametrize when GPU is available -) - - -# Load model params -logger.info(f"Loading model params from {model_file}") -try: - model.load_state_dict(pretrained_dict) -except RuntimeError: - logger.info("only load params that are in the model and match the size") - model_dict = model.state_dict() - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if k in model_dict and v.shape == model_dict[k].shape + +def main(): + # Setting seed + if par["seed"]: + set_seed(par["seed"]) + + # Setting device + logger.info(f"Setting device to {'cuda' if torch.cuda.is_available() else 'cpu'}") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Read in data + logger.info("Reading in data") + mdata = mu.read(par["input"]) + input_adata = mdata.mod[par["modality"]] + adata = input_adata.copy() + + # Fetch batch ids for domain-specific batch normalization + if par["dsbn"] and not par["obs_batch_label"]: + raise ValueError("When dsbn is set to True, you are required to provide batch labels (obs_batch_labels).") + elif par["dsbn"] and par["obs_batch_label"]: + logger.info("Fetching batch id's for domain-specific batch normalization") + batch_id_cats = adata.obs[par["obs_batch_label"]].astype("category") + batch_id_labels = batch_id_cats.cat.codes.values + batch_ids = batch_id_labels.tolist() + batch_ids = np.array(batch_ids) + num_batch_types = len(set(batch_ids)) + elif not par["dsbn"]: + # forward pass requires a tensor as input + batch_ids = np.zeros(adata.shape[0]) + + # Vocabulary configuration + logger.info("Loading model vocabulary") + special_tokens = [par["pad_token"], "", ""] + logger.info(f"Loading model vocab from {par['model_vocab']}") + vocab_file = par["model_vocab"] + vocab = GeneVocab.from_file(vocab_file) + [vocab.append_token(s) for s in special_tokens if s not in vocab] + vocab.set_default_index(vocab[par["pad_token"]]) + ntokens = len(vocab) + + # Model configuration + logger.info("Loading model and configurations") + model_config_file = par["model_config"] + with open(model_config_file, "r") as f: + model_configs = json.load(f) + embsize = model_configs["embsize"] + nhead = model_configs["nheads"] + d_hid = model_configs["d_hid"] + nlayers = model_configs["nlayers"] + + # Ensure the provided model has the correct architecture + logger.info("Loading model") + model_file = par["model"] + model_dict = torch.load(model_file, map_location=device) + for k, v in { + "--finetuned_checkpoints_key": par["finetuned_checkpoints_key"], + "--label_mapper_key": par["label_mapper_key"], + }.items(): + if v not in model_dict.keys(): + raise KeyError(f"The key '{v}' provided for '{k}' could not be found in the provided --model file. The finetuned model file for cell type annotation requires valid keys for the checkpoints and the label mapper.") + pretrained_dict = model_dict[par["finetuned_checkpoints_key"]] + + # Label mapper configuration + logger.info("Loading label mapper") + label_mapper = model_dict[par["label_mapper_key"]] + cell_type_mapper = {int(k): v for k, v in label_mapper.items()} + n_cls = len(cell_type_mapper) + + # Model instatiation + logger.info("Instantiating model") + model = TransformerModel( + ntokens, + d_model=embsize, # self.encoder (GenEncoder), self.value_encoder (ContinuousValueEncoder), self.transformerencoder(TransformerEncoderLayer) + nhead=nhead, # self.transformer_encoder(TransformerEncoderLayer) + d_hid=d_hid, # self.transformer_encoder(TransformerEncoderLayer) + nlayers=nlayers, # self.transformer_encoder(TransformerEncoderLayer), self.cls_decoder + nlayers_cls=3, # self.cls_decoder + n_cls=n_cls, # self.cls_decoder + vocab=vocab, + dropout=0.2, # self.transformer_encoder + pad_token=par["pad_token"], + pad_value=par["pad_value"], + do_mvc=False, + do_dab=False, + use_batch_labels=par["dsbn"], + num_batch_labels=num_batch_types if par["dsbn"] else None, + domain_spec_batchnorm=par["dsbn"], + input_emb_style="continuous", + n_input_bins=par["n_input_bins"], + cell_emb_style="cls", # required for cell-type annotation + use_fast_transformer=False, #TODO: parametrize when GPU is available + fast_transformer_backend="flash", #TODO: parametrize when GPU is available + pre_norm=False, #TODO: parametrize when GPU is available + ) + + + # Load model params + logger.info(f"Loading model params from {model_file}") + try: + model.load_state_dict(pretrained_dict) + except RuntimeError: + logger.info("only load params that are in the model and match the size") + model_dict = model.state_dict() + pretrained_dict = { + k: v + for k, v in pretrained_dict.items() + if k in model_dict and v.shape == model_dict[k].shape + } + for k, v in pretrained_dict.items(): + logger.info(f"Loading params {k} with shape {v.shape}") + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + + model.to(device) + + # Load tokenized gene data + logger.info("Loading data for inference") + for k, v in { + "--obsm_gene_tokens": par["obsm_gene_tokens"], + "--obsm_tokenized_values": par["obsm_tokenized_values"], + }.items(): + if v not in adata.obsm.keys(): + raise KeyError(f"The parameter '{v}' provided for '{k}' could not be found in adata.obsm") + + input_gene_ids = adata.obsm[par["obsm_gene_tokens"]] + input_values = adata.obsm[par["obsm_tokenized_values"]] + + data_pt = { + "gene_ids": input_gene_ids, + "values": input_values, + "batch_labels": torch.from_numpy(batch_ids).long(), } - for k, v in pretrained_dict.items(): - logger.info(f"Loading params {k} with shape {v.shape}") - model_dict.update(pretrained_dict) - model.load_state_dict(model_dict) - -model.to(device) - -# Load tokenized gene data -logger.info("Loading data for inference") -for k, v in { - "--obsm_gene_tokens": par["obsm_gene_tokens"], - "--obsm_tokenized_values": par["obsm_tokenized_values"], - }.items(): - if v not in adata.obsm.keys(): - raise KeyError(f"The parameter '{v}' provided for '{k}' could not be found in adata.obsm") - -input_gene_ids = adata.obsm[par["obsm_gene_tokens"]] -input_values = adata.obsm[par["obsm_tokenized_values"]] - -data_pt = { - "gene_ids": input_gene_ids, - "values": input_values, - "batch_labels": torch.from_numpy(batch_ids).long(), -} -data_loader = DataLoader( - dataset=SeqDataset(data_pt), - batch_size=par["batch_size"], - num_workers=min(os.cpu_count(), par["batch_size"] // 2), - pin_memory=True, -) - -# Inference -logger.info("Predicting cell type classes") -model.eval() -predictions = [] -probabilities = [] -confidences = [] -with torch.no_grad(): - for batch_data in data_loader: - input_gene_ids = batch_data["gene_ids"].to(device) - input_values = batch_data["values"].to(device) - batch_labels = batch_data["batch_labels"].to(device) - - src_key_padding_mask = input_gene_ids.eq(vocab[par["pad_token"]]) - with torch.cuda.amp.autocast(enabled=False): - output_dict = model( - input_gene_ids, - input_values, - src_key_padding_mask=src_key_padding_mask, - batch_labels=batch_labels if par["dsbn"] else None, - CLS=True, # Return celltype classification objective output - CCE=False, - MVC=False, - ECS=False, - ) - output_values = output_dict["cls_output"] - - preds = output_values.argmax(1).cpu().numpy() - predictions.append(preds) - - probs = functional.softmax(output_values, dim=1).max(1)[0] - probabilities.append(probs.cpu().numpy()) - -predictions = np.concatenate(predictions, axis=0) -probabilities = np.concatenate(probabilities, axis=0) - -# Assign cell type labels to predicted classes -logger.info("Assigning cell type predictions and probabilities") -adata.obs["scgpt_class_pred"] = predictions -adata.obs[par["output_obs_predictions"]] = adata.obs["scgpt_class_pred"].map(lambda x: cell_type_mapper[x]) -adata.obs[par["output_obs_probability"]] = probabilities - -# Write output -logger.info("Writing output data") -mdata.mod[par["modality"]] = adata -mdata.write(par["output"], compression=par["output_compression"]) \ No newline at end of file + data_loader = DataLoader( + dataset=SeqDataset(data_pt), + batch_size=par["batch_size"], + num_workers=min(os.cpu_count(), par["batch_size"] // 2), + pin_memory=True, + ) + + # Inference + logger.info("Predicting cell type classes") + model.eval() + predictions = [] + probabilities = [] + with torch.no_grad(): + for batch_data in tqdm(data_loader): + input_gene_ids = batch_data["gene_ids"].to(device) + input_values = batch_data["values"].to(device) + batch_labels = batch_data["batch_labels"].to(device) + + src_key_padding_mask = input_gene_ids.eq(vocab[par["pad_token"]]) + with torch.cuda.amp.autocast(enabled=False): + output_dict = model( + input_gene_ids, + input_values, + src_key_padding_mask=src_key_padding_mask, + batch_labels=batch_labels if par["dsbn"] else None, + CLS=True, # Return celltype classification objective output + CCE=False, + MVC=False, + ECS=False, + ) + output_values = output_dict["cls_output"] + + preds = output_values.argmax(1).cpu().numpy() + predictions.append(preds) + + probs = functional.softmax(output_values, dim=1).max(1)[0] + probabilities.append(probs.cpu().numpy()) + + predictions = np.concatenate(predictions, axis=0) + probabilities = np.concatenate(probabilities, axis=0) + + # Assign cell type labels to predicted classes + logger.info("Assigning cell type predictions and probabilities") + adata.obs["scgpt_class_pred"] = predictions + adata.obs[par["output_obs_predictions"]] = adata.obs["scgpt_class_pred"].map(lambda x: cell_type_mapper[x]) + adata.obs[par["output_obs_probability"]] = probabilities + + # Write output + logger.info("Writing output data") + mdata.mod[par["modality"]] = adata + mdata.write(par["output"], compression=par["output_compression"]) + + +if __name__ == '__main__': + freeze_support() + warnings.filterwarnings("ignore") + main() diff --git a/src/scgpt/cross_check_genes/config.vsh.yaml b/src/scgpt/cross_check_genes/config.vsh.yaml index c1bc34adf35..7ce8bd55de6 100644 --- a/src/scgpt/cross_check_genes/config.vsh.yaml +++ b/src/scgpt/cross_check_genes/config.vsh.yaml @@ -42,6 +42,10 @@ argument_groups: required: false description: | The name of the adata.var column containing gene names. By default the .var index will be used. + - name: "--var_input" + type: string + required: false + description: ".var column containing highly variable genes. If provided, will only cross-check HVG filtered genes with model vocabulary." - name: Outputs arguments: - name: "--output" @@ -56,6 +60,11 @@ argument_groups: choices: ["gzip", "lzf"] required: false example: "gzip" + - name: "--output_var_filter" + type: string + default: "id_in_vocab" + description: In which .var slot to store a boolean array corresponding to which observations should be filtered out based on HVG and model vocabulary. + - name: Arguments arguments: - name: "--pad_token" @@ -71,7 +80,7 @@ resources: test_resources: - type: python_script path: test.py - - path: /resources_test/scgpt/test_resources/Kim2020_Lung_subset.h5mu + - path: /resources_test/scgpt/test_resources/Kim2020_Lung_subset_preprocessed.h5mu - path: /resources_test/scgpt/source/vocab.json engines: diff --git a/src/scgpt/cross_check_genes/script.py b/src/scgpt/cross_check_genes/script.py index 4214a2fa4e6..360e984547b 100644 --- a/src/scgpt/cross_check_genes/script.py +++ b/src/scgpt/cross_check_genes/script.py @@ -1,16 +1,22 @@ import sys import mudata as mu -import numpy as np from scgpt.tokenizer.gene_tokenizer import GeneVocab ## VIASH START par = { - "input": "resources_test/scgpt/test_resources/Kim2020_Lung_subset.h5mu", + "input": "resources_test/scgpt/test_resources/Kim2020_Lung_subset_preprocessed.h5mu", "output": "output.h5mu", "modality": "rna", "input_var_gene_names": None, + "output_var_filter": "id_in_vocab", "pad_token": "", - "vocab_file": "resources_test/scgpt/source/vocab.json" + "var_input": "filter_with_hvg", + "vocab_file": "resources_test/scgpt/source/vocab.json", + "output_compression": None +} + +meta = { + "resources_dir": "src/utils" } ## VIASH END @@ -31,7 +37,7 @@ genes = adata.var.index.astype(str).tolist() elif par["input_var_gene_names"] not in adata.var.columns: raise ValueError(f"Gene name column '{par['input_var_gene_names']}' not found in .mod['{par['modality']}'].obs.") -else: +else: genes = adata.var[par["input_var_gene_names"]].astype(str).tolist() # Cross-check genes with pre-trained model @@ -40,17 +46,18 @@ vocab = GeneVocab.from_file(vocab_file) [vocab.append_token(s) for s in special_tokens if s not in vocab] -# vocab.append_token([s for s in special_tokens if s not in vocab]) - -logger.info("Filtering genes based on model vocab") -adata.var["id_in_vocab"] = [1 if gene in vocab else -1 for gene in genes] - -gene_ids_in_vocab = np.array(adata.var["id_in_vocab"]) - -logger.info("Subsetting input data based on genes present in model vocab") -adata = adata[:, adata.var["id_in_vocab"] >= 0] - -mudata.mod[par["modality"]] = adata +if par["var_input"]: + logger.info("Filtering genes based on model vocab and HVG") + filter_with_hvg = adata.var[par["var_input"]].tolist() + gene_filter_mask = [1 if gene in vocab and hvg else 0 for gene, hvg in zip(genes, filter_with_hvg)] + logger.info(f"Total number of genes after HVG present in model vocab: {str(sum(gene_filter_mask))}") +else: + logger.info("Filtering genes based on model vocab") + gene_filter_mask = [1 if gene in vocab else 0 for gene in genes] + logger.info(f"Total number of genes present in model vocab: {str(sum(gene_filter_mask))}") logger.info(f"Writing to {par['output']}") +adata.var[par["output_var_filter"]] = gene_filter_mask +adata.var[par["output_var_filter"]] = adata.var[par["output_var_filter"]].astype("bool") +mudata.mod[par["modality"]] = adata mudata.write_h5mu(par["output"], compression=par["output_compression"]) diff --git a/src/scgpt/cross_check_genes/test.py b/src/scgpt/cross_check_genes/test.py index 9af6ad36624..61e1e07cd67 100644 --- a/src/scgpt/cross_check_genes/test.py +++ b/src/scgpt/cross_check_genes/test.py @@ -12,9 +12,10 @@ } ## VIASH END -input_path = meta["resources_dir"] + "/Kim2020_Lung_subset.h5mu" +input_path = meta["resources_dir"] + "/Kim2020_Lung_subset_preprocessed.h5mu" vocab_path = meta["resources_dir"] + "/vocab.json" + def test_cross_check(run_component, random_path): output_path = random_path(extension="h5mu") args = [ @@ -25,17 +26,31 @@ def test_cross_check(run_component, random_path): "--output_compression", "gzip" ] run_component(args) - + output_mudata = read_h5mu(output_path) - input_mudata = read_h5mu(input_path) - + # Check added columns assert {"gene_name", "id_in_vocab"}.issubset(set(output_mudata.mod["rna"].var.columns)), "Gene columns were not added." # Check if genes were filtered - assert all(output_mudata.mod["rna"].var["id_in_vocab"] == 1), "Genes were not filtered." - # Check if number of observations is the same - assert output_mudata.mod["rna"].n_obs == input_mudata.mod["rna"].n_obs, "Number of observations changed." - assert output_mudata.n_obs == input_mudata.n_obs, "Number of observations changed." + assert sum(output_mudata.mod["rna"].var["id_in_vocab"]) != len(output_mudata.mod["rna"].var["id_in_vocab"]), "Genes were not filtered." + + output_hvg_path = random_path(extension="h5mu") + args_hvg = [ + "--input", input_path, + "--output", output_hvg_path, + "--modality", "rna", + "--var_input", "filter_with_hvg", + "--vocab_file", vocab_path, + "--output_compression", "gzip" + ] + + run_component(args_hvg) + + output_mudata_hvg = read_h5mu(output_hvg_path) + # Check if genes were filtered based on HVG + assert sum(output_mudata_hvg.mod["rna"].var["id_in_vocab"]) != len(output_mudata_hvg.mod["rna"].var["id_in_vocab"]), "Genes were not filtered." + assert sum(output_mudata.mod["rna"].var["id_in_vocab"]) != len(output_mudata_hvg.mod["rna"].var["id_in_vocab"]), "Genes were not filtered based on HVG." + def test_cross_check_invalid_gene_layer_raises(run_component, random_path): output_path = random_path(extension="h5mu") @@ -50,6 +65,7 @@ def test_cross_check_invalid_gene_layer_raises(run_component, random_path): run_component(args) assert re.search(r"ValueError: Gene name column 'dummy_var' not found in .mod\['rna'\]\.obs\.", err.value.stdout.decode('utf-8')) - + + if __name__ == '__main__': sys.exit(pytest.main([__file__])) \ No newline at end of file diff --git a/src/scgpt/pad_tokenize/config.vsh.yaml b/src/scgpt/pad_tokenize/config.vsh.yaml index 6f662ba9c05..ab40b4e9e96 100644 --- a/src/scgpt/pad_tokenize/config.vsh.yaml +++ b/src/scgpt/pad_tokenize/config.vsh.yaml @@ -31,17 +31,22 @@ argument_groups: example: vocab.json description: | Path to model vocabulary file. - - name: "--input_layer" + - name: "--var_gene_names" type: string - default: "binned" required: false description: | - The name of the layer to be padded and tokenized. - - name: "--var_gene_names" + The name of the .var column containing gene names. When no gene_name_layer is provided, the .var index will be used. + - name: "--var_input" type: string + default: "id_in_vocab" + description: | + The name of the adata.var column containing boolean mask for vocabulary-cross checked and/or highly variable genes. + - name: "--input_obsm_binned_counts" + type: string + default: "binned_counts" required: false description: | - The name of the .var column containing gene names. When no gene_name_layer is provided, the .var index will be used. + The name of the .obsm field containing the binned counts to be padded and tokenized. - name: Outputs arguments: @@ -98,6 +103,7 @@ resources: - type: python_script path: script.py - path: /src/utils/setup_logger.py + - path: /src/utils/subset_vars.py test_resources: - type: python_script path: test.py @@ -113,9 +119,7 @@ engines: packages: - scgpt==0.2.1 - ipython~=8.5.0 - test_setup: - - type: python - __merge__: [ /src/base/requirements/viashpy.yaml ] + __merge__: [ /src/base/requirements/python_test_setup.yaml, .] runners: - type: executable - type: nextflow diff --git a/src/scgpt/pad_tokenize/script.py b/src/scgpt/pad_tokenize/script.py index c9422edc40e..48dc924a3be 100644 --- a/src/scgpt/pad_tokenize/script.py +++ b/src/scgpt/pad_tokenize/script.py @@ -8,24 +8,33 @@ ## VIASH START par = { - "input": "resources_test/scgpt/test_resources/Kim2020_Lung_preprocessed.h5mu", + "input": "resources_test/scgpt/test_resources/Kim2020_Lung_subset_binned.h5mu", "model_vocab": "resources_test/scgpt/source/vocab.json", - "output": "resources_test/scgpt/test_resources/Kim2020_Lung_tokenized.h5mu", + "output": "resources_test/scgpt/test_resources/Kim2020_Lung_subset_tokenized.h5mu", "pad_token": "", "pad_value": -2, "modality": "rna", - "input_layer": "X_binned", + "input_obsm_binned_counts": "binned_counts", "max_seq_len": None, "var_gene_names": None, "obsm_gene_tokens": "gene_id_tokens", "obsm_tokenized_values": "values_tokenized", "obsm_padding_mask": "padding_mask", - "output_compression": None + "output_compression": None, + "var_input": "id_in_vocab" } +meta = { + "resources_dir": "src/utils/" +} + +# mdata = mu.read(par["input"]) +# mdata.mod["rna"].obsm["binned_counts"] = mdata.mod["rna"].layers["binned"] +# mdata.write_h5mu(par["input"]) ## VIASH END sys.path.append(meta["resources_dir"]) from setup_logger import setup_logger +from subset_vars import subset_vars logger = setup_logger() logger.info("Reading in data") @@ -35,6 +44,8 @@ input_adata = mdata.mod[par["modality"]] adata = input_adata.copy() +adata = subset_vars(adata, par["var_input"]) + # Set padding specs pad_token = par["pad_token"] special_tokens = [pad_token, "", ""] @@ -43,9 +54,9 @@ logger.info("Fetching counts and gene names") # Fetch counts all_counts = ( - adata.layers[par["input_layer"]].A - if issparse(adata.layers[par["input_layer"]]) - else adata.layers[par["input_layer"]] + adata.obsm[par["input_obsm_binned_counts"]].toarray() + if issparse(adata.obsm[par["input_obsm_binned_counts"]]) + else adata.obsm[par["input_obsm_binned_counts"]] ) # Fetching gene names @@ -92,9 +103,8 @@ padding_mask = all_gene_ids.eq(vocab[pad_token]) logger.info("Writing output data") -adata.obsm[par["obsm_gene_tokens"]] = all_gene_ids.numpy() -adata.obsm[par["obsm_tokenized_values"]] = all_values.numpy() -adata.obsm[par["obsm_padding_mask"]] = padding_mask.numpy() +input_adata.obsm[par["obsm_gene_tokens"]] = all_gene_ids.numpy() +input_adata.obsm[par["obsm_tokenized_values"]] = all_values.numpy() +input_adata.obsm[par["obsm_padding_mask"]] = padding_mask.numpy() -mdata.mod[par["modality"]] = adata mdata.write(par["output"], compression=par["output_compression"]) diff --git a/src/scgpt/pad_tokenize/test.py b/src/scgpt/pad_tokenize/test.py index 41a8d1bc5be..c1ab1a1a9e9 100644 --- a/src/scgpt/pad_tokenize/test.py +++ b/src/scgpt/pad_tokenize/test.py @@ -3,7 +3,6 @@ import mudata as mu import numpy as np from scgpt.tokenizer.gene_tokenizer import GeneVocab -from scgpt.preprocess import Preprocessor ## VIASH START meta = { @@ -14,69 +13,26 @@ } ## VIASH END -input = f"{meta['resources_dir']}/scgpt/test_resources/Kim2020_Lung_subset.h5mu" +input_file = f"{meta['resources_dir']}/scgpt/test_resources/Kim2020_Lung_subset_binned.h5mu" vocab_file = f"{meta['resources_dir']}/scgpt/source/vocab.json" -input_file = mu.read(input) +vocab = GeneVocab.from_file(vocab_file) -## START TEMPORARY WORKAROUND DATA PREPROCESSING -#TODO: Remove this workaround once scGPT preproc modules are implemented -# Read in data -adata = input_file.mod["rna"] -# Set tokens for integration -pad_token = "" -special_tokens = [pad_token, "", ""] +@pytest.fixture +def binned_h5mu(random_h5mu_path): + binned_h5mu_path = random_h5mu_path() + mdata = mu.read(input_file) + adata = mdata.mod["rna"] + adata.obsm["binned_counts"] = adata.layers["binned"] + mdata.write(binned_h5mu_path) + return binned_h5mu_path -# Make batch a category column -adata.obs["str_batch"] = adata.obs["sample"].astype(str) -batch_id_labels = adata.obs["str_batch"].astype("category").cat.codes.values -adata.obs["batch_id"] = batch_id_labels -adata.var["gene_name"] = adata.var.index.tolist() -# Load model vocab -vocab = GeneVocab.from_file(vocab_file) -for s in special_tokens: - if s not in vocab: - vocab.append_token(s) - -# Cross-check genes with pre-trained model -genes = adata.var["gene_name"].tolist() -adata.var["id_in_vocab"] = [ - 1 if gene in vocab else -1 for gene in adata.var["gene_name"] - ] -gene_ids_in_vocab = np.array(adata.var["id_in_vocab"]) -adata = adata[:, adata.var["id_in_vocab"] >= 0] - -# Preprocess data -preprocessor = Preprocessor( - use_key="X", - filter_gene_by_counts=3, - filter_cell_by_counts=False, - normalize_total=10000, - result_normed_key="X_normed", - log1p=True, - result_log1p_key="X_log1p", - subset_hvg=1200, - hvg_flavor="seurat_v3", - binning=51, - result_binned_key="binned", - ) - -preprocessor(adata, batch_key="str_batch") - -# copy results to mudata -input_file.mod["rna"] = adata - -## END TEMPORARY WORKAROUND DATA PREPROCESSING - - -def test_integration_pad_tokenize(run_component, tmp_path): +def test_integration_pad_tokenize(run_component, tmp_path, binned_h5mu): output = tmp_path / "Kim2020_Lung_tokenized.h5mu" - input_preprocessed = f"{meta['resources_dir']}/scgpt/test_resources/Kim2020_Lung_preprocessed.h5mu" - input_file.write(input_preprocessed) run_component([ - "--input", input_preprocessed, + "--input", binned_h5mu, "--output", output, "--modality", "rna", "--obsm_gene_tokens", "gene_id_tokens", @@ -84,7 +40,7 @@ def test_integration_pad_tokenize(run_component, tmp_path): "--obsm_padding_mask", "padding_mask", "--pad_token", "", "--pad_value", "-2", - "--input_layer", "binned", + "--input_obsm_binned_counts", "binned_counts", "--model_vocab", vocab_file ]) @@ -96,7 +52,7 @@ def test_integration_pad_tokenize(run_component, tmp_path): padding_mask = output_adata.obsm["padding_mask"] # check output dimensions - ## nr of genes that are tokenized + ## nr of genes that are tokenized assert gene_ids.shape[1] <= output_adata.var.shape[0] + 1, "gene_ids shape[1] is higher than adata.var.shape[0] (n_hvg + 1)" assert values.shape[1] <= output_adata.var.shape[0] + 1, "values shape[1] is higher than adata.var.shape[0] (n_hvg + 1)" assert padding_mask.shape[1] <= output_adata.var.shape[0] + 1, "padding_mask shape[1] is higher than adata.var.shape[0] (n_hvg + 1)" diff --git a/src/workflows/annotation/scgpt_annotation/config.vsh.yaml b/src/workflows/annotation/scgpt_annotation/config.vsh.yaml new file mode 100644 index 00000000000..20e84ec79f5 --- /dev/null +++ b/src/workflows/annotation/scgpt_annotation/config.vsh.yaml @@ -0,0 +1,195 @@ +name: "scgpt_annotation" +namespace: "workflows/annotation" +description: | + Cell type annotation workflow using scGPT. + The workflow takes a pre-processed h5mu file as query input, and performs + - subsetting for HVG + - cross-checking of genes with the model vocabulary + - binning of gene counts + - padding and tokenizing of genes + - transformer-based cell type prediction + Note that cell-type prediction using scGPT is only possible using a fine-tuned scGPT model. +authors: + - __merge__: /src/authors/dorien_roosen.yaml + roles: [ author, maintainer ] + - __merge__: /src/authors/elizabeth_mlynarski.yaml + roles: [ contributor ] + - __merge__: /src/authors/weiwei_schultz.yaml + roles: [ contributor ] + +argument_groups: + - name: "Query input" + arguments: + - name: "--id" + required: true + type: string + description: ID of the sample. + example: foo + - name: "--input" + type: file + required: true + description: Path to the input file. + example: input.h5mu + - name: "--modality" + type: string + default: "rna" + required: false + - name: "--input_layer" + type: string + required: False + description: | + Mudata layer (key from layers) to use as input data for HVG subsetting and binning; if not specified, X is used. + - name: "--input_var_gene_names" + type: string + required: false + description: | + The .var field in the input (query) containing gene names; if not provided, the var index will be used. + - name: "--input_obs_batch_label" + type: string + required: true + description: | + The .obs field in the input (query) dataset containing the batch labels. + + - name: "Model input" + arguments: + - name: "--model" + type: file + required: true + example: best_model.pt + description: | + The scGPT model file. + Must be a fine-tuned model that contains keys for checkpoints (--finetuned_checkpoints_key) and cell type label mapper(--label_mapper_key). + - name: "--model_config" + type: file + required: true + example: args.json + description: | + The scGPT model configuration file. + - name: "--model_vocab" + type: file + required: true + example: vocab.json + description: | + The scGPT model vocabulary file. + - name: "--finetuned_checkpoints_key" + type: string + default: model_state_dict + description: | + Key in the model file containing the pre-trained checkpoints. + - name: "--label_mapper_key" + type: string + default: id_to_class + description: | + Key in the model file containing the cell type class to label mapper dictionary. + + - name: "Outputs" + arguments: + - name: "--output" + type: file + required: true + direction: output + description: Output file path + example: output.h5mu + - name: "--output_compression" + type: string + example: "gzip" + required: false + choices: ["gzip", "lzf"] + description: | + The compression algorithm to use for the output h5mu file. + - name: "--output_obs_predictions" + type: string + default: "scgpt_pred" + required: false + description: | + The name of the adata.obs column to write predicted cell type labels to. + - name: "--output_obs_probability" + type: string + default: "scgpt_probability" + required: false + description: | + The name of the adata.obs column to write predicted cell type labels to. + + - name: "Padding arguments" + arguments: + - name: "--pad_token" + type: string + default: "" + required: false + description: | + Token used for padding. + - name: "--pad_value" + type: integer + default: -2 + required: false + description: | + The value of the padding token. + + - name: "HVG subset arguments" + arguments: + - name: "--n_hvg" + type: integer + default: 1200 + description: | + Number of highly variable genes to subset for. + + - name: "Tokenization arguments" + arguments: + - name: "--max_seq_len" + type: integer + required: false + description: | + The maximum sequence length of the tokenized data. + + - name: "Embedding arguments" + arguments: + - name: --dsbn + type: boolean + default: true + description: | + Apply domain-specific batch normalization + - name: "--batch_size" + type: integer + default: 64 + min: 1 + description: | + The batch size to be used for embedding inference. + + - name: "Binning arguments" + arguments: + - name: "--n_input_bins" + type: integer + default: 51 + required: False + min: 1 + description: | + The number of bins to discretize the data into; When no value is provided, data won't be binned. + - name: "--seed" + type: integer + min: 0 + required: false + description: | + Seed for random number generation used for binning. If not set, no seed is used. + +resources: + - type: nextflow_script + path: main.nf + entrypoint: run_wf + +test_resources: + - type: nextflow_script + path: test.nf + entrypoint: test_wf + - path: /resources_test/scgpt + +dependencies: + - name: scgpt/cross_check_genes + - name: scgpt/binning + - name: feature_annotation/highly_variable_features_scanpy + - name: filter/do_filter + - name: scgpt/pad_tokenize + - name: scgpt/cell_type_annotation + alias: scgpt_celltype_annotation + +runners: + - type: nextflow diff --git a/src/workflows/annotation/scgpt_annotation/integration_test.sh b/src/workflows/annotation/scgpt_annotation/integration_test.sh new file mode 100755 index 00000000000..14575014e9c --- /dev/null +++ b/src/workflows/annotation/scgpt_annotation/integration_test.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# get the root of the directory +REPO_ROOT=$(git rev-parse --show-toplevel) + +# ensure that the command below is run from the root of the repository +cd "$REPO_ROOT" + +nextflow run . \ + -main-script src/workflows/annotation/scgpt_annotation/test.nf \ + -resume \ + -profile docker,no_publish \ + -entry test_wf \ + -c src/workflows/utils/labels_ci.config \ + -c src/workflows/utils/integration_tests.config diff --git a/src/workflows/annotation/scgpt_annotation/main.nf b/src/workflows/annotation/scgpt_annotation/main.nf new file mode 100644 index 00000000000..e7b47703171 --- /dev/null +++ b/src/workflows/annotation/scgpt_annotation/main.nf @@ -0,0 +1,112 @@ +workflow run_wf { + + take: + input_ch + + main: + output_ch = input_ch + // Set aside the output for this workflow to avoid conflicts + | map {id, state -> + def new_state = state + ["workflow_output": state.output] + [id, new_state] + } + // Annotate the mudata object with highly variable genes. + | highly_variable_features_scanpy.run( + fromState: [ + "input": "input", + "layer": "input_layer", + "modality": "modality", + "n_top_features": "n_hvg", + ], + args: [ + "var_name_filter": "scgpt_filter_with_hvg", + "flavor": "seurat_v3" + ], + toState: ["input": "output"] + ) + // Check whether the genes are part of the provided vocabulary. + // Subsets for genes present in vocab only. + | cross_check_genes.run( + fromState: [ + "input": "input", + "modality": "modality", + "vocab_file": "model_vocab", + "input_var_gene_names": "input_var_gene_names", + "output": "output", + "pad_token": "pad_token" + ], + args: [ + "var_input": "scgpt_filter_with_hvg", + "output_var_filter": "scgpt_cross_checked_genes" + ], + toState: ["input": "output"] + ) + // Bins the data into a fixed number of bins. + | binning.run( + fromState: [ + "input": "input", + "modality": "modality", + "input_layer": "input_layer", + "n_input_bins": "n_input_bins", + "output": "output", + "seed": "seed" + ], + args: [ + "output_obsm_binned_counts": "binned_counts", + "var_input": "scgpt_cross_checked_genes" + ], + toState: ["input": "output"] + ) + // Padding and tokenization of gene count values. + | pad_tokenize.run( + fromState: [ + "input": "input", + "modality": "modality", + "model_vocab": "model_vocab", + "var_gene_names": "input_var_gene_names", + "pad_token": "pad_token", + "pad_value": "pad_value", + "max_seq_len": "max_seq_len", + "output": "output" + ], + args: [ + "input_obsm_binned_counts": "binned_counts", + "obsm_gene_tokens": "gene_id_tokens", + "obsm_tokenized_values": "values_tokenized", + "obsm_padding_mask": "padding_mask", + "var_input": "scgpt_cross_checked_genes" + ], + toState: ["input": "output"] + ) + // scGPT decoder-based cell type annotation. + | scgpt_celltype_annotation.run( + fromState: [ + "model": "model", + "model_vocab": "model_vocab", + "model_config": "model_config", + "label_mapper_key": "label_mapper_key", + "finetuned_checkpoints_key": "finetuned_checkpoints_key", + "input": "input", + "modality": "modality", + "obs_batch_label": "input_obs_batch_label", + "pad_token": "pad_token", + "pad_value": "pad_value", + "n_input_bins": "n_input_bins", + "dsbn": "dsbn", + "batch_size": "batch_size", + "seed": "seed", + "output_obs_predictions": "output_obs_predictions", + "output_obs_probability": "output_obs_probability", + "output": "workflow_output", + "output_compression": "output_compression" + ], + args: [ + "obsm_gene_tokens": "gene_id_tokens", + "obsm_tokenized_values": "values_tokenized" + ], + toState: {id, output, state -> ["output": output.output]} + ) + + emit: + output_ch +} diff --git a/src/workflows/annotation/scgpt_annotation/nextflow.config b/src/workflows/annotation/scgpt_annotation/nextflow.config new file mode 100644 index 00000000000..059100c489c --- /dev/null +++ b/src/workflows/annotation/scgpt_annotation/nextflow.config @@ -0,0 +1,10 @@ +manifest { + nextflowVersion = '!>=20.12.1-edge' +} + +params { + rootDir = java.nio.file.Paths.get("$projectDir/../../../../").toAbsolutePath().normalize().toString() +} + +// include common settings +includeConfig("${params.rootDir}/src/workflows/utils/labels.config") diff --git a/src/workflows/annotation/scgpt_annotation/test.nf b/src/workflows/annotation/scgpt_annotation/test.nf new file mode 100644 index 00000000000..66213bef4e8 --- /dev/null +++ b/src/workflows/annotation/scgpt_annotation/test.nf @@ -0,0 +1,55 @@ +nextflow.enable.dsl=2 + +include { scgpt_annotation } from params.rootDir + "/target/nextflow/workflows/annotation/scgpt_annotation/main.nf" +include { scgpt_annotation_test } from params.rootDir + "/target/nextflow/test_workflows/annotation/scgpt_annotation_test/main.nf" + +workflow test_wf { + resources_test = file("${params.rootDir}/resources_test/scgpt") + + output_ch = Channel.fromList([ + [ + id: "simple_execution_test", + input: resources_test.resolve("test_resources/Kim2020_Lung_subset_preprocessed.h5mu"), + model: resources_test.resolve("finetuned_model/best_model.pt"), + model_config: resources_test.resolve("source/args.json"), + model_vocab: resources_test.resolve("source/vocab.json"), + input_layer: "log_normalized", + input_obs_batch_label: "sample", + // change default to reduce resource requirements + n_hvg: 400, + seed: 1 + ] + ]) + | map{ state -> [state.id, state] } + | scgpt_annotation + | view { output -> + assert output.size() == 2 : "Outputs should contain two elements; [id, state]" + + // check id + def id = output[0] + assert id.endsWith("_test") + + // check output + def state = output[1] + assert state instanceof Map : "State should be a map. Found: ${state}" + assert state.containsKey("output") : "Output should contain key 'output'." + assert state.output.isFile() : "'output' should be a file." + assert state.output.toString().endsWith(".h5mu") : "Output file should end with '.h5mu'. Found: ${state.output}" + + "Output: $output" + } + | scgpt_annotation_test.run( + fromState: [ + "input": "output" + ], + args: [ + "n_hvg": 400 + ] + ) + | toSortedList() + | map { output_list -> + assert output_list.size() == 1 : "output channel should contain 1 event" + assert output_list.collect{it[0]} == ["simple_execution_test"] + } + +} diff --git a/src/workflows/test_workflows/annotation/scgpt/config.vsh.yaml b/src/workflows/test_workflows/annotation/scgpt/config.vsh.yaml new file mode 100644 index 00000000000..6688114db1f --- /dev/null +++ b/src/workflows/test_workflows/annotation/scgpt/config.vsh.yaml @@ -0,0 +1,40 @@ +name: "scgpt_annotation_test" +namespace: "test_workflows/annotation" +description: "This component test the output of the integration test of the bd_rhapsody workflow." +authors: + - __merge__: /src/authors/dorien_roosen.yaml +argument_groups: + - name: Inputs + arguments: + - name: "--input" + type: file + required: true + description: Path to h5mu output. + example: foo.final.h5mu + - name: "--n_hvg" + type: integer + required: true + description: Number of highly variable genes the input file was subset for. + example: 400 +resources: + - type: python_script + path: script.py + - path: /src/utils/setup_logger.py + - path: /src/base/openpipelinetestutils + dest: openpipelinetestutils +engines: + - type: docker + image: python:3.12-slim + setup: + - type: docker + copy: ["openpipelinetestutils /opt/openpipelinetestutils"] + - type: apt + packages: + - procps + - type: python + packages: /opt/openpipelinetestutils + - type: python + __merge__: [/src/base/requirements/anndata_mudata.yaml, /src/base/requirements/viashpy.yaml, .] +runners: + - type: executable + - type: nextflow \ No newline at end of file diff --git a/src/workflows/test_workflows/annotation/scgpt/script.py b/src/workflows/test_workflows/annotation/scgpt/script.py new file mode 100644 index 00000000000..658fc45f27c --- /dev/null +++ b/src/workflows/test_workflows/annotation/scgpt/script.py @@ -0,0 +1,39 @@ +from mudata import read_h5mu +import numpy as np +import shutil +import os +import sys +from pathlib import Path +import pytest + +##VIASH START +par = { + "input": "input.h5mu" +} + +meta = { + "resources_dir": "resources_test" +} +##VIASH END + + +def test_run(): + input_mudata = read_h5mu(par["input"]) + expected_obsm = ["gene_id_tokens", "values_tokenized", "padding_mask", "bin_edges", "binned_counts"] + expected_var = ["scgpt_filter_with_hvg", "scgpt_cross_checked_genes"] + expected_obs = ["scgpt_pred", "scgpt_probability"] + + assert "rna" in list(input_mudata.mod.keys()), "Input should contain rna modality." + + assert all(key in list(input_mudata.mod["rna"].obsm) for key in expected_obsm), f"Input mod['rna'] obs columns should be: {expected_obsm}, found: {input_mudata.mod['rna'].obsm.keys()}." + assert all(key in list(input_mudata.mod["rna"].var) for key in expected_var), f"Input mod['rna'] var columns should be: {expected_var}, found: {input_mudata.mod['rna'].var.keys()}." + assert all(key in list(input_mudata.mod["rna"].obs) for key in expected_obs), f"Input mod['rna'] obs columns should be: {expected_obs}, found: {input_mudata.mod['rna'].obs.keys()}." + # hvg subsetting is not exact - add 10% to allowed data shape + assert input_mudata.mod["rna"].obsm["binned_counts"].shape[1] <= par["n_hvg"] + 0.1 * par["n_hvg"], f"Input shape should be lower or equal than --n_hvg {par['n_hvg']}, found: {input_mudata.shape[1]}." + + +if __name__ == "__main__": + HERE_DIR = Path(__file__).resolve().parent + shutil.copyfile(os.path.join(meta['resources_dir'], "openpipelinetestutils", "conftest.py"), + os.path.join(HERE_DIR, "conftest.py")) + sys.exit(pytest.main(["--import-mode=importlib"]))