diff --git a/experiments/deepfluoro/train.py b/experiments/deepfluoro/train.py index 5ff74e5..9d4239e 100644 --- a/experiments/deepfluoro/train.py +++ b/experiments/deepfluoro/train.py @@ -63,12 +63,12 @@ def train( contrast = contrast_distribution.sample().item() offset = get_random_offset(batch_size, device) pose = isocenter_pose.compose(offset) - img = drr(None, None, None, pose=pose, bone_attenuation_multiplier=contrast) + img = drr(pose, bone_attenuation_multiplier=contrast) img = transforms(img) pred_offset = model(img) pred_pose = isocenter_pose.compose(pred_offset) - pred_img = drr(None, None, None, pose=pred_pose) + pred_img = drr(pred_pose) pred_img = transforms(pred_img) ncc = metric(pred_img, img)