diff --git a/models/brats_mri_generative_diffusion/configs/train_diffusion.json b/models/brats_mri_generative_diffusion/configs/train_diffusion.json index 85c8ca8a..441b20ee 100644 --- a/models/brats_mri_generative_diffusion/configs/train_diffusion.json +++ b/models/brats_mri_generative_diffusion/configs/train_diffusion.json @@ -1,55 +1,42 @@ { "ckpt_dir": "$@bundle_root + '/models'", "train_batch_size": 4, - "lr": 1e-05, "train_patch_size": [ - 144, - 176, - 112 + 192, + 192, + 128 ], "latent_shape": [ "@latent_channels", - 36, - 44, - 28 + 48, + 48, + 32 ], "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", "autoencoder": "$@autoencoder_def.to(@device)", - "network_def": { + "diffusion_def": { "_target_": "generative.networks.nets.DiffusionModelUNet", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "out_channels": "@latent_channels", - "num_channels": [ - 256, - 256, - 512 - ], - "attention_levels": [ - false, - true, - true - ], - "num_head_channels": [ - 0, - 64, - 64 - ], - "num_res_blocks": 2 + "num_channels":[128, 256, 512], + "attention_levels":[false, true, true], + "num_head_channels":[0, 32, 32], + "num_res_blocks": 2, + "use_flash_attention": true }, - "diffusion": "$@network_def.to(@device)", + "diffusion": "$@diffusion_def.to(@device)", "optimizer": { "_target_": "torch.optim.Adam", "params": "$@diffusion.parameters()", - "lr": "@lr" + "lr": 1e-04 }, "lr_scheduler": { "_target_": "torch.optim.lr_scheduler.MultiStepLR", "optimizer": "@optimizer", "milestones": [ - 100, - 1000 + 5000 ], "gamma": 0.1 }, @@ -59,20 +46,20 @@ "_requires_": [ "@load_autoencoder" ], - "schedule": "scaled_linear_beta", + "schedule": "linear_beta", "num_train_timesteps": 1000, "beta_start": 0.0015, "beta_end": 0.0195 }, + "inferer": { + "_target_": "generative.inferers.LatentDiffusionInferer", + "scheduler": "@noise_scheduler", + "scale_factor": "@scale_factor" + }, "loss": { "_target_": "torch.nn.MSELoss" }, "train": { - "inferer": { - "_target_": "generative.inferers.LatentDiffusionInferer", - "scheduler": "@noise_scheduler", - "scale_factor": "@scale_factor" - }, "crop_transforms": [ { "_target_": "CenterSpatialCropd", @@ -85,13 +72,8 @@ "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms" }, "dataset": { - "_target_": "monai.apps.DecathlonDataset", - "root_dir": "@dataset_dir", - "task": "Task01_BrainTumour", - "section": "training", - "cache_rate": 1.0, - "num_workers": 8, - "download": false, + "_target_": "Dataset", + "data": "@train_datalist", "transform": "@train#preprocessing" }, "dataloader": { @@ -116,7 +98,7 @@ "save_interval": 0, "save_final": true, "epoch_level": true, - "final_filename": "model.pt" + "final_filename": "model_ldm.pt" }, { "_target_": "StatsHandler", @@ -127,20 +109,21 @@ "_target_": "TensorBoardStatsHandler", "log_dir": "@tf_dir", "tag_name": "train_diffusion_loss", - "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)" + "output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)", + "iteration_log":false } ], "trainer": { "_target_": "scripts.ldm_trainer.LDMTrainer", "device": "@device", - "max_epochs": 5000, + "max_epochs": 10000, "train_data_loader": "@train#dataloader", "network": "@diffusion", "autoencoder_model": "@autoencoder", "optimizer": "@optimizer", "loss_function": "@loss", "latent_shape": "@latent_shape", - "inferer": "@train#inferer", + "inferer": "@inferer", "key_train_metric": "$None", "train_handlers": "@train#handlers" }