Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1428 update to the latest usage SaveImage #1429

Merged
merged 1 commit into from
Jun 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions 3d_segmentation/challenge_baseline/run_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import shutil
import sys

import numpy as np
import torch
import torch.nn as nn
from ignite.contrib.handlers import ProgressBar
Expand All @@ -42,7 +41,7 @@ def get_xforms(mode="train", keys=("image", "label")):
"""returns a composed transform for train/val/infer."""

xforms = [
LoadImaged(keys, ensure_channel_first=True),
LoadImaged(keys, ensure_channel_first=True, image_only=True),
Orientationd(keys, axcodes="LPS"),
Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]),
ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
Expand Down Expand Up @@ -239,7 +238,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
)

inferer = get_inferer()
saver = monai.data.NiftiSaver(output_dir=prediction_folder, mode="nearest")
saver = monai.transforms.SaveImage(output_dir=prediction_folder, mode="nearest", resample=True)
with torch.no_grad():
for infer_data in infer_loader:
logging.info(f"segmenting {infer_data['image'].meta['filename_or_obj']}")
Expand All @@ -258,7 +257,8 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
n = n + 1.0
preds = preds / n
preds = (preds.argmax(dim=1, keepdims=True)).float()
saver.save_batch(preds, infer_data["image"].meta)
for p in preds: # save each image+metadata in the batch respectively
saver(p)

# copy the saved segmentations into the required folder structure for submission
submission_dir = os.path.join(prediction_folder, "to_submit")
Expand Down