Skip to content

Commit

Permalink
Fully Integrate SCDL into Geneformer (#480)
Browse files Browse the repository at this point in the history
## Summary
In this PR we refactor the Geneformer `SingleCellDataset` class to
integrate the `SingleCellMemmapDataset`(SCDL). The goal of this is to
streamline and increase readability of the dataset class.
## Details
We make the following changes:
- Input Format: 
- The `SingleCellDataset` now assumes that the input path to the data is
a directory formatted in the `SingleCellMemmap` format.
- The SingleCellModule now assumes that the train, val, and test input
paths are to directories that are formatted in the `SingleCellMemmap`
format
 - Get Item:
- `_get_item()` now leverages the get_row function from SCDL (so we
eliminate the need to store and parse information in metadata.json)
- Error Handling for Genes not in the Tokenizer Vocabulary:
- We add an optional parameter to SingleCellDataset and
SingleCellDataModule called `bypass_tokenizer_vocab` which is by default
`False`. So by default, we throw an error if a gene ID is not in the
tokenizer vocabulary. If a user wants to bypass this, they can change
`bypass_tokenizer_vocab` to `True`.
-  Error Handling for Genes with Zero Expression Values:
- We throw an invalid input error in the cases that certain cells have
no gene expression values (i.e. `sc_dataset.scdl.get_item()` returns
`[]` for the gene data value)
   
## Usage
The main change from a user perspective is to ensure that they convert
their single cell h5ad files (or directories of h5ad files) to
SingleCellMemmap format.
1) For a single h5ad file, i.e. `data.h5ad`, they can simply run the
following, where `output_path` is the file path the SingleCellMemmap
directory should be written to:
`    SingleCellMemMapDataset(output_path, data.h5ad) `
2) For a directory of h5ad files, they can simply run the
`convert_h5ad_to_scdl` script (more information available in the SCDL
ReadMe).

## Testing
We test that the updated SingleCellDataset produces the same output as
the old dataset on synthetic samples and samples from the cellxsmall
dataset. We also test for Megatron compatibility (as this dataset uses
the MultiEpochDatasampler / Epoch Index) and for correct error handling
of the above cases.
Tests for these changes can be run via:
```shell
pytest -vsub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_dataset.py
```
Note that we have also updated the following test files to use the
MemMap dataset format + set bypass_tokenizer_vocab=True in them, because
the cellxsmall dataset does have a few genes not in the HuggingFace
tokenizer vocab and so the tests will error otherwise:
`sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py`
`scripts/singlecell/geneformer/test_train.py`

`sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_stop_and_go.py`

---------

Signed-off-by: savitha-eng <[email protected]>
Signed-off-by: polinabinder1 <[email protected]>
Co-authored-by: Savitha Srinivasan <[email protected]>
Co-authored-by: polinabinder1 <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2024
1 parent e9ed8cf commit 30527b1
Show file tree
Hide file tree
Showing 22 changed files with 1,073 additions and 729 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
docs/site/
*.nemo
protein/
singlecell/
results/

# Local configs
Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,10 @@ type, and then pass in the config type to the training recipe.
Similar to ESM-2, you can download the dataset and checkpoint through our utility function.

