Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fine-tuned scGPT #17

Merged
merged 14 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/run_benchmark/run_full_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ input_states: resources/datasets/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["uce"]}'
settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}'
HERE

# run the benchmark
Expand Down
1 change: 1 addition & 0 deletions scripts/run_benchmark/run_full_seqeracloud.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ input_states: s3://openproblems-data/resources/task_batch_integration/datasets/*
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["scgpt_finetuned"]}'
HERE

tw launch https://github.com/openproblems-bio/task_batch_integration.git \
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_benchmark/run_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["uce"]}'
settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}'
HERE

nextflow run . \
Expand Down
65 changes: 65 additions & 0 deletions src/methods/scgpt_finetuned/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
__merge__: ../../api/base_method.yaml

name: scgpt_finetuned
label: scGPT (fine-tuned)
summary: "A foundation model for single-cell biology (fine-tuned)"
description: |
scGPT is a foundation model for single-cell biology based on a generative
pre-trained transformer and trained on a repository of over 33 million cells.

Here, we fine-tune the pre-trained model for the batch integration task.
references:
doi:
- 10.1038/s41592-024-02201-0
links:
documentation: https://scgpt.readthedocs.io/en/latest/
repository: https://github.com/bowang-lab/scGPT

info:
method_types: [embedding]
preferred_normalization: counts
variants:
scgpt_finetuned_default:

arguments:
- name: --model_name
type: string
description: String giving the name of the scGPT model to use
choices: ["scGPT_human", "scGPT_CP"]
default: "scGPT_human"
- name: --model
type: file
description: |
Path to the directory containing the scGPT model specified by model_name
or a .zip/.tar.gz archive to extract. If not given the model will be
downloaded.
required: false
- name: --n_hvg
type: integer
default: 3000
description: Number of highly variable genes to use.

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py
- path: scgpt_functions.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
setup:
- type: python
pypi:
- gdown
- scgpt # Install from PyPI to get dependencies
- type: docker
# Force re-installing from GitHub to get bug fixes
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
288 changes: 288 additions & 0 deletions src/methods/scgpt_finetuned/scgpt_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import time
import warnings

import numpy as np
import scgpt
import torch


def prepare_data(
tokenized_train,
tokenized_valid,
train_batch_labels,
valid_batch_labels,
hyperparameters,
model_settings,
epoch,
):
masked_values_train = scgpt.tokenizer.random_mask_value(
tokenized_train["values"],
mask_ratio=hyperparameters["mask_ratio"],
mask_value=model_settings["mask_value"],
pad_value=model_settings["pad_value"],
)
masked_values_valid = scgpt.tokenizer.random_mask_value(
tokenized_valid["values"],
mask_ratio=hyperparameters["mask_ratio"],
mask_value=model_settings["mask_value"],
pad_value=model_settings["pad_value"],
)
scgpt.logger.info(
f"Random masking at epoch {epoch:3d},"
f"ratio of masked values in train: {(masked_values_train == model_settings['mask_value']).sum() / (masked_values_train - model_settings['pad_value']).count_nonzero():.4f}"
)

input_gene_ids_train, input_gene_ids_valid = (
tokenized_train["genes"],
tokenized_valid["genes"],
)
input_values_train, input_values_valid = masked_values_train, masked_values_valid
target_values_train, target_values_valid = (
tokenized_train["values"],
tokenized_valid["values"],
)

tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

if model_settings["per_seq_batch_sample"]:
train_sort_ids = np.argsort(train_batch_labels)
input_gene_ids_train = input_gene_ids_train[train_sort_ids]
input_values_train = input_values_train[train_sort_ids]
target_values_train = target_values_train[train_sort_ids]
tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]

valid_sort_ids = np.argsort(valid_batch_labels)
input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
input_values_valid = input_values_valid[valid_sort_ids]
target_values_valid = target_values_valid[valid_sort_ids]
tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]

train_data_pt = {
"gene_ids": input_gene_ids_train,
"values": input_values_train,
"target_values": target_values_train,
"batch_labels": tensor_batch_labels_train,
}
valid_data_pt = {
"gene_ids": input_gene_ids_valid,
"values": input_values_valid,
"target_values": target_values_valid,
"batch_labels": tensor_batch_labels_valid,
}

return train_data_pt, valid_data_pt


class SeqDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return self.data["gene_ids"].shape[0]

def __getitem__(self, idx):
return {k: v[idx] for k, v in self.data.items()}


def prepare_dataloader(
data_pt,
batch_size,
shuffle,
intra_domain_shuffle,
drop_last,
num_workers,
per_seq_batch_sample,
):
dataset = SeqDataset(data_pt)

if per_seq_batch_sample:
# Find the indices of samples in each seq batch
subsets = []
batch_labels_array = data_pt["batch_labels"].numpy()
for batch_label in np.unique(batch_labels_array):
batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
subsets.append(batch_indices)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_sampler=scgpt.SubsetsBatchSampler(
subsets,
batch_size,
intra_subset_shuffle=intra_domain_shuffle,
inter_subset_shuffle=shuffle,
drop_last=drop_last,
),
num_workers=num_workers,
pin_memory=True,
)
return data_loader

