Skip to content

Commit

Permalink
Improve metatensor to nifti dimension parsing, other minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 25, 2024
1 parent 18deddf commit 4534cb1
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions spleenseg/core/neuralnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,14 @@ def evalAndCorrect(self) -> float:
return f_loss.item()

def metaTensor_toNIfTI(self, metaTensor: MetaTensor, savefile: Path):
singleVolume: np.ndarray = metaTensor[0, 0].cpu().numpy()
singleVolume: np.ndarray
if metaTensor.dim() == 5:
singleVolume = metaTensor[0, 0].cpu().numpy()
if metaTensor.dim() == 3:
singleVolume = metaTensor.cpu().numpy()
affine: np.ndarray = np.eye(4)
niftiVolume: Nifti1Image = Nifti1Image(singleVolume, affine)
nib.save(niftiVolume, savefile)
pass

def sample_showInfo(
self,
Expand Down Expand Up @@ -571,15 +574,17 @@ def inference_post(
telemetry: data.NIfTItelemetry | None = None,
) -> float:
index: int = int(sample["index"])
print(f"[{index}]--------> novel inference")
sample["pred"] = result
sample = [self.f_outputPost(i) for i in decollate_batch(sample)]
prediction = from_engine(["pred"])(sample)
fi = transforms.f_LoadImage()
input = fi(prediction[0].meta["filename_or_obj"])
Ti = torch.as_tensor(input)
Ti: torch.Tensor = torch.as_tensor(input)
Pi: torch.Tensor = prediction[0]
plotting.plot_infer(
Ti,
prediction,
Pi,
f"{index}",
Path(
Path(self.network.options.outputdir)
Expand Down Expand Up @@ -617,7 +622,7 @@ def infer_usingModel(self, model: Path):
),
]
)
niftiTelemetry.info = ["novel inference input", "novel inference output"]
niftiTelemetry.info = ["novel inference input ", "novel inference output"]
niftiTelemetry.savePath = [
self.trainingParams.novelInference / "input.nii.gz",
self.trainingParams.novelInference / "output.nii.gz",
Expand Down

0 comments on commit 4534cb1

Please sign in to comment.