Skip to content

Commit

Permalink
update readme and input check
Browse files Browse the repository at this point in the history
Signed-off-by: Pengfei Guo <[email protected]>
  • Loading branch information
guopengf committed Sep 11, 2024
1 parent 08ccc1e commit 8ba1d08
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
8 changes: 8 additions & 0 deletions generation/maisi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ The information for the inference input, like body region and anatomy to generat

To generate images with substantial dimensions, such as 512 &times; 512 &times; 512 or larger, using GPUs with 80GB of memory, it is advisable to configure the `"num_splits"` parameter in [the auto-encoder configuration](./configs/config_maisi.json#L11-L37) to 16. This adjustment is crucial to avoid out-of-memory issues during inference.

#### Recommended spacing for different output sizes:

|`output_size`| Recommended `"spacing"`|
|:-----:|:-----:|
[256, 256, 256] | [1.5, 1.5, 1.5] |
[512, 512, 128] | [0.8, 0.8, 2.5] |
[512, 512, 512] | [1.0, 1.0, 1.0] |

#### Execute Inference:
To run the inference script, please run:
```bash
Expand Down
5 changes: 4 additions & 1 deletion generation/maisi/scripts/infer_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.transforms import SaveImage
from monai.utils import RankFilter

from .sample import ldm_conditional_sample_one_image
from .sample import check_input, ldm_conditional_sample_one_image
from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp


Expand Down Expand Up @@ -150,10 +150,13 @@ def main():
top_region_index_tensor = batch["top_region_index"].to(device)
bottom_region_index_tensor = batch["bottom_region_index"].to(device)
spacing_tensor = batch["spacing"].to(device)
out_spacing = tuple((batch["spacing"].squeeze().numpy()/100).tolist())
# get target dimension
dim = batch["dim"]
output_size = (dim[0].item(), dim[1].item(), dim[2].item())
latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4)
# check if output_size and out_spacing are valid.
check_input(None, None, None, output_size, out_spacing, None)
# generate a single synthetic image using a latent diffusion model with controlnet.
synthetic_images, _ = ldm_conditional_sample_one_image(
autoencoder,
Expand Down

0 comments on commit 8ba1d08

Please sign in to comment.