```bash
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20240506 --source $MY_DATA_SOURCE); \
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20241203 --source $MY_DATA_SOURCE); \
GENEFORMER_10M_CKPT=$(download_bionemo_data geneformer/10M_240530:2.0 --source $MY_DATA_SOURCE); \
train_geneformer \
--data-dir ${TEST_DATA_DIR}/cellxgene_2023-12-15_small/processed_data \
--data-dir ${TEST_DATA_DIR}/cellxgene_2023-12-15_small_processed_scdl \
--result-dir ./results \
--restore-from-checkpoint-path ${GENEFORMER_10M_CKPT} \
--experiment-name test_experiment \
Expand All @@ -305,9 +305,9 @@ copy the `sub-projects/bionemo-geneformer/geneformer/scripts/train_geneformer.py
Simple fine-tuning example (**NOTE**: please change `--restore-from-checkpoint-path` to be the checkpoint directory path that was output last
by the previous train run)
```bash
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20240506 --source $MY_DATA_SOURCE); \
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20241203 --source $MY_DATA_SOURCE); \
train_geneformer \
--data-dir ${TEST_DATA_DIR}/cellxgene_2023-12-15_small/processed_data \
--data-dir ${TEST_DATA_DIR}/cellxgene_2023-12-15_small_processed_scdl \
--result-dir ./results \
--experiment-name test_finettune_experiment \
--num-gpus 1 \
Expand All @@ -331,11 +331,11 @@ customizations for your task.


```bash
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20240506 --source $MY_DATA_SOURCE); \
TEST_DATA_DIR=$(download_bionemo_data single_cell/testdata-20241203 --source $MY_DATA_SOURCE); \
bionemo-geneformer-recipe \
--recipe geneformer_10m_pretrain_recipe \
--dest my_config.yaml \
--data-path ${TEST_DATA_DIR}/cellxgene_2023-12-15_small/processed_data \
--recipe 10m-pretrain \
--dest my_config.json \
--data-path ${TEST_DATA_DIR}/cellxgene_2023-12-15_small_processed_scdl \
--result-dir ./results
```
> ⚠️ **IMPORTANT:** Inspect and edit the contents of the outputted my_config.yaml as you see fit
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@
sha256: 7a4237537bf535dfa00301ce8cc7073e0a23d5bc8aa902ad65db9f51b57a6df9 # pragma: allowlist secret
owner: Polina Binder <[email protected]>
description: Sample test data for SCDL.

- tag: sample_scdl_feature_ids
ngc: nvidia/clara/scdl_sample_test_feature_ids:1.0
ngc_registry: resource
pbss: s3://bionemo-ci/test-data/scdl_sample_test_feat_ids.tar.gz
sha256: 9020ba336dbfe33bddadba26ca0cde49958cbd73c5ad44f0960a5a4837c9db26 # pragma: allowlist secret
owner: Savitha Srinivasan <[email protected]>
description: Sample test data for SCDL with feature IDs appended.
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@
sha256: ab038b184de52e53ff7bcea5e01d97d55944c507db88c0495bdf9e5e9e0303a4 # pragma: allowlist secret
owner: John St John <[email protected]>
description: Golden values for geneformer QA model.

- tag: testdata-20241203
ngc: nvidia/clara/singlecell-testdata:2.0
ngc_registry: resource
pbss: "s3://bionemo-ci/test-data/singlecell/singlecell-scdltestdata-20241203.tar.gz"
sha256: d8e3ea569bc43768c24aa651aff77722df202078415528497c22394046b08cc3 # pragma: allowlist secret
owner: Savitha Srinivasan <[email protected]>
description: Test data for single cell models in SCDL Memmap format.
2 changes: 1 addition & 1 deletion sub-packages/bionemo-geneformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pytest -v .


## Acquiring Data
Datasets are expected to be in the form of AnnData (.h5ad) objects such as those downloaded from [Cell x Gene | CZI](https://chanzuckerberg.github.io/cellxgene-census/). They are then pre-processed with either `bionemo-geneformer/src/bionemo/geneformer/data/singlecell/sc_memmap.py` or with sc-DL.
Datasets are expected to be in the form of AnnData (.h5ad) objects such as those downloaded from [Cell x Gene | CZI](https://chanzuckerberg.github.io/cellxgene-census/). They are then pre-processed with `sub-packages/bionemo-scdl/src/bionemo/scdl/scripts/convert_h5ad_to_scdl.py`.

## Geneformer-nv 10M and 106M
Refer to the Dataset cards and Model cards to learn more about the pre-trained checkpoints provided for both 10M and 106M of Geneformer-nv.
Expand Down
1 change: 0 additions & 1 deletion sub-packages/bionemo-geneformer/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ dependencies = [
[project.scripts]
bionemo-geneformer-train= "bionemo.geneformer.run.main:main"
bionemo-geneformer-recipe= "bionemo.geneformer.run.recipes:main"
sc_memmap = "bionemo.geneformer.scripts.sc_memmap:main_cli"
infer_geneformer = "bionemo.geneformer.scripts.infer_geneformer:geneformer_infer_entrypoint"
train_geneformer = "bionemo.geneformer.scripts.train_geneformer:entrypoint"
geneformer_mlm_loss_eval = "bionemo.geneformer.scripts.geneformer_mlm_loss_eval:entrypoint"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def main(
seq_len_nv: int = 2048,
seq_len_hf: int = 2048,
seed: int = 513,
include_unrecognized_vocab_in_dataset: bool = False,
):
"""Inference function (requires DDP and only training data that fits in memory)."""
# This is just used to get the tokenizer :(
Expand Down Expand Up @@ -185,6 +186,7 @@ def main(
max_len=seq_len_nv,
mask_prob=mask_prob,
seed=seed,
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
ds_hf_nvfilt = SingleCellDataset(
dataset_path,
Expand All @@ -194,6 +196,7 @@ def main(
mask_prob=mask_prob,
eos_token=hf_tokenizer.token_to_id(hf_tokenizer.sep_token), # Stored in the special token
seed=seed,
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
print(f"Loaded dataset of length (NV): {len(ds_nv)}, (HF): {len(ds_hf_nvfilt)}")

Expand Down Expand Up @@ -299,6 +302,11 @@ def entrypoint():
)
parser.add_argument("--hf-model-path", type=str, default="ctheodoris/Geneformer", help="HF model path")
parser.add_argument("--dataset-path", type=Path, help="Path to dataset directory", required=True)
parser.add_argument(
"--include-unrecognized-vocab-in-dataset",
action="store_true",
help="If set to true, a hard-check is performed to verify all gene identifers are in the user supplied tokenizer vocab. Defaults to false which means any gene identifier not in the user supplied tokenizer vocab will be excluded.",
)

args = parser.parse_args()
main(
Expand All @@ -307,6 +315,7 @@ def entrypoint():
args.dataset_path,
args.hf_token_dictionary_path,
args.hf_medians_dictionary_path,
args.include_unrecognized_vocab_in_dataset,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class SingleCellDataModule(MegatronDataModule):
num_mask_per_sample (int): Number of masked versions of a single sample to be returned by each worker
train_batch_size (int): Batch size for training
val_batch_size (int): Batch size for validation
include_unrecognized_vocab_in_dataset (bool, optional): If set to True, a hard-check is performed to verify all gene identifers are in the user supplied tokenizer vocab. Defaults to False which means any gene identifier not in the user supplied tokenizer vocab will be excluded.
Attributes:
cfg (Config): Configuration object
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__( # noqa: D107
num_workers: int = 10, # TODO can this be automatically set?
persistent_workers: bool = True,
pin_memory: bool = True,
include_unrecognized_vocab_in_dataset: bool = False,
) -> None:
super().__init__()
if predict_dataset_path is None:
Expand Down Expand Up @@ -122,6 +124,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
self._val_dataset_ori = SingleCellDataset(
self.data_path_val,
Expand All @@ -132,6 +135,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
self._test_dataset_ori = SingleCellDataset(
self.data_path_test,
Expand All @@ -142,6 +146,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
self._predict_dataset_ori = None
else:
Expand All @@ -155,6 +160,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
include_unrecognized_vocab_in_dataset=include_unrecognized_vocab_in_dataset,
)
self._train_dataset_ori = None
self._val_dataset_ori = None
Expand Down
Loading

0 comments on commit 30527b1

Please sign in to comment.