Skip to content

Commit af4138c

Browse files
authoredApr 27, 2022
Add files via upload
1 parent a979a2b commit af4138c

File tree

4 files changed

+278
-0
lines changed

4 files changed

+278
-0
lines changed
 

‎LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2021 Jonathan Frawley
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

‎README.markdown

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# MedNeRF: Medical Neural Radiance Fields for Reconstructing 3D-aware CT-Projections from a Single X-ray
2+
3+
Repository copied from:
4+
https://github.com/abrilcf/mednerf
5+
6+
7+
## Get the Data
8+
You can find all DRR in the following [link](https://drive.google.com/file/d/1_EJX3LnRMG5uXEhZ63C2eYoY4hjwmipP/view?usp=sharing). Here is a description of the folders:
9+
10+
An <em>instance</em> comprehends 72 DRRs (each at 5 degrees) from a 360 degree rotation of a real CT scan.
11+
12+
13+
`chest_xrays` all images of the 20 chest instances (.png, res. 128x128).
14+
15+
`knee_xrays` all images of the 5 knee instances (.png, res. 128x128)
16+
17+
## Train a model
18+
Refer to graf-main folder and execute, replacing CONFIG.yaml with knee.yaml or chest.yaml
19+
```
20+
python train.py configs/CONFIG.yaml
21+
```
22+
23+
## Reconstruction given an X-ray
24+
After training a model, you can test its capacity to reconstruct 3D-aware CT projections given a single X-ray.
25+
26+
Install ray tune for hyperparameter tuning with:
27+
```
28+
pip install "ray[tune]"
29+
```
30+
31+
To execute the reconstruction, please refer to graf-main folder and execute:
32+
```
33+
python finetune_xray.py configs/config-file.yaml --xray_img_path path_to_xray --save_dir path_to_save_dir --model path_to_trained_model configs/knee.yaml
34+
```
35+
(to use ray for finetuning, please change the runtime environment with your configuration)
36+
In my case:
37+
```
38+
ray.init(runtime_env={"conda": "/home/anya/anaconda3/envs/graf", "py_modules": ["/home/anya/Programs/mednerf/graf-main/submodules", "/home/anya/Programs/mednerf/graf-main/graf", "/home/anya/Programs/mednerf/graf-main/submodules/GAN_stability/", "/home/anya/Programs/mednerf/graf-main/configs"]})
39+
```
40+
41+
## PixelNeRF instructions
42+
To use pixelNeRF model use the following configuration files:
43+
44+
```
45+
pixel-nerf/conf/exp/ct_single.conf
46+
pixel-nerf/conf/exp/drr.conf
47+
```
48+
49+
## Generate DRR images from CT scans
50+
To generate xrays images (.png) at different angles from CT scans use the script `generate_drr.py` under the folder `data/`. To run it you need to install the [Plastimatch's build](http://plastimatch.org/). Version 1.9.3 was used.
51+
52+
An updated version of the script has been added (generate_drr_multiple_dirs.py). Use this script to automatically generate the set of DRRs when you have a global folder with multiple folders containing CT scans. The files do not necessarily need to be in the immediate subfolder. You only need to assign the path location of the global folder and the global folder to save the set of DRRs.
53+
54+
### Overview of input arguments
55+
Replace the following variables within the file:
56+
57+
- `input_path`: path to the .dcm files or .mha file of the CT.
58+
- `save_root_path`: path where you want the xrays images to be saved.
59+
- `plasti_path`: path of the build.
60+
- `multiple_view_mode <True | False>`: generate single xrays from lateral or frontal views or multiple images from a circular rotation around the z axis.
61+
If False you need to specify the view with the argument `frontal_dir <True | False>` (false for lateral view).
62+
If True you need to specify `num_xrays` to generate equally spaced number of views and `angles` to input the difference between neighboring angles (in degrees).
63+
- `preprocessing <True | False>`: set this to True if files are .dcm for Hounsfield Units conversion. Set to False if given file is raw (.mha), for which you need to provide its path under the variable `raw_input_file`.
64+
- `detector_size`: pair of values in mm
65+
- `bg_color`: choose either black or white background.
66+
- `resolution`: size of the output xrays images.
67+
68+
## Acknowledgments
69+
70+
This codebase is heavily based on the [GRAF](https://github.com/autonomousvision/graf) code base. We also use the code from [pixel-nerf](https://github.com/sxyu/pixel-nerf) for baseline experiments.
71+
72+
We thank all authors for the wonderful code!
73+
74+
## Citation
75+
If you use our model for your research, please cite the following work.
76+
77+
```bash
78+
@misc{coronafigueroa2022mednerf,
79+
title={MedNeRF: Medical Neural Radiance Fields for Reconstructing 3D-aware CT-Projections from a Single X-ray},
80+
author={Abril Corona-Figueroa and Jonathan Frawley and Sam Bond-Taylor and Sarath Bethapudi and Hubert P. H. Shum and Chris G. Willcocks},
81+
year={2022},
82+
eprint={2202.01020},
83+
archivePrefix={arXiv},
84+
primaryClass={eess.IV}
85+
}
86+
```
87+

‎gan_training.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import torch
2+
import numpy as np
3+
import os
4+
from tqdm import tqdm
5+
6+
from submodules.GAN_stability.gan_training.train import toggle_grad, Trainer as TrainerBase
7+
from submodules.GAN_stability.gan_training.eval import Evaluator as EvaluatorBase
8+
from submodules.GAN_stability.gan_training.metrics import FIDEvaluator, KIDEvaluator
9+
10+
from .utils import save_video, color_depth_map
11+
12+
13+
class Trainer(TrainerBase):
14+
def __init__(self, *args, use_amp=False, **kwargs):
15+
super(Trainer, self).__init__(*args, **kwargs)
16+
self.use_amp = use_amp
17+
if self.use_amp:
18+
self.scaler = torch.cuda.amp.GradScaler()
19+
20+
def generator_trainstep(self, y, z):
21+
if not self.use_amp:
22+
return super(Trainer, self).generator_trainstep(y, z)
23+
assert (y.size(0) == z.size(0))
24+
toggle_grad(self.generator, True)
25+
toggle_grad(self.discriminator, False)
26+
self.generator.train()
27+
self.discriminator.train()
28+
self.g_optimizer.zero_grad()
29+
30+
with torch.cuda.amp.autocast():
31+
x_fake = self.generator(z, y)
32+
d_fake = self.discriminator(x_fake, y)
33+
gloss = self.compute_loss(d_fake, 1)
34+
self.scaler.scale(gloss).backward()
35+
36+
self.scaler.step(self.g_optimizer)
37+
self.scaler.update()
38+
39+
return gloss.item()
40+
41+
def discriminator_trainstep(self, x_real, y, z, data_aug):
42+
return super(Trainer, self).discriminator_trainstep(x_real, y, z, data_aug) # spectral norm raises error for when using amp
43+
44+
45+
class Evaluator(EvaluatorBase):
46+
def __init__(self, eval_fid_kid, *args, **kwargs):
47+
super(Evaluator, self).__init__(*args, **kwargs)
48+
if eval_fid_kid:
49+
self.inception_eval = FIDEvaluator(
50+
device=self.device,
51+
batch_size=self.batch_size,
52+
resize=True,
53+
n_samples=20000,
54+
n_samples_fake=1000,
55+
)
56+
57+
def get_rays(self, pose):
58+
return self.generator.val_ray_sampler(self.generator.H, self.generator.W,
59+
self.generator.focal, pose)[0]
60+
61+
def create_samples(self, z, poses=None):
62+
self.generator.eval()
63+
N_samples = len(z)
64+
device = self.generator.device
65+
if self.batch_size > 1:
66+
z = z.to(device).split(self.batch_size)
67+
if poses is None:
68+
rays = [None] * len(z)
69+
else:
70+
rays = torch.stack([self.get_rays(poses[i].to(device)) for i in range(N_samples)])
71+
rays = rays.split(self.batch_size)
72+
73+
rgb, disp, acc = [], [], []
74+
with torch.no_grad():
75+
if self.batch_size > 1:
76+
for z_i, rays_i in tqdm(zip(z, rays), total=len(z), desc='Create samples...'):
77+
bs = len(z_i)
78+
if rays_i is not None:
79+
rays_i = rays_i.permute(1, 0, 2, 3).flatten(1, 2) # Bx2x(HxW)xC -> 2x(BxHxW)x3
80+
rgb_i, disp_i, acc_i, _ = self.generator(z_i, rays=rays_i)
81+
82+
reshape = lambda x: x.view(bs, self.generator.H, self.generator.W, x.shape[1]).permute(0, 3, 1, 2) # (NxHxW)xC -> NxCxHxW
83+
rgb.append(reshape(rgb_i).cpu())
84+
disp.append(reshape(disp_i).cpu())
85+
acc.append(reshape(acc_i).cpu())
86+
else:
87+
for rays_i in rays:
88+
bs = len(z)
89+
if rays_i is not None:
90+
rays_i = rays_i.permute(1, 0, 2, 3).flatten(1, 2) # Bx2x(HxW)xC -> 2x(BxHxW)x3
91+
rgb_i, disp_i, acc_i, _ = self.generator(z, rays=rays_i)
92+
93+
reshape = lambda x: x.view(bs, self.generator.H, self.generator.W, x.shape[1]).permute(0, 3, 1, 2) # (NxHxW)xC -> NxCxHxW
94+
rgb.append(reshape(rgb_i).cpu())
95+
disp.append(reshape(disp_i).cpu())
96+
acc.append(reshape(acc_i).cpu())
97+
98+
rgb = torch.cat(rgb)
99+
disp = torch.cat(disp)
100+
acc = torch.cat(acc)
101+
102+
depth = self.disp_to_cdepth(disp)
103+
104+
return rgb, depth, acc
105+
106+
def make_video(self, basename, z, poses, as_gif=True):
107+
""" Generate images and save them as video.
108+
z (N_samples, zdim): latent codes
109+
poses (N_frames, 3 x 4): camera poses for all frames of video
110+
"""
111+
N_samples, N_frames = len(z), len(poses)
112+
113+
# reshape inputs
114+
z = z.unsqueeze(1).expand(-1, N_frames, -1).flatten(0, 1) # (N_samples x N_frames) x z_dim
115+
poses = poses.unsqueeze(0) \
116+
.expand(N_samples, -1, -1, -1).flatten(0, 1) # (N_samples x N_frames) x 3 x 4
117+
118+
rgbs, depths, accs = self.create_samples(z, poses=poses)
119+
120+
reshape = lambda x: x.view(N_samples, N_frames, *x.shape[1:])
121+
rgbs = reshape(rgbs)
122+
depths = reshape(depths)
123+
print('Done, saving', rgbs.shape)
124+
125+
fps = min(int(N_frames / 2.), 25) # aim for at least 2 second video
126+
for i in range(N_samples):
127+
save_video(rgbs[i], basename + '{:04d}_rgb.mp4'.format(i), as_gif=as_gif, fps=fps)
128+
save_video(depths[i], basename + '{:04d}_depth.mp4'.format(i), as_gif=as_gif, fps=fps)
129+
130+
def disp_to_cdepth(self, disps):
131+
"""Convert depth to color values"""
132+
if (disps == 2e10).all(): # no values predicted
133+
return torch.ones_like(disps)
134+
135+
near, far = self.generator.render_kwargs_test['near'], self.generator.render_kwargs_test['far']
136+
137+
disps = disps / 2 + 0.5 # [-1, 1] -> [0, 1]
138+
139+
depth = 1. / torch.max(1e-10 * torch.ones_like(disps), disps) # disparity -> depth
140+
depth[disps == 1e10] = far # set undefined values to far plane
141+
142+
# scale between near, far plane for better visualization
143+
depth = (depth - near) / (far - near)
144+
145+
depth = np.stack([color_depth_map(d) for d in depth[:, 0].detach().cpu().numpy()]) # convert to color
146+
depth = (torch.from_numpy(depth).permute(0, 3, 1, 2) / 255.) * 2 - 1 # [0, 255] -> [-1, 1]
147+
148+
return depth
149+
150+
def compute_fid_kid(self, sample_generator=None):
151+
if sample_generator is None:
152+
def sample():
153+
while True:
154+
z = self.zdist.sample((self.batch_size,))
155+
rgb, _, _ = self.create_samples(z)
156+
# convert to uint8 and back to get correct binning
157+
rgb = (rgb / 2 + 0.5).mul_(255).clamp_(0, 255).to(torch.uint8).to(torch.float) / 255. * 2 - 1
158+
yield rgb.cpu()
159+
160+
sample_generator = sample()
161+
162+
fid, (kids, vars) = self.inception_eval.get_fid_kid(sample_generator)
163+
kid = np.mean(kids)
164+
return fid, kid

‎requirements.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
matplotlib
2+
numpy
3+
scikit-image
4+
pydicom
5+
gdcm
6+
pylibjpeg-libjpeg

0 commit comments

Comments
 (0)
Please sign in to comment.