From 8ba1d08dd636dc97b2a7060abbd35e19ee05e3b1 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Wed, 11 Sep 2024 20:38:44 +0000 Subject: [PATCH] update readme and input check Signed-off-by: Pengfei Guo --- generation/maisi/README.md | 8 ++++++++ generation/maisi/scripts/infer_controlnet.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/generation/maisi/README.md b/generation/maisi/README.md index 789dfb41e..280a5103d 100644 --- a/generation/maisi/README.md +++ b/generation/maisi/README.md @@ -68,6 +68,14 @@ The information for the inference input, like body region and anatomy to generat To generate images with substantial dimensions, such as 512 × 512 × 512 or larger, using GPUs with 80GB of memory, it is advisable to configure the `"num_splits"` parameter in [the auto-encoder configuration](./configs/config_maisi.json#L11-L37) to 16. This adjustment is crucial to avoid out-of-memory issues during inference. +#### Recommended spacing for different output sizes: + +|`output_size`| Recommended `"spacing"`| +|:-----:|:-----:| +[256, 256, 256] | [1.5, 1.5, 1.5] | +[512, 512, 128] | [0.8, 0.8, 2.5] | +[512, 512, 512] | [1.0, 1.0, 1.0] | + #### Execute Inference: To run the inference script, please run: ```bash diff --git a/generation/maisi/scripts/infer_controlnet.py b/generation/maisi/scripts/infer_controlnet.py index cb4d3c9fc..5568eb922 100644 --- a/generation/maisi/scripts/infer_controlnet.py +++ b/generation/maisi/scripts/infer_controlnet.py @@ -23,7 +23,7 @@ from monai.transforms import SaveImage from monai.utils import RankFilter -from .sample import ldm_conditional_sample_one_image +from .sample import check_input, ldm_conditional_sample_one_image from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp @@ -150,10 +150,13 @@ def main(): top_region_index_tensor = batch["top_region_index"].to(device) bottom_region_index_tensor = batch["bottom_region_index"].to(device) spacing_tensor = batch["spacing"].to(device) + out_spacing = tuple((batch["spacing"].squeeze().numpy()/100).tolist()) # get target dimension dim = batch["dim"] output_size = (dim[0].item(), dim[1].item(), dim[2].item()) latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4) + # check if output_size and out_spacing are valid. + check_input(None, None, None, output_size, out_spacing, None) # generate a single synthetic image using a latent diffusion model with controlnet. synthetic_images, _ = ldm_conditional_sample_one_image( autoencoder,