From 4d105ec52d563df4ee7b8b3c87210507261a87d3 Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 4 Nov 2024 17:58:38 +0000 Subject: [PATCH 1/9] Add singlecell dataset profiling function and some small performance improvements to process --- .../geneformer/data/singlecell/dataset.py | 70 +++++++++++++++---- 1 file changed, 58 insertions(+), 12 deletions(-) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index 2a0cff74c6..38d667947c 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -213,6 +213,25 @@ def __getitem__(self, index: EpochIndex) -> types.BertSample: ) +def _gather_medians( + gene_names: np.ndarray, + gene_data: np.ndarray, + normalize: bool, + vocab: dict[str, int], + gene_median: dict[str, float], +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Filter out genes that are not in the provided tokenizer vocab, and tokenize the gene names.""" + genes, tokens, medians = [], [], [] + for tok, gene in zip(gene_names, gene_data): + if tok in vocab: + tokens.append(vocab[tok]) + genes.append(gene) + if normalize: + med = gene_median[tok] # If not in the dictionary we default to no normalization (1) + medians.append(med) + return np.asarray(genes), np.asarray(tokens), np.asarray(medians) + + def process_item( # noqa: D417 gene_data: np.ndarray, gene_idxs: np.ndarray, @@ -264,19 +283,9 @@ def process_item( # noqa: D417 max_len = max_len - 1 # - minus 1 for [CLS] token - gene_names = [feature_ids[idx] for idx in gene_idxs] - genes, tokens, medians = [], [], [] - for tok, gene in zip(gene_names, gene_data): - if tok in tokenizer.vocab: - tokens.append(tokenizer.token_to_id(tok)) - genes.append(gene) - if normalize: - med = gene_median.get(tok, 1) # If not in the dictionary we default to no normalization (1) - medians.append(med) + gene_names = feature_ids[gene_idxs] - genes = np.asarray(genes) - token_ids = np.asarray(tokens) - medians = np.asarray(medians) + genes, token_ids, medians = _gather_medians(gene_names, gene_data, normalize, tokenizer.vocab, gene_median) if normalize: # re-order according to expression median normalized rank. descending order. @@ -320,3 +329,40 @@ def process_item( # noqa: D417 "loss_mask": loss_mask, "is_random": torch.zeros_like(masked_tokens, dtype=torch.int64), } + + +def _profile_sc_dataset(): + import random + import time + + from tqdm import tqdm + + from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess + from bionemo.testing.data.load import load + + data_path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" / "train" + preprocessor = GeneformerPreprocess( + download_directory=data_path, + medians_file_path=data_path / "medians.json", + tokenizer_vocab_path=data_path / "geneformer.vocab", + ) + match preprocessor.preprocess(): + case {"tokenizer": tokenizer, "median_dict": median_dict}: + logging.info("*************** Preprocessing Finished ************") + case _: + logging.error("Preprocessing failed.") + scd = SingleCellDataset(data_path=data_path, tokenizer=tokenizer, median_dict=median_dict, max_len=2048, seed=321) + n_epochs = 1 + start = time.time() + idxs = list(range(len(scd) * n_epochs)) + random.seed(315) + random.shuffle(idxs) + for i in tqdm(idxs): + _ = scd[EpochIndex(idx=i % n_epochs, epoch=i // len(scd))] + stop = time.time() + print(f"Processed {len(scd)} rows in {stop - start} seconds") + + +if __name__ == "__main__": + # python -m bionemo.geneformer.data.singlecell.dataset will run this profile. + _profile_sc_dataset() From 53f1a1fcada2610beec90896cde3e147998bd831 Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 4 Nov 2024 20:07:40 +0000 Subject: [PATCH 2/9] Fix issue where we just re-sample the first element every time --- .../src/bionemo/geneformer/data/singlecell/dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index 38d667947c..b5e832c774 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -353,14 +353,15 @@ def _profile_sc_dataset(): logging.error("Preprocessing failed.") scd = SingleCellDataset(data_path=data_path, tokenizer=tokenizer, median_dict=median_dict, max_len=2048, seed=321) n_epochs = 1 - start = time.time() - idxs = list(range(len(scd) * n_epochs)) + len_dataset: int = len(scd) + idxs = list(range(len_dataset * n_epochs)) random.seed(315) random.shuffle(idxs) + start = time.time() for i in tqdm(idxs): - _ = scd[EpochIndex(idx=i % n_epochs, epoch=i // len(scd))] + _ = scd[EpochIndex(idx=i % len_dataset, epoch=i // len_dataset)] stop = time.time() - print(f"Processed {len(scd)} rows in {stop - start} seconds") + print(f"Processed {len_dataset * n_epochs} rows in {stop - start} seconds") if __name__ == "__main__": From 39b3841be338ffe6094b918a3f5ca54ceb127ac4 Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 4 Nov 2024 20:33:21 +0000 Subject: [PATCH 3/9] rename genes -> gene_expression --- .../bionemo/geneformer/data/singlecell/dataset.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index b5e832c774..b4f633ba13 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -285,15 +285,17 @@ def process_item( # noqa: D417 gene_names = feature_ids[gene_idxs] - genes, token_ids, medians = _gather_medians(gene_names, gene_data, normalize, tokenizer.vocab, gene_median) + gene_expression, token_ids, medians = _gather_medians( + gene_names, gene_data, normalize, tokenizer.vocab, gene_median + ) if normalize: # re-order according to expression median normalized rank. descending order. - genes = genes / genes.sum() * target_sum - genes = genes / medians.astype(float) - idxs = np.argsort(-genes) # sort in descending order so that the 0th position is the highest value. - genes = genes[idxs] + gene_expression = gene_expression / gene_expression.sum() * target_sum + gene_expression = gene_expression / medians.astype(float) + idxs = np.argsort(-gene_expression) # sort in descending order so that the 0th position is the highest value. + gene_expression = gene_expression[idxs] token_ids = token_ids[idxs] # - select max_len subset, set sample to false so it doesnt permute the already rank ordered expression values. From 3e12be6a1524b3f0e417d304c4fce76c68771473 Mon Sep 17 00:00:00 2001 From: John St John Date: Mon, 4 Nov 2024 20:35:27 +0000 Subject: [PATCH 4/9] More gene expression,median vector renaming --- .../bionemo/geneformer/data/singlecell/dataset.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index b4f633ba13..21e0fb855d 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -285,17 +285,19 @@ def process_item( # noqa: D417 gene_names = feature_ids[gene_idxs] - gene_expression, token_ids, medians = _gather_medians( + gene_expression_cell, token_ids, gene_expression_medians = _gather_medians( gene_names, gene_data, normalize, tokenizer.vocab, gene_median ) if normalize: # re-order according to expression median normalized rank. descending order. - gene_expression = gene_expression / gene_expression.sum() * target_sum - gene_expression = gene_expression / medians.astype(float) - idxs = np.argsort(-gene_expression) # sort in descending order so that the 0th position is the highest value. - gene_expression = gene_expression[idxs] + gene_expression_cell = gene_expression_cell / gene_expression_cell.sum() * target_sum + gene_expression_cell = gene_expression_cell / gene_expression_medians.astype(float) + idxs = np.argsort( + -gene_expression_cell + ) # sort in descending order so that the 0th position is the highest value. + gene_expression_cell = gene_expression_cell[idxs] token_ids = token_ids[idxs] # - select max_len subset, set sample to false so it doesnt permute the already rank ordered expression values. From b2fd9f8490a2d8d1ecc70c83e40aa10eb60d72f5 Mon Sep 17 00:00:00 2001 From: John St John Date: Tue, 5 Nov 2024 00:23:25 +0000 Subject: [PATCH 5/9] Address PR comments --- .../geneformer/data/singlecell/dataset.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index 21e0fb855d..70e70758fa 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -15,6 +15,8 @@ import json +import random +import time from pathlib import Path from typing import Any, Dict, Optional, Sequence, Tuple @@ -22,12 +24,15 @@ import torch from nemo.utils import logging from torch.utils.data import Dataset +from tqdm import tqdm from bionemo.core.data.multi_epoch_dataset import EpochIndex from bionemo.core.utils import random_utils +from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess from bionemo.geneformer.data.singlecell.utils import sample_or_truncate from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer from bionemo.llm.data import masking, types +from bionemo.testing.data.load import load __all__: Sequence[str] = ( @@ -336,14 +341,6 @@ def process_item( # noqa: D417 def _profile_sc_dataset(): - import random - import time - - from tqdm import tqdm - - from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess - from bionemo.testing.data.load import load - data_path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" / "train" preprocessor = GeneformerPreprocess( download_directory=data_path, @@ -361,10 +358,10 @@ def _profile_sc_dataset(): idxs = list(range(len_dataset * n_epochs)) random.seed(315) random.shuffle(idxs) - start = time.time() + start = time.monotonic() # Like time.time() but uses the CPU clock rather so subsequent calls will progress. for i in tqdm(idxs): _ = scd[EpochIndex(idx=i % len_dataset, epoch=i // len_dataset)] - stop = time.time() + stop = time.monotonic() print(f"Processed {len_dataset * n_epochs} rows in {stop - start} seconds") From fdefc178cf28ebcf2ea3de861a726988d7957e4a Mon Sep 17 00:00:00 2001 From: "John St. John" Date: Fri, 8 Nov 2024 10:38:11 -0800 Subject: [PATCH 6/9] Update sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py Co-authored-by: Malcolm Greaves Signed-off-by: John St. John --- .../src/bionemo/geneformer/data/singlecell/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index 70e70758fa..28d8328501 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -32,7 +32,7 @@ from bionemo.geneformer.data.singlecell.utils import sample_or_truncate from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer from bionemo.llm.data import masking, types -from bionemo.testing.data.load import load +from bionemo.core.data.load import load __all__: Sequence[str] = ( From c5c0dc08d4310e8a8c3050612d54a51018768f37 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 8 Nov 2024 18:41:33 +0000 Subject: [PATCH 7/9] fix pre-commit issues --- .../src/bionemo/geneformer/data/singlecell/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index 68993f28b8..2470192cc3 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -26,13 +26,13 @@ from torch.utils.data import Dataset from tqdm import tqdm +from bionemo.core.data.load import load from bionemo.core.data.multi_epoch_dataset import EpochIndex from bionemo.core.utils import random_utils from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess from bionemo.geneformer.data.singlecell.utils import sample_or_truncate from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer from bionemo.llm.data import masking, types -from bionemo.core.data.load import load __all__: Sequence[str] = ( From 5793335265173382add49b4a1429607914849d51 Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 8 Nov 2024 18:47:08 +0000 Subject: [PATCH 8/9] Remove testing dependency --- sub-packages/bionemo-geneformer/pyproject.toml | 1 - sub-packages/bionemo-geneformer/scripts/README.md | 3 +++ .../geneformer => }/scripts/geneformer_mlm_loss_eval.py | 0 tach.toml | 1 - 4 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 sub-packages/bionemo-geneformer/scripts/README.md rename sub-packages/bionemo-geneformer/{src/bionemo/geneformer => }/scripts/geneformer_mlm_loss_eval.py (100%) diff --git a/sub-packages/bionemo-geneformer/pyproject.toml b/sub-packages/bionemo-geneformer/pyproject.toml index fb77cb7994..4efd56a76f 100644 --- a/sub-packages/bionemo-geneformer/pyproject.toml +++ b/sub-packages/bionemo-geneformer/pyproject.toml @@ -14,7 +14,6 @@ dependencies = [ # bionemo sub-packages 'bionemo-core', 'bionemo-llm', - 'bionemo-testing', # needed for getting the tokenizer from NGC # external 'cellxgene_census', ] diff --git a/sub-packages/bionemo-geneformer/scripts/README.md b/sub-packages/bionemo-geneformer/scripts/README.md new file mode 100644 index 0000000000..c5bf8c8fec --- /dev/null +++ b/sub-packages/bionemo-geneformer/scripts/README.md @@ -0,0 +1,3 @@ +# WARNING +This folder contains one-off eval scripts that may not run and are not actively tested or kept up to date. +Also these scripts may depend on `bionemo-testing` which is generally not allowed. diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py b/sub-packages/bionemo-geneformer/scripts/geneformer_mlm_loss_eval.py similarity index 100% rename from sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py rename to sub-packages/bionemo-geneformer/scripts/geneformer_mlm_loss_eval.py diff --git a/tach.toml b/tach.toml index 252ea01e63..b28c38ae7f 100644 --- a/tach.toml +++ b/tach.toml @@ -54,7 +54,6 @@ path = "bionemo.geneformer" depends_on = [ { path = "bionemo.core" }, { path = "bionemo.llm" }, - { path = "bionemo.testing" }, # needed for the inference script to get the tokenizer reliably. ] [[modules]] From 3ce2dfa68b517a4d865fac4604d01ed45f6379ce Mon Sep 17 00:00:00 2001 From: John St John Date: Fri, 8 Nov 2024 19:26:02 +0000 Subject: [PATCH 9/9] fix pydantic train test --- .../tests/bionemo/geneformer/scripts/test_pydantic_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_pydantic_train.py b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_pydantic_train.py index f76682fbc0..b7a4280017 100644 --- a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_pydantic_train.py +++ b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_pydantic_train.py @@ -145,7 +145,7 @@ def test_finetune_cli(tmpdir): if result.returncode != 0: raise Exception(f"Pretrain recipe failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}") - cmd_str = f"""bionemo-geneformer-train --conf {config} """.strip() + cmd_str = f"bionemo-geneformer-train --conf {config} --model-config-t ExposedFineTuneSeqLenBioBertConfig" env = dict(**os.environ) # a local copy of the environment open_port = find_free_network_port() env["MASTER_PORT"] = str(open_port)