Skip to content

Commit

Permalink
Merge pull request #5 from eigenvivek/pytorch3d
Browse files Browse the repository at this point in the history
Remove dependency on `pytorch3d`
  • Loading branch information
eigenvivek authored Jan 24, 2024
2 parents e272ba0 + 66e5bfd commit 062aac4
Show file tree
Hide file tree
Showing 13 changed files with 129 additions and 79 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
checkpoints/
evaluations/
data/
logs/
runs/
Expand Down
19 changes: 9 additions & 10 deletions diffpose/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
from typing import Optional

from beartype import beartype
from diffdrr.utils import Transform3d
from diffdrr.utils import convert as convert_so3
from diffdrr.utils import se3_exp_map, se3_log_map
from jaxtyping import Float, jaxtyped
from pytorch3d.transforms import Transform3d
from pytorchse3.se3 import se3_exp_map, se3_log_map

# %% ../notebooks/api/02_calibration.ipynb 7
@beartype
class RigidTransform(Transform3d):
"""Wrapper of pytorch3d.transforms.Transform3d with extra functionalities."""

@jaxtyped
@jaxtyped(typechecker=beartype)
def __init__(
self,
R: Float[torch.Tensor, "..."],
Expand Down Expand Up @@ -74,7 +74,7 @@ def clone(self):
return RigidTransform(R, t, device=self.device, dtype=self.dtype)

def get_se3_log(self):
return se3_log_map(self.get_matrix().transpose(-1, -2))
return se3_log_map(self.get_matrix())

# %% ../notebooks/api/02_calibration.ipynb 8
def convert(
Expand All @@ -88,8 +88,8 @@ def convert(

# Convert any input parameterization to a RigidTransform
if input_parameterization == "se3_log_map":
transform = torch.concat([*transform], axis=-1)
matrix = se3_exp_map(transform)
transform = torch.concat([transform[1], transform[0]], axis=-1)
matrix = se3_exp_map(transform).transpose(-1, -2)
transform = RigidTransform(
R=matrix[..., :3, :3],
t=matrix[..., :3, 3],
Expand All @@ -111,8 +111,8 @@ def convert(
return transform
elif output_parameterization == "se3_log_map":
se3_log = transform.get_se3_log()
log_R_vee = se3_log[..., :3]
log_t_vee = se3_log[..., 3:]
log_t_vee = se3_log[..., :3]
log_R_vee = se3_log[..., 3:]
return log_R_vee, log_t_vee
else:
return (
Expand All @@ -121,8 +121,7 @@ def convert(
)

# %% ../notebooks/api/02_calibration.ipynb 10
@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def perspective_projection(
extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera)
intrinsic: Float[torch.Tensor, "3 3"], # Intrinsic camera matrix (camera to image)
Expand Down
2 changes: 0 additions & 2 deletions diffpose/deepfluoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,6 @@ def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)):
return img

# %% ../notebooks/api/00_deepfluoro.ipynb 26
from beartype import beartype

from .calibration import RigidTransform, convert


Expand Down
2 changes: 0 additions & 2 deletions diffpose/ljubljana.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ def __getitem__(self, idx):
)

# %% ../notebooks/api/01_ljubljana.ipynb 7
from beartype import beartype

from .calibration import RigidTransform, convert


Expand Down
24 changes: 9 additions & 15 deletions diffpose/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,25 @@ def __init__(self, patch_size=None):
# %% ../notebooks/api/04_metrics.ipynb 9
import torch
from beartype import beartype
from diffdrr.utils import convert
from jaxtyping import Float, jaxtyped
from pytorch3d.transforms import (
so3_rotation_angle,
from diffdrr.utils import (
convert,
so3_log_map,
so3_relative_angle,
so3_rotation_angle,
standardize_quaternion,
)
from jaxtyping import Float, jaxtyped

from .calibration import RigidTransform

# %% ../notebooks/api/04_metrics.ipynb 10
from pytorchse3.so3 import so3_log_map


class GeodesicSO3(torch.nn.Module):
"""Calculate the angular distance between two rotations in SO(3)."""

def __init__(self):
super().__init__()

@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def forward(
self,
pose_1: RigidTransform,
Expand All @@ -102,8 +99,7 @@ class GeodesicTranslation(torch.nn.Module):
def __init__(self):
super().__init__()

@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def forward(
self,
pose_1: RigidTransform,
Expand All @@ -120,8 +116,7 @@ class GeodesicSE3(torch.nn.Module):
def __init__(self):
super().__init__()

@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def forward(
self,
pose_1: RigidTransform,
Expand All @@ -146,8 +141,7 @@ def __init__(
self.rotation = GeodesicSO3()
self.translation = GeodesicTranslation()

@beartype
@jaxtyped
@jaxtyped(typechecker=beartype)
def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):
angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)
translation_geodesic = self.translation(pose_1, pose_2)
Expand Down
5 changes: 1 addition & 4 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
name: preop
name: diffpose
channels:
- conda-forge
- pytorch
- pytorch3d
- nvidia
dependencies:
- pip
- pytorch
- torchvision
- pytorch3d
- pip:
- diffdrr>=0.3.8
- h5py
- scikit-image
- seaborn
- pytorch-transformers
- pytorchse3
- timm
- torchmetrics
- tqdm
Expand Down
82 changes: 82 additions & 0 deletions experiments/deepfluoro/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from pathlib import Path

import pandas as pd
import submitit
import torch
from tqdm import tqdm

from diffpose.deepfluoro import DeepFluoroDataset, Evaluator, Transforms
from diffpose.registration import PoseRegressor


def load_specimen(id_number, device):
specimen = DeepFluoroDataset(id_number)
isocenter_pose = specimen.isocenter_pose.to(device)
return specimen, isocenter_pose


def load_model(model_name, device):
ckpt = torch.load(model_name)
model = PoseRegressor(
ckpt["model_name"],
ckpt["parameterization"],
ckpt["convention"],
norm_layer=ckpt["norm_layer"],
).to(device)
model.load_state_dict(ckpt["model_state_dict"])
transforms = Transforms(ckpt["height"])
return model, transforms


def evaluate(specimen, isocenter_pose, model, transforms, device):
error = []
model.eval()
for idx in tqdm(range(len(specimen)), ncols=100):
target_registration_error = Evaluator(specimen, idx)
img, _ = specimen[idx]
img = img.to(device)
img = transforms(img)
with torch.no_grad():
offset = model(img)
pred_pose = isocenter_pose.compose(offset)
mtre = target_registration_error(pred_pose.cpu()).item()
error.append(mtre)
return error


def main(id_number):
device = torch.device("cuda")
specimen, isocenter_pose = load_specimen(id_number, device)
models = sorted(Path("checkpoints/").glob(f"specimen_{id_number:02d}_epoch*.ckpt"))

errors = []
for model_name in models:
model, transforms = load_model(model_name, device)
error = evaluate(specimen, isocenter_pose, model, transforms, device)
errors.append([model_name.stem] + error)

df = pd.DataFrame(errors)
df.to_csv(f"evaluations/subject{id_number}.csv", index=False)


if __name__ == "__main__":
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Path("evaluations").mkdir(exist_ok=True)
id_numbers = [1, 2, 3, 4, 5, 6]

executor = submitit.AutoExecutor(folder="logs")
executor.update_parameters(
name="eval",
gpus_per_node=1,
mem_gb=10.0,
slurm_array_parallelism=len(id_numbers),
slurm_exclude="curcum",
slurm_partition="2080ti",
timeout_min=10_000,
)
jobs = executor.map_array(main, id_numbers)
4 changes: 2 additions & 2 deletions experiments/deepfluoro/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def train(
best_loss = torch.inf

model.train()
for epoch in range(n_epochs):
for epoch in range(n_epochs + 1):
losses = []
for _ in (itr := tqdm(range(n_batches_per_epoch), leave=False)):
contrast = contrast_distribution.sample().item()
Expand Down Expand Up @@ -144,7 +144,7 @@ def train(
f"checkpoints/specimen_{id_number:02d}_best.ckpt",
)

if epoch % 25 == 0 and epoch != 0:
if epoch % 50 == 0:
torch.save(
{
"model_state_dict": model.state_dict(),
Expand Down
2 changes: 0 additions & 2 deletions notebooks/api/00_deepfluoro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,6 @@
"outputs": [],
"source": [
"#| export\n",
"from beartype import beartype\n",
"\n",
"from diffpose.calibration import RigidTransform, convert\n",
"\n",
"\n",
Expand Down
2 changes: 0 additions & 2 deletions notebooks/api/01_ljubljana.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,6 @@
"outputs": [],
"source": [
"#| export\n",
"from beartype import beartype\n",
"\n",
"from diffpose.calibration import RigidTransform, convert\n",
"\n",
"\n",
Expand Down
21 changes: 10 additions & 11 deletions notebooks/api/02_calibration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@
"from typing import Optional\n",
"\n",
"from beartype import beartype\n",
"from diffdrr.utils import Transform3d\n",
"from diffdrr.utils import convert as convert_so3\n",
"from jaxtyping import Float, jaxtyped\n",
"from pytorch3d.transforms import Transform3d\n",
"from pytorchse3.se3 import se3_exp_map, se3_log_map"
"from diffdrr.utils import se3_exp_map, se3_log_map\n",
"from jaxtyping import Float, jaxtyped"
]
},
{
Expand All @@ -115,7 +115,7 @@
"class RigidTransform(Transform3d):\n",
" \"\"\"Wrapper of pytorch3d.transforms.Transform3d with extra functionalities.\"\"\"\n",
"\n",
" @jaxtyped\n",
" @jaxtyped(typechecker=beartype)\n",
" def __init__(\n",
" self,\n",
" R: Float[torch.Tensor, \"...\"],\n",
Expand Down Expand Up @@ -169,7 +169,7 @@
" return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n",
"\n",
" def get_se3_log(self):\n",
" return se3_log_map(self.get_matrix().transpose(-1, -2))"
" return se3_log_map(self.get_matrix())"
]
},
{
Expand All @@ -191,8 +191,8 @@
"\n",
" # Convert any input parameterization to a RigidTransform\n",
" if input_parameterization == \"se3_log_map\":\n",
" transform = torch.concat([*transform], axis=-1)\n",
" matrix = se3_exp_map(transform)\n",
" transform = torch.concat([transform[1], transform[0]], axis=-1)\n",
" matrix = se3_exp_map(transform).transpose(-1, -2)\n",
" transform = RigidTransform(\n",
" R=matrix[..., :3, :3],\n",
" t=matrix[..., :3, 3],\n",
Expand All @@ -214,8 +214,8 @@
" return transform\n",
" elif output_parameterization == \"se3_log_map\":\n",
" se3_log = transform.get_se3_log()\n",
" log_R_vee = se3_log[..., :3]\n",
" log_t_vee = se3_log[..., 3:]\n",
" log_t_vee = se3_log[..., :3]\n",
" log_R_vee = se3_log[..., 3:]\n",
" return log_R_vee, log_t_vee\n",
" else:\n",
" return (\n",
Expand Down Expand Up @@ -243,8 +243,7 @@
"outputs": [],
"source": [
"#| export\n",
"@beartype\n",
"@jaxtyped\n",
"@jaxtyped(typechecker=beartype)\n",
"def perspective_projection(\n",
" extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera)\n",
" intrinsic: Float[torch.Tensor, \"3 3\"], # Intrinsic camera matrix (camera to image)\n",
Expand Down
Loading

0 comments on commit 062aac4

Please sign in to comment.