data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=num_workers,
pin_memory=True,
)
return data_loader


def train(
model,
loader,
scaler,
optimizer,
scheduler,
vocab,
criterion,
criterion_dab,
hyperparameters,
model_settings,
device,
epoch,
):
model.train()

total_loss, total_mse, total_gepc = 0.0, 0.0, 0.0
total_error = 0.0
log_interval = hyperparameters["log_interval"]
start_time = time.time()

num_batches = len(loader)
for batch, batch_data in enumerate(loader):
input_gene_ids = batch_data["gene_ids"].to(device)
input_values = batch_data["values"].to(device)
target_values = batch_data["target_values"].to(device)
batch_labels = batch_data["batch_labels"].to(device)

src_key_padding_mask = input_gene_ids.eq(vocab[model_settings["pad_token"]])
with torch.cuda.amp.autocast(enabled=hyperparameters["amp"]):
output_dict = model(
input_gene_ids,
input_values,
src_key_padding_mask=src_key_padding_mask,
batch_labels=batch_labels if model_settings["DSBN"] else None,
MVC=hyperparameters["GEPC"],
ECS=hyperparameters["ecs_thres"] > 0,
)

masked_positions = input_values.eq(
model_settings["mask_value"]
) # the postions to predict
loss = loss_mse = criterion(
output_dict["mlm_output"], target_values, masked_positions
)
if model_settings["explicit_zero_prob"]:
loss_zero_log_prob = scgpt.loss.criterion_neg_log_bernoulli(
output_dict["mlm_zero_probs"], target_values, masked_positions
)
loss = loss + loss_zero_log_prob
if hyperparameters["GEPC"]:
loss_gepc = criterion(
output_dict["mvc_output"], target_values, masked_positions
)
loss = loss + loss_gepc
if hyperparameters["GEPC"] and model_settings["explicit_zero_prob"]:
loss_gepc_zero_log_prob = scgpt.loss.criterion_neg_log_bernoulli(
output_dict["mvc_zero_probs"], target_values, masked_positions
)
loss = loss + loss_gepc_zero_log_prob
if hyperparameters["ecs_thres"] > 0:
loss_ecs = 10 * output_dict["loss_ecs"]
loss = loss + loss_ecs
loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)
loss = loss + hyperparameters["dab_weight"] * loss_dab

model.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always")
torch.nn.utils.clip_grad_norm_(
model.parameters(),
1.0,
error_if_nonfinite=False if scaler.is_enabled() else True,
)
if len(w) > 0:
scgpt.logger.warning(
f"Found infinite gradient. This may be caused by the gradient "
f"scaler. The current scale is {scaler.get_scale()}. This warning "
"can be ignored if no longer occurs after autoscaling of the scaler."
)
scaler.step(optimizer)
scaler.update()

with torch.no_grad():
mre = scgpt.loss.masked_relative_error(
output_dict["mlm_output"], target_values, masked_positions
)

total_loss += loss.item()
total_mse += loss_mse.item()
total_gepc += loss_gepc.item() if hyperparameters["GEPC"] else 0.0
total_error += mre.item()
if batch % log_interval == 0 and batch > 0:
lr = scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
cur_mse = total_mse / log_interval
cur_gepc = total_gepc / log_interval if hyperparameters["GEPC"] else 0.0
cur_error = total_error / log_interval
scgpt.logger.info(
f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} | mre {cur_error:5.2f} |"
+ (f"gepc {cur_gepc:5.2f} |" if hyperparameters["GEPC"] else "")
)
total_loss = 0
total_mse = 0
total_gepc = 0
total_error = 0
start_time = time.time()


def evaluate(
model,
loader,
vocab,
criterion,
criterion_dab,
hyperparameters,
model_settings,
device,
):
model.eval()
total_loss = 0.0
total_error = 0.0
total_dab = 0.0
total_num = 0
with torch.no_grad():
for batch_data in loader:
input_gene_ids = batch_data["gene_ids"].to(device)
input_values = batch_data["values"].to(device)
target_values = batch_data["target_values"].to(device)
batch_labels = batch_data["batch_labels"].to(device)

src_key_padding_mask = input_gene_ids.eq(vocab[model_settings["pad_token"]])
with torch.cuda.amp.autocast(enabled=hyperparameters["amp"]):
output_dict = model(
input_gene_ids,
input_values,
src_key_padding_mask=src_key_padding_mask,
batch_labels=batch_labels if model_settings["DSBN"] else None,
)
output_values = output_dict["mlm_output"]

masked_positions = input_values.eq(model_settings["mask_value"])
loss = criterion(output_values, target_values, masked_positions)
loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

total_loss += loss.item() * len(input_gene_ids)
total_error += scgpt.loss.masked_relative_error(
output_values, target_values, masked_positions
).item() * len(input_gene_ids)
total_dab += loss_dab.item() * len(input_gene_ids)
total_num += len(input_gene_ids)

return total_loss / total_num, total_error / total_num
Loading
Loading