Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small performance improvement to process_item #395

Merged
merged 12 commits into from
Nov 8, 2024
1 change: 0 additions & 1 deletion sub-packages/bionemo-geneformer/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ dependencies = [
# bionemo sub-packages
'bionemo-core',
'bionemo-llm',
'bionemo-testing', # needed for getting the tokenizer from NGC
# external
'cellxgene_census',
]
Expand Down
3 changes: 3 additions & 0 deletions sub-packages/bionemo-geneformer/scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# WARNING
This folder contains one-off eval scripts that may not run and are not actively tested or kept up to date.
Also these scripts may depend on `bionemo-testing` which is generally not allowed.
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,21 @@


import json
import random
import time
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple

import numpy as np
import torch
from nemo.utils import logging
from torch.utils.data import Dataset
from tqdm import tqdm

from bionemo.core.data.load import load
from bionemo.core.data.multi_epoch_dataset import EpochIndex
from bionemo.core.utils import random_utils
from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess
from bionemo.geneformer.data.singlecell.utils import sample_or_truncate
from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer
from bionemo.llm.data import masking, types
Expand Down Expand Up @@ -216,6 +221,25 @@ def __getitem__(self, index: EpochIndex) -> types.BertSample:
)


def _gather_medians(
gene_names: np.ndarray,
gene_data: np.ndarray,
normalize: bool,
vocab: dict[str, int],
gene_median: dict[str, float],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Filter out genes that are not in the provided tokenizer vocab, and tokenize the gene names."""
genes, tokens, medians = [], [], []
for tok, gene in zip(gene_names, gene_data):
if tok in vocab:
tokens.append(vocab[tok])
genes.append(gene)
if normalize:
med = gene_median[tok] # If not in the dictionary we default to no normalization (1)
medians.append(med)
return np.asarray(genes), np.asarray(tokens), np.asarray(medians)


def process_item( # noqa: D417
gene_data: np.ndarray,
gene_idxs: np.ndarray,
Expand Down Expand Up @@ -271,27 +295,21 @@ def process_item( # noqa: D417
if eos_token is not None:
max_len = max_len - 1 # - minus 1 for [EOS] token

gene_names = [feature_ids[idx] for idx in gene_idxs]
genes, tokens, medians = [], [], []
for tok, gene in zip(gene_names, gene_data):
if tok in tokenizer.vocab:
tokens.append(tokenizer.token_to_id(tok))
genes.append(gene)
if normalize:
med = gene_median.get(tok, 1) # If not in the dictionary we default to no normalization (1)
medians.append(med)
gene_names = feature_ids[gene_idxs]

genes = np.asarray(genes)
token_ids = np.asarray(tokens)
medians = np.asarray(medians)
gene_expression_cell, token_ids, gene_expression_medians = _gather_medians(
gene_names, gene_data, normalize, tokenizer.vocab, gene_median
)

if normalize:
# re-order according to expression median normalized rank. descending order.

genes = genes / genes.sum() * target_sum
genes = genes / medians.astype(float)
idxs = np.argsort(-genes) # sort in descending order so that the 0th position is the highest value.
genes = genes[idxs]
gene_expression_cell = gene_expression_cell / gene_expression_cell.sum() * target_sum
gene_expression_cell = gene_expression_cell / gene_expression_medians.astype(float)
idxs = np.argsort(
-gene_expression_cell
) # sort in descending order so that the 0th position is the highest value.
gene_expression_cell = gene_expression_cell[idxs]
token_ids = token_ids[idxs]

# - select max_len subset, set sample to false so it doesnt permute the already rank ordered expression values.
Expand Down Expand Up @@ -327,3 +345,33 @@ def process_item( # noqa: D417
"loss_mask": loss_mask,
"is_random": torch.zeros_like(masked_tokens, dtype=torch.int64),
}


def _profile_sc_dataset():
data_path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" / "train"
preprocessor = GeneformerPreprocess(
download_directory=data_path,
medians_file_path=data_path / "medians.json",
tokenizer_vocab_path=data_path / "geneformer.vocab",
)
match preprocessor.preprocess():
case {"tokenizer": tokenizer, "median_dict": median_dict}:
logging.info("*************** Preprocessing Finished ************")
case _:
logging.error("Preprocessing failed.")
scd = SingleCellDataset(data_path=data_path, tokenizer=tokenizer, median_dict=median_dict, max_len=2048, seed=321)
n_epochs = 1
len_dataset: int = len(scd)
idxs = list(range(len_dataset * n_epochs))
random.seed(315)
random.shuffle(idxs)
start = time.monotonic() # Like time.time() but uses the CPU clock rather so subsequent calls will progress.
for i in tqdm(idxs):
_ = scd[EpochIndex(idx=i % len_dataset, epoch=i // len_dataset)]
stop = time.monotonic()
print(f"Processed {len_dataset * n_epochs} rows in {stop - start} seconds")


if __name__ == "__main__":
# python -m bionemo.geneformer.data.singlecell.dataset will run this profile.
_profile_sc_dataset()
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_finetune_cli(tmpdir):
if result.returncode != 0:
raise Exception(f"Pretrain recipe failed:\n{cmd_str=}\n{result.stdout=}\n{result.stderr=}")

cmd_str = f"""bionemo-geneformer-train --conf {config} """.strip()
cmd_str = f"bionemo-geneformer-train --conf {config} --model-config-t ExposedFineTuneSeqLenBioBertConfig"
env = dict(**os.environ) # a local copy of the environment
open_port = find_free_network_port()
env["MASTER_PORT"] = str(open_port)
Expand Down
1 change: 0 additions & 1 deletion tach.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ path = "bionemo.geneformer"
depends_on = [
{ path = "bionemo.core" },
{ path = "bionemo.llm" },
{ path = "bionemo.testing" }, # needed for the inference script to get the tokenizer reliably.
]

[[modules]]
Expand Down
Loading