Skip to content

Commit

Permalink
update train-diffusion.json
Browse files Browse the repository at this point in the history
Signed-off-by: Can-Zhao <[email protected]>
  • Loading branch information
Can-Zhao committed Dec 13, 2023
1 parent 94ad31e commit 9d2d2f9
Showing 1 changed file with 28 additions and 45 deletions.
73 changes: 28 additions & 45 deletions models/brats_mri_generative_diffusion/configs/train_diffusion.json
Original file line number Diff line number Diff line change
@@ -1,55 +1,42 @@
{
"ckpt_dir": "$@bundle_root + '/models'",
"train_batch_size": 4,
"lr": 1e-05,
"train_patch_size": [
144,
176,
112
192,
192,
128
],
"latent_shape": [
"@latent_channels",
36,
44,
28
48,
48,
32
],
"load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'",
"load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))",
"autoencoder": "$@autoencoder_def.to(@device)",
"network_def": {
"diffusion_def": {
"_target_": "generative.networks.nets.DiffusionModelUNet",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels": [
256,
256,
512
],
"attention_levels": [
false,
true,
true
],
"num_head_channels": [
0,
64,
64
],
"num_res_blocks": 2
"num_channels":[128, 256, 512],
"attention_levels":[false, true, true],
"num_head_channels":[0, 32, 32],
"num_res_blocks": 2,
"use_flash_attention": true
},
"diffusion": "$@network_def.to(@device)",
"diffusion": "$@diffusion_def.to(@device)",
"optimizer": {
"_target_": "torch.optim.Adam",
"params": "[email protected]()",
"lr": "@lr"
"lr": 1e-04
},
"lr_scheduler": {
"_target_": "torch.optim.lr_scheduler.MultiStepLR",
"optimizer": "@optimizer",
"milestones": [
100,
1000
5000
],
"gamma": 0.1
},
Expand All @@ -59,20 +46,20 @@
"_requires_": [
"@load_autoencoder"
],
"schedule": "scaled_linear_beta",
"schedule": "linear_beta",
"num_train_timesteps": 1000,
"beta_start": 0.0015,
"beta_end": 0.0195
},
"inferer": {
"_target_": "generative.inferers.LatentDiffusionInferer",
"scheduler": "@noise_scheduler",
"scale_factor": "@scale_factor"
},
"loss": {
"_target_": "torch.nn.MSELoss"
},
"train": {
"inferer": {
"_target_": "generative.inferers.LatentDiffusionInferer",
"scheduler": "@noise_scheduler",
"scale_factor": "@scale_factor"
},
"crop_transforms": [
{
"_target_": "CenterSpatialCropd",
Expand All @@ -85,13 +72,8 @@
"transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms"
},
"dataset": {
"_target_": "monai.apps.DecathlonDataset",
"root_dir": "@dataset_dir",
"task": "Task01_BrainTumour",
"section": "training",
"cache_rate": 1.0,
"num_workers": 8,
"download": false,
"_target_": "Dataset",
"data": "@train_datalist",
"transform": "@train#preprocessing"
},
"dataloader": {
Expand All @@ -116,7 +98,7 @@
"save_interval": 0,
"save_final": true,
"epoch_level": true,
"final_filename": "model.pt"
"final_filename": "model_ldm.pt"
},
{
"_target_": "StatsHandler",
Expand All @@ -127,20 +109,21 @@
"_target_": "TensorBoardStatsHandler",
"log_dir": "@tf_dir",
"tag_name": "train_diffusion_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)"
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)",
"iteration_log":false
}
],
"trainer": {
"_target_": "scripts.ldm_trainer.LDMTrainer",
"device": "@device",
"max_epochs": 5000,
"max_epochs": 10000,
"train_data_loader": "@train#dataloader",
"network": "@diffusion",
"autoencoder_model": "@autoencoder",
"optimizer": "@optimizer",
"loss_function": "@loss",
"latent_shape": "@latent_shape",
"inferer": "@train#inferer",
"inferer": "@inferer",
"key_train_metric": "$None",
"train_handlers": "@train#handlers"
}
Expand Down

0 comments on commit 9d2d2f9

Please sign in to comment.