From 9db26a4d83e01c8cb9644cd96c1a5374ec818a2b Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Sat, 2 Dec 2023 16:54:37 -0500 Subject: [PATCH] Bug fixes --- diffpose/_modidx.py | 24 ++- diffpose/deepfluoro.py | 25 +++- diffpose/ljubljana.py | 114 ++++++++++++++- diffpose/registration.py | 47 +++--- notebooks/api/00_deepfluoro.ipynb | 42 +++++- notebooks/api/01_ljubljana.ipynb | 219 +++++++++++++++++++++++++++- notebooks/api/03_registration.ipynb | 51 ++----- 7 files changed, 445 insertions(+), 77 deletions(-) diff --git a/diffpose/_modidx.py b/diffpose/_modidx.py index f6fde5a..cf5148d 100644 --- a/diffpose/_modidx.py +++ b/diffpose/_modidx.py @@ -54,6 +54,8 @@ 'diffpose/deepfluoro.py'), 'diffpose.deepfluoro.get_3d_fiducials': ( 'api/deepfluoro.html#get_3d_fiducials', 'diffpose/deepfluoro.py'), + 'diffpose.deepfluoro.get_random_offset': ( 'api/deepfluoro.html#get_random_offset', + 'diffpose/deepfluoro.py'), 'diffpose.deepfluoro.load_deepfluoro_dataset': ( 'api/deepfluoro.html#load_deepfluoro_dataset', 'diffpose/deepfluoro.py'), 'diffpose.deepfluoro.parse_proj_params': ( 'api/deepfluoro.html#parse_proj_params', @@ -72,13 +74,29 @@ 'diffpose/jacobians.py'), 'diffpose.jacobians.plot_img_jacobian': ( 'api/jacobians.html#plot_img_jacobian', 'diffpose/jacobians.py')}, - 'diffpose.ljubljana': { 'diffpose.ljubljana.LjubljanaDataset': ('api/ljubljana.html#ljubljanadataset', 'diffpose/ljubljana.py'), + 'diffpose.ljubljana': { 'diffpose.ljubljana.Evaluator': ('api/ljubljana.html#evaluator', 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.Evaluator.__call__': ( 'api/ljubljana.html#evaluator.__call__', + 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.Evaluator.__init__': ( 'api/ljubljana.html#evaluator.__init__', + 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.Evaluator.project': ( 'api/ljubljana.html#evaluator.project', + 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.LjubljanaDataset': ('api/ljubljana.html#ljubljanadataset', 'diffpose/ljubljana.py'), 'diffpose.ljubljana.LjubljanaDataset.__getitem__': ( 'api/ljubljana.html#ljubljanadataset.__getitem__', 'diffpose/ljubljana.py'), 'diffpose.ljubljana.LjubljanaDataset.__init__': ( 'api/ljubljana.html#ljubljanadataset.__init__', 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.LjubljanaDataset.__iter__': ( 'api/ljubljana.html#ljubljanadataset.__iter__', + 'diffpose/ljubljana.py'), 'diffpose.ljubljana.LjubljanaDataset.__len__': ( 'api/ljubljana.html#ljubljanadataset.__len__', - 'diffpose/ljubljana.py')}, + 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.Transforms': ('api/ljubljana.html#transforms', 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.Transforms.__call__': ( 'api/ljubljana.html#transforms.__call__', + 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.Transforms.__init__': ( 'api/ljubljana.html#transforms.__init__', + 'diffpose/ljubljana.py'), + 'diffpose.ljubljana.get_random_offset': ( 'api/ljubljana.html#get_random_offset', + 'diffpose/ljubljana.py')}, 'diffpose.metrics': { 'diffpose.metrics.CustomMetric': ('api/metrics.html#custommetric', 'diffpose/metrics.py'), 'diffpose.metrics.CustomMetric.__init__': ( 'api/metrics.html#custommetric.__init__', 'diffpose/metrics.py'), @@ -136,8 +154,6 @@ 'diffpose/registration.py'), 'diffpose.registration.VectorizedNormalizedCrossCorrelation2d.norm': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d.norm', 'diffpose/registration.py'), - 'diffpose.registration.get_random_offset': ( 'api/registration.html#get_random_offset', - 'diffpose/registration.py'), 'diffpose.registration.img_to_patches': ( 'api/registration.html#img_to_patches', 'diffpose/registration.py'), 'diffpose.registration.mask_to_img': ( 'api/registration.html#mask_to_img', diff --git a/diffpose/deepfluoro.py b/diffpose/deepfluoro.py index c2a5224..10ccbb1 100644 --- a/diffpose/deepfluoro.py +++ b/diffpose/deepfluoro.py @@ -2,7 +2,7 @@ # %% auto 0 __all__ = ['DeepFluoroDataset', 'convert_deepfluoro_to_diffdrr', 'convert_diffdrr_to_deepfluoro', 'Evaluator', 'preprocess', - 'Transforms'] + 'get_random_offset', 'Transforms'] # %% ../notebooks/api/00_deepfluoro.ipynb 3 from pathlib import Path @@ -305,7 +305,28 @@ def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)): img = (img - img.min()) / (img.max() - img.min()) return img -# %% ../notebooks/api/00_deepfluoro.ipynb 30 +# %% ../notebooks/api/00_deepfluoro.ipynb 26 +from beartype import beartype +from pytorch3d.transforms import se3_exp_map + +from .calibration import RigidTransform + + +@beartype +def get_random_offset(batch_size: int, device) -> RigidTransform: + t1 = torch.distributions.Normal(10, 70).sample((batch_size,)) + t2 = torch.distributions.Normal(250, 90).sample((batch_size,)) + t3 = torch.distributions.Normal(5, 50).sample((batch_size,)) + r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,)) + r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) + r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,)) + logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device) + T = se3_exp_map(logmap) + R = T[..., :3, :3].transpose(-1, -2) + t = T[..., 3, :3] + return RigidTransform(R, t) + +# %% ../notebooks/api/00_deepfluoro.ipynb 32 from torchvision.transforms import Compose, Lambda, Normalize, Resize diff --git a/diffpose/ljubljana.py b/diffpose/ljubljana.py index e70a953..68c2b26 100644 --- a/diffpose/ljubljana.py +++ b/diffpose/ljubljana.py @@ -1,7 +1,7 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/01_ljubljana.ipynb. # %% auto 0 -__all__ = ['LjubljanaDataset'] +__all__ = ['LjubljanaDataset', 'get_random_offset', 'Evaluator', 'Transforms'] # %% ../notebooks/api/01_ljubljana.ipynb 3 from pathlib import Path @@ -53,6 +53,9 @@ def __init__( def __len__(self): return 10 + def __iter__(self): + return iter(self[idx] for idx in range(len(self))) + def __getitem__(self, idx): idx += 1 extrinsic = self.f[f"subject{idx:02d}/proj-{self.view}/extrinsic"][:] @@ -69,8 +72,9 @@ def __getitem__(self, idx): if self.preprocess: img += 1 img = img.max().log() - img.log() - height, width = img.shape + img = img.unsqueeze(0).unsqueeze(0) + focal_len, x0, y0 = parse_intrinsic_matrix( intrinsic, height, @@ -88,6 +92,13 @@ def __getitem__(self, idx): volume = self.f[f"subject{idx:02d}/volume/pixels"][:] spacing = self.f[f"subject{idx:02d}/volume/spacing"][:] + isocenter_rot = torch.tensor([[torch.pi / 2, 0.0, -torch.pi / 2]]) + isocenter_xyz = torch.tensor(volume.shape) * spacing / 2 + isocenter_xyz = isocenter_xyz.unsqueeze(0) + isocenter_pose = RigidTransform( + isocenter_rot, isocenter_xyz, "euler_angles", "ZYX" + ) + return ( volume, spacing, @@ -100,4 +111,103 @@ def __getitem__(self, idx): y0, img, pose, + isocenter_pose, ) + +# %% ../notebooks/api/01_ljubljana.ipynb 10 +from beartype import beartype +from pytorch3d.transforms import se3_exp_map + +from .calibration import RigidTransform + + +@beartype +def get_random_offset(view, batch_size: int, device) -> RigidTransform: + if view == "ap": + t1 = torch.distributions.Normal(-10, 20).sample((batch_size,)) + t2 = torch.distributions.Normal(175, 30).sample((batch_size,)) + t3 = torch.distributions.Normal(-5, 15).sample((batch_size,)) + r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) + r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) + r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,)) + elif view == "lat": + t1 = torch.distributions.Normal(75, 15).sample((batch_size,)) + t2 = torch.distributions.Normal(-80, 20).sample((batch_size,)) + t3 = torch.distributions.Normal(-5, 10).sample((batch_size,)) + r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) + r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) + r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,)) + else: + raise ValueError(f"view must be 'ap' or 'lat', not {view}") + + logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device) + T = se3_exp_map(logmap) + R = T[..., :3, :3].transpose(-1, -2) + t = T[..., 3, :3] + return RigidTransform(R, t) + +# %% ../notebooks/api/01_ljubljana.ipynb 12 +from torch.nn.functional import pad + +from .calibration import perspective_projection + + +class Evaluator: + def __init__(self, specimen, idx): + # Save matrices to device + self.translate = specimen.translate + self.flip_xz = specimen.flip_xz + self.intrinsic = specimen.intrinsic + self.intrinsic_inv = specimen.intrinsic.inverse() + + # Get gt fiducial locations + self.specimen = specimen + self.fiducials = specimen.fiducials + gt_pose = specimen[idx][1] + self.true_projected_fiducials = self.project(gt_pose) + + def project(self, pose): + extrinsic = convert_diffdrr_to_deepfluoro(self.specimen, pose) + x = perspective_projection(extrinsic, self.intrinsic, self.fiducials) + x = -self.specimen.focal_len * torch.einsum( + "ij, bnj -> bni", + self.intrinsic_inv, + pad(x, (0, 1), value=1), # Convert to homogenous coordinates + ) + extrinsic = ( + self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose) + ) + return extrinsic.transform_points(x) + + def __call__(self, pose): + pred_projected_fiducials = self.project(pose) + registration_error = ( + (self.true_projected_fiducials - pred_projected_fiducials) + .norm(dim=-1) + .mean() + ) + registration_error *= 0.154 # Pixel spacing is 0.154 mm / pixel isotropic + return registration_error + +# %% ../notebooks/api/01_ljubljana.ipynb 15 +from torchvision.transforms import Compose, Lambda, Normalize, Resize + + +class Transforms: + def __init__( + self, + height: int, + width: int, + eps: float = 1e-6, + ): + """Transform X-rays and DRRs before inputting to CNN.""" + self.transforms = Compose( + [ + Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + eps)), + Resize((height, width), antialias=True), + Normalize(mean=0.0774, std=0.0569), + ] + ) + + def __call__(self, x): + return self.transforms(x) diff --git a/diffpose/registration.py b/diffpose/registration.py index 12136e1..7d87a7e 100644 --- a/diffpose/registration.py +++ b/diffpose/registration.py @@ -1,13 +1,16 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/03_registration.ipynb. # %% auto 0 -__all__ = ['PoseRegressor', 'get_random_offset', 'SparseRegistration', 'VectorizedNormalizedCrossCorrelation2d'] +__all__ = ['PoseRegressor', 'SparseRegistration', 'VectorizedNormalizedCrossCorrelation2d'] # %% ../notebooks/api/03_registration.ipynb 3 import timm import torch # %% ../notebooks/api/03_registration.ipynb 5 +from .calibration import RigidTransform, convert + + class PoseRegressor(torch.nn.Module): """ A PoseRegressor is comprised of a pretrained backbone model that extracts features @@ -45,46 +48,31 @@ def forward(self, x): x = self.backbone(x) rot = self.rot_regression(x) xyz = self.xyz_regression(x) - return RigidTransform(rot, xyz, self.parameterization, self.convention) + return convert( + [rot, xyz], + input_parameterization=self.parameterization, + output_parameterization="se3_exp_map", + input_convention=self.convention, + ) # %% ../notebooks/api/03_registration.ipynb 6 N_ANGULAR_COMPONENTS = { "axis_angle": 3, "euler_angles": 3, - "se3": 3, + "se3_log_map": 3, "quaternion": 4, "rotation_6d": 6, "rotation_10d": 10, "quaternion_adjugate": 10, } -# %% ../notebooks/api/03_registration.ipynb 8 -from beartype import beartype -from pytorch3d.transforms import se3_exp_map - -from .calibration import RigidTransform - - -@beartype -def get_random_offset(batch_size: int, device) -> RigidTransform: - t1 = torch.distributions.Normal(10, 70).sample((batch_size,)) - t2 = torch.distributions.Normal(250, 90).sample((batch_size,)) - t3 = torch.distributions.Normal(5, 50).sample((batch_size,)) - r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,)) - r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) - r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,)) - logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device) - T = se3_exp_map(logmap) - R = T[..., :3, :3].transpose(-1, -2) - t = T[..., 3, :3] - return RigidTransform(R, t) - -# %% ../notebooks/api/03_registration.ipynb 12 +# %% ../notebooks/api/03_registration.ipynb 11 from diffdrr.detector import make_xrays from diffdrr.drr import DRR from diffdrr.siddon import siddon_raycast +from pytorch3d.transforms import se3_exp_map -from .calibration import RigidTransform, convert +from .calibration import RigidTransform class SparseRegistration(torch.nn.Module): @@ -210,12 +198,11 @@ def get_current_pose(self): return convert( [self.rotation, self.translation], input_parameterization=self.parameterization, - output_parameterization="euler_angles", + output_parameterization="se3_exp_map", input_convention=self.convention, - output_convention="ZYX", ) -# %% ../notebooks/api/03_registration.ipynb 14 +# %% ../notebooks/api/03_registration.ipynb 13 def preprocess(x, eps=1e-4): x = (x - x.min()) / (x.max() - x.min() + eps) return (x - 0.3080) / 0.1494 @@ -245,7 +232,7 @@ def vector_to_img(pred_img, mask): patches.append(patch) return filled -# %% ../notebooks/api/03_registration.ipynb 15 +# %% ../notebooks/api/03_registration.ipynb 14 class VectorizedNormalizedCrossCorrelation2d(torch.nn.Module): def __init__(self, eps=1e-4): super().__init__() diff --git a/notebooks/api/00_deepfluoro.ipynb b/notebooks/api/00_deepfluoro.ipynb index 6a54020..99a1289 100644 --- a/notebooks/api/00_deepfluoro.ipynb +++ b/notebooks/api/00_deepfluoro.ipynb @@ -694,6 +694,45 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "dece6bbf-d973-41a4-8b38-078ab64e6f79", + "metadata": {}, + "source": [ + "## Distribution over camera poses\n", + "\n", + "We sample the three rotational and three translational parameters of $\\mathfrak{se}(3)$ from independent normal distributions defined with sufficient variance to capture wide perturbations from the isocenter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32d6434d-41c3-4bf6-b12f-ea3d2721c753", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from beartype import beartype\n", + "from pytorch3d.transforms import se3_exp_map\n", + "\n", + "from diffpose.calibration import RigidTransform\n", + "\n", + "\n", + "@beartype\n", + "def get_random_offset(batch_size: int, device) -> RigidTransform:\n", + " t1 = torch.distributions.Normal(10, 70).sample((batch_size,))\n", + " t2 = torch.distributions.Normal(250, 90).sample((batch_size,))\n", + " t3 = torch.distributions.Normal(5, 50).sample((batch_size,))\n", + " r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))\n", + " r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", + " r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))\n", + " logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device)\n", + " T = se3_exp_map(logmap)\n", + " R = T[..., :3, :3].transpose(-1, -2)\n", + " t = T[..., 3, :3]\n", + " return RigidTransform(R, t)" + ] + }, { "cell_type": "markdown", "id": "76efedd3-103c-43dd-944b-fe0c06e2c87b", @@ -701,7 +740,7 @@ "source": [ "## Fiducial markers\n", "\n", - "The `DeepFluoroDataset` class also contains a method for evaluating the registration error for a predicted pose. Fiducial markers were implanted in the original cadavers. Projecting them with predicted pose parameters can be used to measure their distance from the true fiducials." + "The `DeepFluoroDataset` class also contains a method for evaluating the registration error for a predicted pose. Fiducial markers were digitally placed on the preoperative CT. Projecting them with predicted pose parameters can be used to measure their distance from the true fiducials." ] }, { @@ -720,7 +759,6 @@ ], "source": [ "#| eval: false\n", - "# from pytorch3d.transforms import euler_angles_to_matrix, matrix_to_euler_angles\n", "from diffdrr.utils import convert\n", "\n", "# Perturb the ground truth rotations by 0.05 degrees and 2 mm\n", diff --git a/notebooks/api/01_ljubljana.ipynb b/notebooks/api/01_ljubljana.ipynb index cfce57d..292f495 100644 --- a/notebooks/api/01_ljubljana.ipynb +++ b/notebooks/api/01_ljubljana.ipynb @@ -106,6 +106,9 @@ " def __len__(self):\n", " return 10\n", "\n", + " def __iter__(self):\n", + " return iter(self[idx] for idx in range(len(self)))\n", + "\n", " def __getitem__(self, idx):\n", " idx += 1\n", " extrinsic = self.f[f\"subject{idx:02d}/proj-{self.view}/extrinsic\"][:]\n", @@ -122,8 +125,9 @@ " if self.preprocess:\n", " img += 1\n", " img = img.max().log() - img.log()\n", - "\n", " height, width = img.shape\n", + " img = img.unsqueeze(0).unsqueeze(0)\n", + "\n", " focal_len, x0, y0 = parse_intrinsic_matrix(\n", " intrinsic,\n", " height,\n", @@ -141,6 +145,13 @@ " volume = self.f[f\"subject{idx:02d}/volume/pixels\"][:]\n", " spacing = self.f[f\"subject{idx:02d}/volume/spacing\"][:]\n", "\n", + " isocenter_rot = torch.tensor([[torch.pi / 2, 0.0, -torch.pi / 2]])\n", + " isocenter_xyz = torch.tensor(volume.shape) * spacing / 2\n", + " isocenter_xyz = isocenter_xyz.unsqueeze(0)\n", + " isocenter_pose = RigidTransform(\n", + " isocenter_rot, isocenter_xyz, \"euler_angles\", \"ZYX\"\n", + " )\n", + "\n", " return (\n", " volume,\n", " spacing,\n", @@ -153,6 +164,7 @@ " y0,\n", " img,\n", " pose,\n", + " isocenter_pose,\n", " )" ] }, @@ -293,6 +305,7 @@ " y0,\n", " img,\n", " pose,\n", + " isocenter_pose,\n", " ) = subject[idx]\n", " volume[volume < 500] = 0.0\n", " if idx == 5:\n", @@ -449,6 +462,7 @@ " y0,\n", " img,\n", " pose,\n", + " isocenter_pose,\n", " ) = subject[idx]\n", " volume[volume < 500] = 0.0\n", " if idx == 5:\n", @@ -481,6 +495,209 @@ " plt.show()" ] }, + { + "cell_type": "markdown", + "id": "4d1df029-59b8-4a0f-88d7-9de82e1f8419", + "metadata": {}, + "source": [ + "## Distribution over camera poses\n", + "\n", + "We sample the three rotational and three translational parameters of $\\mathfrak{se}(3)$ from independent normal distributions defined with sufficient variance to capture wide perturbations from the isocenter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af00428c-a6fa-49e3-a2bc-3a822d90aa97", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from beartype import beartype\n", + "from pytorch3d.transforms import se3_exp_map\n", + "\n", + "from diffpose.calibration import RigidTransform\n", + "\n", + "\n", + "@beartype\n", + "def get_random_offset(view, batch_size: int, device) -> RigidTransform:\n", + " if view == \"ap\":\n", + " t1 = torch.distributions.Normal(-10, 20).sample((batch_size,))\n", + " t2 = torch.distributions.Normal(175, 30).sample((batch_size,))\n", + " t3 = torch.distributions.Normal(-5, 15).sample((batch_size,))\n", + " r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", + " r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", + " r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,))\n", + " elif view == \"lat\":\n", + " t1 = torch.distributions.Normal(75, 15).sample((batch_size,))\n", + " t2 = torch.distributions.Normal(-80, 20).sample((batch_size,))\n", + " t3 = torch.distributions.Normal(-5, 10).sample((batch_size,))\n", + " r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", + " r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", + " r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))\n", + " else:\n", + " raise ValueError(f\"view must be 'ap' or 'lat', not {view}\")\n", + " \n", + " logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device)\n", + " T = se3_exp_map(logmap)\n", + " R = T[..., :3, :3].transpose(-1, -2)\n", + " t = T[..., 3, :3]\n", + " return RigidTransform(R, t)" + ] + }, + { + "cell_type": "markdown", + "id": "6cabac02-6771-453a-897b-3977144b5ba5", + "metadata": {}, + "source": [ + "## Fiducial markers\n", + "\n", + "The `LjubljanaDataset` class also contains a method for evaluating the registration error for a predicted pose. Digital fiducial markers were placed along the centerlines of the vessels. Projecting them with predicted pose parameters can be used to measure their distance from the true fiducials." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5da282b3-4b16-43cd-b81e-3251074e9415", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from torch.nn.functional import pad\n", + "\n", + "from diffpose.calibration import perspective_projection\n", + "\n", + "\n", + "class Evaluator:\n", + " def __init__(self, specimen, idx):\n", + " # Save matrices to device\n", + " self.translate = specimen.translate\n", + " self.flip_xz = specimen.flip_xz\n", + " self.intrinsic = specimen.intrinsic\n", + " self.intrinsic_inv = specimen.intrinsic.inverse()\n", + "\n", + " # Get gt fiducial locations\n", + " self.specimen = specimen\n", + " self.fiducials = specimen.fiducials\n", + " gt_pose = specimen[idx][1]\n", + " self.true_projected_fiducials = self.project(gt_pose)\n", + "\n", + " def project(self, pose):\n", + " extrinsic = convert_diffdrr_to_deepfluoro(self.specimen, pose)\n", + " x = perspective_projection(extrinsic, self.intrinsic, self.fiducials)\n", + " x = -self.specimen.focal_len * torch.einsum(\n", + " \"ij, bnj -> bni\",\n", + " self.intrinsic_inv,\n", + " pad(x, (0, 1), value=1), # Convert to homogenous coordinates\n", + " )\n", + " extrinsic = (\n", + " self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose)\n", + " )\n", + " return extrinsic.transform_points(x)\n", + "\n", + " def __call__(self, pose):\n", + " pred_projected_fiducials = self.project(pose)\n", + " registration_error = (\n", + " (self.true_projected_fiducials - pred_projected_fiducials)\n", + " .norm(dim=-1)\n", + " .mean()\n", + " )\n", + " registration_error *= 0.154 # Pixel spacing is 0.154 mm / pixel isotropic\n", + " return registration_error" + ] + }, + { + "cell_type": "markdown", + "id": "7612861a-3917-4ea3-b368-bb880b6d65e6", + "metadata": {}, + "source": [ + "## Deep learning transforms\n", + "\n", + "We transform X-rays and DRRs before inputting them to a deep learning model by\n", + "\n", + "- Rescaling pixels to [0, 1]\n", + "- Resizing the images to a specified size\n", + "- Normalizing pixels by the mean and std dev" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84c120f0-0249-4b78-81d9-b7e24423d602", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████| 10/10 [00:07<00:00, 1.38it/s]\n", + "100%|█████████████| 10/10 [00:03<00:00, 2.87it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Pixel mean : tensor(0.0774)\n", + "Pixel std dev : tensor(0.0569)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "#| eval: false\n", + "#| code-fold: true\n", + "from tqdm import tqdm\n", + "\n", + "mean, vars = [], []\n", + "for view in [\"ap\", \"lat\"]:\n", + " specimen = LjubljanaDataset(view, filename=\"../../data/ljubljana.h5\")\n", + " for _, _, _, _, _, _, _, _, _, img, _, _ in tqdm(specimen, ncols=50):\n", + " img = (img - img.min()) / (img.max() - img.min())\n", + " mean.append(img.mean())\n", + " vars.append(img.var())\n", + "\n", + "print(\"Pixel mean :\", sum(mean) / len(mean))\n", + "print(\"Pixel std dev :\", (sum(vars) / len(vars)).sqrt())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "824075a4-4d9d-4cdf-b2e0-9dd0d625ead4", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from torchvision.transforms import Compose, Lambda, Normalize, Resize\n", + "\n", + "\n", + "class Transforms:\n", + " def __init__(\n", + " self,\n", + " height: int,\n", + " width: int,\n", + " eps: float = 1e-6,\n", + " ):\n", + " \"\"\"Transform X-rays and DRRs before inputting to CNN.\"\"\"\n", + " self.transforms = Compose(\n", + " [\n", + " Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + eps)),\n", + " Resize((height, width), antialias=True),\n", + " Normalize(mean=0.0774, std=0.0569),\n", + " ]\n", + " )\n", + "\n", + " def __call__(self, x):\n", + " return self.transforms(x)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/notebooks/api/03_registration.ipynb b/notebooks/api/03_registration.ipynb index 6286520..8832121 100644 --- a/notebooks/api/03_registration.ipynb +++ b/notebooks/api/03_registration.ipynb @@ -68,6 +68,9 @@ "outputs": [], "source": [ "#| export\n", + "from diffpose.calibration import RigidTransform, convert\n", + "\n", + "\n", "class PoseRegressor(torch.nn.Module):\n", " \"\"\"\n", " A PoseRegressor is comprised of a pretrained backbone model that extracts features\n", @@ -105,7 +108,12 @@ " x = self.backbone(x)\n", " rot = self.rot_regression(x)\n", " xyz = self.xyz_regression(x)\n", - " return RigidTransform(rot, xyz, self.parameterization, self.convention)" + " return convert(\n", + " [rot, xyz],\n", + " input_parameterization=self.parameterization,\n", + " output_parameterization=\"se3_exp_map\",\n", + " input_convention=self.convention,\n", + " )" ] }, { @@ -119,7 +127,7 @@ "N_ANGULAR_COMPONENTS = {\n", " \"axis_angle\": 3,\n", " \"euler_angles\": 3,\n", - " \"se3\": 3,\n", + " \"se3_log_map\": 3,\n", " \"quaternion\": 4,\n", " \"rotation_6d\": 6,\n", " \"rotation_10d\": 10,\n", @@ -134,36 +142,7 @@ "source": [ "### Sampling random camera poses\n", "\n", - "We sample random camera poses from the tangent space of SE(3), which is linear." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41eccad5-a135-4cfb-b350-1a4a3df9ab7b", - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "from beartype import beartype\n", - "from pytorch3d.transforms import se3_exp_map\n", - "\n", - "from diffpose.calibration import RigidTransform\n", - "\n", - "\n", - "@beartype\n", - "def get_random_offset(batch_size: int, device) -> RigidTransform:\n", - " t1 = torch.distributions.Normal(10, 70).sample((batch_size,))\n", - " t2 = torch.distributions.Normal(250, 90).sample((batch_size,))\n", - " t3 = torch.distributions.Normal(5, 50).sample((batch_size,))\n", - " r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))\n", - " r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", - " r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))\n", - " logmap = torch.stack([t1, t2, t3, r1, r2, r3], dim=1).to(device)\n", - " T = se3_exp_map(logmap)\n", - " R = T[..., :3, :3].transpose(-1, -2)\n", - " t = T[..., 3, :3]\n", - " return RigidTransform(R, t)" + "We sample random camera poses from the tangent space of SE(3), which is Euclidean." ] }, { @@ -191,7 +170,7 @@ "from diffdrr.drr import DRR\n", "from torchvision.utils import make_grid\n", "\n", - "from diffpose.deepfluoro import DeepFluoroDataset\n", + "from diffpose.deepfluoro import DeepFluoroDataset, get_random_poses\n", "\n", "specimen = DeepFluoroDataset(1)\n", "device = torch.device(\"cuda\")\n", @@ -269,8 +248,9 @@ "from diffdrr.detector import make_xrays\n", "from diffdrr.drr import DRR\n", "from diffdrr.siddon import siddon_raycast\n", + "from pytorch3d.transforms import se3_exp_map\n", "\n", - "from diffpose.calibration import RigidTransform, convert\n", + "from diffpose.calibration import RigidTransform\n", "\n", "\n", "class SparseRegistration(torch.nn.Module):\n", @@ -396,9 +376,8 @@ " return convert(\n", " [self.rotation, self.translation],\n", " input_parameterization=self.parameterization,\n", - " output_parameterization=\"euler_angles\",\n", + " output_parameterization=\"se3_exp_map\",\n", " input_convention=self.convention,\n", - " output_convention=\"ZYX\",\n", " )" ] },