Skip to content

Commit

Permalink
Merge branch 'main' into anomaly_detection_porting
Browse files Browse the repository at this point in the history
  • Loading branch information
ericspod authored Sep 16, 2024
2 parents 7a1055a + 236a883 commit 9853f31
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion detection/generate_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def generate_detection_train_transform(
spatial_axes=(0, 1),
),
# apply the same affine matrix which already applied on the images to the points
ApplyTransformToPointsd(keys=[point_key], refer_key=image_key, affine_lps_to_ras=affine_lps_to_ras),
ApplyTransformToPointsd(keys=[point_key], refer_keys=image_key, affine_lps_to_ras=affine_lps_to_ras),
# convert points back to boxes
ConvertPointsToBoxesd(keys=[point_key]),
ClipBoxToImaged(
Expand Down
8 changes: 8 additions & 0 deletions generation/maisi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ The information for the inference input, like body region and anatomy to generat

To generate images with substantial dimensions, such as 512 × 512 × 512 or larger, using GPUs with 80GB of memory, it is advisable to configure the `"num_splits"` parameter in [the auto-encoder configuration](./configs/config_maisi.json#L11-L37) to 16. This adjustment is crucial to avoid out-of-memory issues during inference.

#### Recommended spacing for different output sizes:

|`output_size`| Recommended `"spacing"`|
|:-----:|:-----:|
[256, 256, 256] | [1.5, 1.5, 1.5] |
[512, 512, 128] | [0.8, 0.8, 2.5] |
[512, 512, 512] | [1.0, 1.0, 1.0] |

#### Execute Inference:
To run the inference script, please run:
```bash
Expand Down
5 changes: 4 additions & 1 deletion generation/maisi/scripts/infer_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.transforms import SaveImage
from monai.utils import RankFilter

from .sample import ldm_conditional_sample_one_image
from .sample import check_input, ldm_conditional_sample_one_image
from .utils import define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp


Expand Down Expand Up @@ -150,10 +150,13 @@ def main():
top_region_index_tensor = batch["top_region_index"].to(device)
bottom_region_index_tensor = batch["bottom_region_index"].to(device)
spacing_tensor = batch["spacing"].to(device)
out_spacing = tuple((batch["spacing"].squeeze().numpy() / 100).tolist())
# get target dimension
dim = batch["dim"]
output_size = (dim[0].item(), dim[1].item(), dim[2].item())
latent_shape = (args.latent_channels, output_size[0] // 4, output_size[1] // 4, output_size[2] // 4)
# check if output_size and out_spacing are valid.
check_input(None, None, None, output_size, out_spacing, None)
# generate a single synthetic image using a latent diffusion model with controlnet.
synthetic_images, _ = ldm_conditional_sample_one_image(
autoencoder,
Expand Down

0 comments on commit 9853f31

Please sign in to comment.