Skip to content

Commit

Permalink
Merge pull request #922 from vahluw/vahluw_patch_size_fix
Browse files Browse the repository at this point in the history
Fixed bug in inference
  • Loading branch information
sarthakpati authored Aug 19, 2024
2 parents 0635d8b + 258cd90 commit e9d92ae
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,16 @@ def validate_network(
if ext in [".jpg", ".jpeg", ".png"]:
pred_mask = pred_mask.astype(np.uint8)

## special case for 2D
if image.shape[-1] > 1:
result_image = sitk.GetImageFromArray(pred_mask)
else:
result_image = sitk.GetImageFromArray(pred_mask.squeeze(0))
pred_mask = (
pred_mask.squeeze(0)
if pred_mask.shape[0] == 1
else (
pred_mask.squeeze(-1)
if pred_mask.shape[-1] == 1
else pred_mask
)
)
result_image = sitk.GetImageFromArray(pred_mask)
result_image.CopyInformation(img_for_metadata)

# this handles cases that need resampling/resizing
Expand Down

0 comments on commit e9d92ae

Please sign in to comment.