Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix register_fields bug in tuner and add dependencies #55

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions cpa/_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
setup_anndata_kwargs: dict[str, Any] = {},
use_wandb: bool = False,
wandb_name: str = "cpa_tune",
wandb_api_key: str | None = None,
plan_kwargs_keys: list[str] = [],
) -> None:
self.model_cls = model_cls
Expand All @@ -153,6 +154,7 @@ def __init__(
self.use_wandb = use_wandb
self.wandb_name = wandb_name
self.plan_kwargs_keys = plan_kwargs_keys
self.wandb_api_key = wandb_api_key

@property
def id(self) -> str:
Expand Down Expand Up @@ -513,7 +515,7 @@ def get_tuner(self) -> Tuner:
local_dir=self.logging_dir,
log_to_file=True,
verbose=1,
callbacks=[WandbLoggerCallback(project=self.wandb_name)] if self.use_wandb else None,
callbacks=[WandbLoggerCallback(project=self.wandb_name, api_key=self.wandb_api_key)] if self.use_wandb else None,

)
return Tuner(
Expand Down Expand Up @@ -593,17 +595,17 @@ def _trainable(
import gc
gc.collect()
elif isinstance(experiment.data, (AnnData, MuData)):
getattr(experiment.model_cls, experiment.setup_method_name)(
experiment.data,
**experiment.setup_method_args,
)
experiment.model_cls.setup_anndata(adata, **setup_anndata_kwargs)

model = experiment.model_cls(adata, **model_args)
model.train(max_epochs=2000,
use_gpu=True,
early_stopping_patience=10,
check_val_every_n_epoch=5,
# getattr(experiment.model_cls, experiment.setup_method_name)(
# experiment.data,
# # **experiment.setup_method_args,
# )
experiment.model_cls.setup_anndata(experiment.data, **setup_anndata_kwargs)

model = experiment.model_cls(experiment.data, **model_args)
model.train(max_epochs=train_args.pop("max_epochs",2000),
use_gpu=train_args.pop("use_gpu",True),
early_stopping_patience=train_args.pop("early_stopping_patience",10),
check_val_every_n_epoch=train_args.pop("check_val_every_n_epoch",5),
plan_kwargs=plan_kwargs,
**train_args)
Comment on lines +598 to 610
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have this section similar to the part above it?
Eventually, something like this I think:

Suggested change
# getattr(experiment.model_cls, experiment.setup_method_name)(
# experiment.data,
# # **experiment.setup_method_args,
# )
experiment.model_cls.setup_anndata(experiment.data, **setup_anndata_kwargs)
model = experiment.model_cls(experiment.data, **model_args)
model.train(max_epochs=train_args.pop("max_epochs",2000),
use_gpu=train_args.pop("use_gpu",True),
early_stopping_patience=train_args.pop("early_stopping_patience",10),
check_val_every_n_epoch=train_args.pop("check_val_every_n_epoch",5),
plan_kwargs=plan_kwargs,
**train_args)
adata = experiment.data.copy()
if sub_sample is not None:
sc.pp.subsample(adata, fraction=sub_sample)
experiment.model_cls.setup_anndata(adata, **setup_anndata_kwargs)
model = experiment.model_cls(adata, **model_args)
model.train(plan_kwargs=plan_kwargs, **train_args)
del adata
import gc
gc.collect()

else: # NOT TESTED
Expand Down Expand Up @@ -631,6 +633,7 @@ def run_autotune(
setup_anndata_kwargs: dict[str, Any] = {},
use_wandb: bool = False,
wandb_name: str = "cpa_tune",
wandb_api_key: str | None = None,
plan_kwargs_keys: list[str] = [],
) -> AutotuneExperiment:
"""``BETA`` Run a hyperparameter sweep.
Expand Down Expand Up @@ -732,6 +735,7 @@ def run_autotune(
setup_anndata_kwargs=setup_anndata_kwargs,
use_wandb=use_wandb,
wandb_name=wandb_name,
wandb_api_key=wandb_api_key,
plan_kwargs_keys=plan_kwargs_keys,
)
logger.info(f"Running autotune experiment {experiment.name}.")
Expand Down
317 changes: 164 additions & 153 deletions examples/tune_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,159 +6,170 @@
import numpy as np

import pickle
import gdown

DATA_PATH = '/PATH/TO/DATA.h5ad' # Change this to your desired path
adata = sc.read_h5ad(DATA_PATH)
adata.X = adata.layers['counts'].copy() # Counts should be available in the 'counts' layer
sc.pp.subsample(adata, fraction=0.1)
WANDB_API_KEY = "api_key" #add wandb api key
DATA_PATH = 'data/to/load.h5ad'
OUTPUT_PATH = "result/output/path.pkl"

model_args = {
'n_latent': tune.choice([32, 64, 128, 256]),
'recon_loss': tune.choice(['nb']),
'doser_type': tune.choice(['logsigm']),
def main():

'n_hidden_encoder': tune.choice([128, 256, 512, 1024]),
'n_layers_encoder': tune.choice([1, 2, 3, 4, 5]),

'n_hidden_decoder': tune.choice([128, 256, 512, 1024]),
'n_layers_decoder': tune.choice([1, 2, 3, 4, 5]),

'use_batch_norm_encoder': tune.choice([True, False]),
'use_layer_norm_encoder': tune.sample_from(
lambda spec: False if spec.config.model_args.use_batch_norm_encoder else np.random.choice([True, False])),

'use_batch_norm_decoder': tune.choice([True, False]),
'use_layer_norm_decoder': tune.sample_from(
lambda spec: False if spec.config.model_args.use_batch_norm_decoder else np.random.choice([True, False])),

'dropout_rate_encoder': tune.choice([0.0, 0.1, 0.2, 0.25]),
'dropout_rate_decoder': tune.choice([0.0, 0.1, 0.2, 0.25]),

'variational': tune.choice([False]),

'seed': tune.randint(0, 10000),

'split_key': 'split_1ct_MEC',
'train_split': 'train',
'valid_split': 'valid',
'test_split': 'ood',
}

train_args = {
##################### plan_kwargs #####################
'n_epochs_adv_warmup': tune.choice([0, 1, 3, 5, 10, 50, 70]),
'n_epochs_kl_warmup': tune.choice([None]),
# lambda spec: None if not spec.config.model_args.variational else np.random.choice([0, 1, 3, 5, 10])), # Use this if you're using variational=True as well

'n_epochs_pretrain_ae': tune.choice([0, 1, 3, 5, 10, 30, 50]),

'adv_steps': tune.choice([2, 3, 5, 10, 15, 20, 25, 30]),

'mixup_alpha': tune.choice([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]),
'n_epochs_mixup_warmup': tune.sample_from(
lambda spec: 0 if spec.config.train_args.mixup_alpha == 0.0 else np.random.choice([
0, 1, 3, 5, 10]),
),

'n_layers_adv': tune.choice([1, 2, 3, 4, 5]),
'n_hidden_adv': tune.choice([32, 64, 128, 256]),

'use_batch_norm_adv': tune.choice([True, False]),
'use_layer_norm_adv': tune.sample_from(
lambda spec: False if spec.config.train_args.use_batch_norm_adv else np.random.choice([True, False])),

'dropout_rate_adv': tune.choice([0.0, 0.1, 0.2, 0.25, 0.3]),

'pen_adv': tune.loguniform(1e-2, 1e2),
'reg_adv': tune.loguniform(1e-2, 1e2),

'lr': tune.loguniform(1e-5, 1e-2),
'wd': tune.loguniform(1e-8, 1e-5),

'doser_lr': tune.loguniform(1e-5, 1e-2),
'doser_wd': tune.loguniform(1e-8, 1e-5),

'adv_lr': tune.loguniform(1e-5, 1e-2),
'adv_wd': tune.loguniform(1e-8, 1e-5),

'adv_loss': tune.choice(['cce']),

'do_clip_grad': tune.choice([True, False]),
# 'gradient_clip_value': tune.loguniform(1e-2, 1e2),
'gradient_clip_value': tune.choice([1.0]),

'step_size_lr': tune.choice([10, 25, 45]),
}
plan_kwargs_keys = list(train_args.keys())

trainer_actual_args = {
'max_epochs': 200,
'use_gpu': True,
'early_stopping_patience': 10,
'check_val_every_n_epoch': 5,
}
train_args.update(trainer_actual_args)

search_space = {
'model_args': model_args,
'train_args': train_args,
}

scheduler_kwargs = {
# 'mode': 'max',
# 'metric': 'cpa_metric',
'max_t': 1000,
'grace_period': 5,
'reduction_factor': 4,
}

# searcher_kwargs = {
# 'mode': 'max',
# 'metric': 'cpa_metric',
# }

setup_anndata_kwargs = {
'perturbation_key': 'condition_ID',
'dosage_key': 'log_dose',
'control_group': 'CHEMBL504',
'batch_key': None,
'is_count_data': True,
'categorical_covariate_keys': ['cell_type'],
'deg_uns_key': 'rank_genes_groups_cov',
'deg_uns_cat_key': 'cov_drug_dose',
'max_comb_len': 2,
}
model = cpa.CPA
model.setup_anndata(adata, **setup_anndata_kwargs)

experiment = run_autotune(
model_cls=model,
data=adata,
metrics=["cpa_metric", # The first one (cpa_metric) is the one that will be used for optimization "MAIN ONE"
"disnt_basal",
"disnt_after",
"r2_mean",
"val_r2_mean",
"val_r2_var",
"val_recon",],
mode="max",
search_space=search_space,
num_samples=5000, # Change this to your desired number of samples (Number of runs)
scheduler="asha",
searcher="hyperopt",
seed=1,
resources={"cpu": 40, "gpu": 0.2, "memory": 16 * 1024 * 1024 * 1024}, # Change this to your desired resources
experiment_name="cpa_autotune", # Change this to your desired experiment name
logging_dir='/PATH/TO/LOGS/', # Change this to your desired path
adata_path=DATA_PATH,
sub_sample=0.1,
setup_anndata_kwargs=setup_anndata_kwargs,
use_wandb=False, # If you want to use wandb, set this to True
wandb_name="cpa_tune", # Change this to your desired wandb project name
scheduler_kwargs=scheduler_kwargs,
plan_kwargs_keys=plan_kwargs_keys,
# searcher_kwargs=searcher_kwargs,
)
result_grid = experiment.result_grid
with open('result_grid.pkl', 'wb') as f:
pickle.dump(result_grid, f)
adata = sc.read(DATA_PATH)

adata.X = adata.layers['counts'].copy() # Counts should be available in the 'counts' layer
sc.pp.subsample(adata, fraction=0.1)

model_args = {
'n_latent': tune.choice([32, 64, 128, 256]),
'recon_loss': tune.choice(['nb']),
'doser_type': tune.choice(['logsigm']),

'n_hidden_encoder': tune.choice([128, 256, 512, 1024]),
'n_layers_encoder': tune.choice([1, 2, 3, 4, 5]),

'n_hidden_decoder': tune.choice([128, 256, 512, 1024]),
'n_layers_decoder': tune.choice([1, 2, 3, 4, 5]),

'use_batch_norm_encoder': tune.choice([True, False]),
'use_layer_norm_encoder': tune.sample_from(
lambda spec: False if spec.config.model_args.use_batch_norm_encoder else np.random.choice([True, False])),

'use_batch_norm_decoder': tune.choice([True, False]),
'use_layer_norm_decoder': tune.sample_from(
lambda spec: False if spec.config.model_args.use_batch_norm_decoder else np.random.choice([True, False])),

'dropout_rate_encoder': tune.choice([0.0, 0.1, 0.2, 0.25]),
'dropout_rate_decoder': tune.choice([0.0, 0.1, 0.2, 0.25]),

'variational': tune.choice([False]),

'seed': tune.randint(0, 10000),

'split_key': 'split_1ct_MEC',
'train_split': 'train',
'valid_split': 'valid',
'test_split': 'ood',
}

train_args = {
##################### plan_kwargs #####################
'n_epochs_adv_warmup': tune.choice([0, 1, 3, 5, 10, 50, 70]),
'n_epochs_kl_warmup': tune.choice([None]),
# lambda spec: None if not spec.config.model_args.variational else np.random.choice([0, 1, 3, 5, 10])), # Use this if you're using variational=True as well

'n_epochs_pretrain_ae': tune.choice([0, 1, 3, 5, 10, 30, 50]),

'adv_steps': tune.choice([2, 3, 5, 10, 15, 20, 25, 30]),

'mixup_alpha': tune.choice([0.0, 0.1, 0.2, 0.3, 0.4, 0.5]),
'n_epochs_mixup_warmup': tune.sample_from(
lambda spec: 0 if spec.config.train_args.mixup_alpha == 0.0 else np.random.choice([
0, 1, 3, 5, 10]),
),

'n_layers_adv': tune.choice([1, 2, 3, 4, 5]),
'n_hidden_adv': tune.choice([32, 64, 128, 256]),

'use_batch_norm_adv': tune.choice([True, False]),
'use_layer_norm_adv': tune.sample_from(
lambda spec: False if spec.config.train_args.use_batch_norm_adv else np.random.choice([True, False])),

'dropout_rate_adv': tune.choice([0.0, 0.1, 0.2, 0.25, 0.3]),

'pen_adv': tune.loguniform(1e-2, 1e2),
'reg_adv': tune.loguniform(1e-2, 1e2),

'lr': tune.loguniform(1e-5, 1e-2),
'wd': tune.loguniform(1e-8, 1e-5),

'doser_lr': tune.loguniform(1e-5, 1e-2),
'doser_wd': tune.loguniform(1e-8, 1e-5),

'adv_lr': tune.loguniform(1e-5, 1e-2),
'adv_wd': tune.loguniform(1e-8, 1e-5),

'adv_loss': tune.choice(['cce']),

'do_clip_grad': tune.choice([True, False]),
# 'gradient_clip_value': tune.loguniform(1e-2, 1e2),
'gradient_clip_value': tune.choice([1.0]),

'step_size_lr': tune.choice([10, 25, 45]),
}
plan_kwargs_keys = list(train_args.keys())

trainer_actual_args = {
'max_epochs': 20,
'use_gpu': True,
'early_stopping_patience': 10,
'check_val_every_n_epoch': 5,
}
train_args.update(trainer_actual_args)

search_space = {
'model_args': model_args,
'train_args': train_args,
}

scheduler_kwargs = {
# 'mode': 'max',
# 'metric': 'cpa_metric',
'max_t': 1000,
'grace_period': 5,
'reduction_factor': 4,
}

# searcher_kwargs = {
# 'mode': 'max',
# 'metric': 'cpa_metric',
# }

setup_anndata_kwargs = {
'perturbation_key': 'condition_ID',
'dosage_key': 'log_dose',
'control_group': 'CHEMBL504',
'batch_key': None,
'is_count_data': True,
'categorical_covariate_keys': ['cell_type'],
'deg_uns_key': 'rank_genes_groups_cov',
'deg_uns_cat_key': 'cov_drug_dose',
'max_comb_len': 2,
}
model = cpa.CPA
model.setup_anndata(adata, **setup_anndata_kwargs)

experiment = run_autotune(
model_cls=model,
data=adata,
metrics=["cpa_metric", # The first one (cpa_metric) is the one that will be used for optimization "MAIN ONE"
"disnt_basal",
"disnt_after",
"r2_mean",
"val_r2_mean",
"val_r2_var",
"val_recon",],
mode="max",
search_space=search_space,
num_samples=20, # Change this to your desired number of samples (Number of runs)
scheduler="asha",
searcher="hyperopt",
seed=1,
resources={"cpu": 40, "gpu": 0.2, "memory": 16 * 1024 * 1024 * 1024}, # Change this to your desired resources
experiment_name="cpa_autotune", # Change this to your desired experiment name
# logging_dir="", # Change this to your desired path
adata_path=None,
sub_sample=0.1,
setup_anndata_kwargs=setup_anndata_kwargs,
use_wandb=True, # If you want to use wandb, set this to True
wandb_name="cpa_tune", # Change this to your desired wandb project name
wandb_api_key = WANDB_API_KEY, #add wandb api key
scheduler_kwargs=scheduler_kwargs,
plan_kwargs_keys=plan_kwargs_keys,
# searcher_kwargs=searcher_kwargs,
)
result_grid = experiment.result_grid
with open(OUTPUT_PATH, 'wb') as f: # change to desired output path
pickle.dump(result_grid, f)

if __name__ == "__main__":
main()
Loading