Skip to content

Commit e386513

Browse files
authored
Merge pull request #37 from HiLab-git/dev
Dev
2 parents 78c1460 + ea8d405 commit e386513

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+936
-333
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ BibTeX entry:
1515
author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang},
1616
title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}},
1717
year = {2023},
18-
url = {http://arxiv.org/abs/2208.09350},
18+
url = {https://doi.org/10.1016/j.cmpb.2023.107398},
1919
journal = {Computer Methods and Programs in Biomedicine},
20-
volume = {February},
20+
volume = {231},
2121
pages = {107398},
2222
}
2323

2424
# Features
2525
PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions:
26-
* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning.
26+
* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, weakly-supervised and noisy-label learning.
2727
* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC.
2828
* Easy-to-use I/O interface to read and write different 2D and 3D images.
2929
* Various data pre-processing/transformation methods before sending a tensor into a network.
@@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC:
4747
```bash
4848
pip install PYMIC
4949
```
50-
To install a specific version of PYMIC such as 0.3.1, run:
50+
To install a specific version of PYMIC such as 0.4.0, run:
5151

5252
```bash
53-
pip install PYMIC==0.3.1
53+
pip install PYMIC==0.4.0
5454
```
5555
Alternatively, you can download the source code for the latest version. Run the following command to compile and install:
5656

pymic/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
__version__ = "0.4.0"

pymic/io/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from __future__ import absolute_import
2+
from pymic.io.image_read_write import *
3+
from pymic.io.nifty_dataset import *
4+
from pymic.io.h5_dataset import *

pymic/layer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/loss/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/loss/cls/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/loss/seg/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/net/cls/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/net/net2d/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/net/net2d/unet2d.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class UNet2D(nn.Module):
187187
:param class_num: (int) The class number for segmentation task.
188188
:param bilinear: (bool) Using bilinear for up-sampling or not.
189189
If False, deconvolution will be used for up-sampling.
190-
:param deep_supervise: (bool) Using deep supervision for training or not.
190+
:param multiscale_pred: (bool) Get multiscale prediction.
191191
"""
192192
def __init__(self, params):
193193
super(UNet2D, self).__init__()
@@ -197,7 +197,7 @@ def __init__(self, params):
197197
self.dropout = self.params['dropout']
198198
self.n_class = self.params['class_num']
199199
self.bilinear = self.params['bilinear']
200-
self.deep_sup = self.params['deep_supervise']
200+
self.mul_pred = self.params['multiscale_pred']
201201

202202
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
203203

@@ -213,7 +213,7 @@ def __init__(self, params):
213213
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
214214

215215
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
216-
if(self.deep_sup):
216+
if(self.mul_pred):
217217
self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1)
218218
self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1)
219219
self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1)
@@ -239,7 +239,7 @@ def forward(self, x):
239239
x_d1 = self.up3(x_d2, x1)
240240
x_d0 = self.up4(x_d1, x0)
241241
output = self.out_conv(x_d0)
242-
if(self.deep_sup):
242+
if(self.mul_pred):
243243
output1 = self.out_conv1(x_d1)
244244
output2 = self.out_conv2(x_d2)
245245
output3 = self.out_conv3(x_d3)
@@ -261,7 +261,8 @@ def forward(self, x):
261261
'feature_chns':[2, 8, 32, 48, 64],
262262
'dropout': [0, 0, 0.3, 0.4, 0.5],
263263
'class_num': 2,
264-
'bilinear': True}
264+
'bilinear': True,
265+
'multiscale_pred': False}
265266
Net = UNet2D(params)
266267
Net = Net.double()
267268

pymic/net/net3d/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/net/net3d/unet3d.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Decoder(nn.Module):
131131
:param class_num: (int) The class number for segmentation task.
132132
:param trilinear: (bool) Using bilinear for up-sampling or not.
133133
If False, deconvolution will be used for up-sampling.
134+
:param multiscale_pred: (bool) Get multi-scale prediction.
134135
"""
135136
def __init__(self, params):
136137
super(Decoder, self).__init__()
@@ -139,16 +140,21 @@ def __init__(self, params):
139140
self.ft_chns = self.params['feature_chns']
140141
self.dropout = self.params['dropout']
141142
self.n_class = self.params['class_num']
142-
self.trilinear = self.params['trilinear']
143+
self.trilinear = self.params.get('trilinear', True)
144+
self.mul_pred = self.params.get('multiscale_pred', False)
143145

144146
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
145147

146148
if(len(self.ft_chns) == 5):
147-
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
148-
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
149-
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
150-
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
149+
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear)
150+
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear)
151+
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear)
152+
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear)
151153
self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1)
154+
if(self.mul_pred):
155+
self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1)
156+
self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1)
157+
self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1)
152158

153159
def forward(self, x):
154160
if(len(self.ft_chns) == 5):
@@ -163,6 +169,11 @@ def forward(self, x):
163169
x_d1 = self.up3(x_d2, x1)
164170
x_d0 = self.up4(x_d1, x0)
165171
output = self.out_conv(x_d0)
172+
if(self.mul_pred):
173+
output1 = self.out_conv1(x_d1)
174+
output2 = self.out_conv2(x_d2)
175+
output3 = self.out_conv3(x_d3)
176+
output = [output, output1, output2, output3]
166177
return output
167178

168179
class UNet3D(nn.Module):
@@ -187,7 +198,7 @@ class UNet3D(nn.Module):
187198
:param class_num: (int) The class number for segmentation task.
188199
:param trilinear: (bool) Using trilinear for up-sampling or not.
189200
If False, deconvolution will be used for up-sampling.
190-
:param deep_supervise: (bool) Using deep supervision for training or not.
201+
:param multiscale_pred: (bool) Get multi-scale prediction.
191202
"""
192203
def __init__(self, params):
193204
super(UNet3D, self).__init__()
@@ -197,7 +208,7 @@ def __init__(self, params):
197208
self.dropout = self.params['dropout']
198209
self.n_class = self.params['class_num']
199210
self.trilinear = self.params['trilinear']
200-
self.deep_sup = self.params['deep_supervise']
211+
self.mul_pred = self.params['multiscale_pred']
201212
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
202213

203214
self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
@@ -216,7 +227,7 @@ def __init__(self, params):
216227
dropout_p = self.dropout[0], trilinear=self.trilinear)
217228

218229
self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1)
219-
if(self.deep_sup):
230+
if(self.mul_pred):
220231
self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1)
221232
self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1)
222233
self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1)
@@ -235,14 +246,10 @@ def forward(self, x):
235246
x_d1 = self.up3(x_d2, x1)
236247
x_d0 = self.up4(x_d1, x0)
237248
output = self.out_conv(x_d0)
238-
if(self.deep_sup):
239-
out_shape = list(output.shape)[2:]
249+
if(self.mul_pred):
240250
output1 = self.out_conv1(x_d1)
241-
output1 = interpolate(output1, out_shape, mode = 'trilinear')
242251
output2 = self.out_conv2(x_d2)
243-
output2 = interpolate(output2, out_shape, mode = 'trilinear')
244252
output3 = self.out_conv3(x_d3)
245-
output3 = interpolate(output3, out_shape, mode = 'trilinear')
246253
output = [output, output1, output2, output3]
247254
return output
248255

@@ -251,7 +258,8 @@ def forward(self, x):
251258
'class_num': 2,
252259
'feature_chns':[2, 8, 32, 64],
253260
'dropout' : [0, 0, 0, 0.5],
254-
'trilinear': True}
261+
'trilinear': True,
262+
'multiscale_pred': False}
255263
Net = UNet3D(params)
256264
Net = Net.double()
257265

pymic/net_run/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/net_run/agent_abstract.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def worker_init_fn(worker_id):
276276
self.test_loader = torch.utils.data.DataLoader(self.test_set,
277277
batch_size = bn_test, shuffle=False, num_workers= bn_test)
278278

279-
def create_optimizer(self, params):
279+
def create_optimizer(self, params, checkpoint = None):
280280
"""
281281
Create optimizer based on configuration.
282282
@@ -288,9 +288,9 @@ def create_optimizer(self, params):
288288
self.optimizer = get_optimizer(opt_params['optimizer'],
289289
params, opt_params)
290290
last_iter = -1
291-
if(self.checkpoint is not None):
292-
self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict'])
293-
last_iter = self.checkpoint['iteration'] - 1
291+
if(checkpoint is not None):
292+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
293+
last_iter = checkpoint['iteration'] - 1
294294
if(self.scheduler is None):
295295
opt_params["last_iter"] = last_iter
296296
self.scheduler = get_lr_scheduler(self.optimizer, opt_params)

pymic/net_run/agent_cls.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,6 @@ def training(self):
157157
loss = self.get_loss_value(data, outputs, labels)
158158
loss.backward()
159159
self.optimizer.step()
160-
if(self.scheduler is not None and \
161-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
162-
self.scheduler.step()
163160

164161
# statistics
165162
sample_num += labels.size(0)
@@ -183,7 +180,7 @@ def validation(self):
183180
inputs = self.convert_tensor_type(data['image'])
184181
labels = self.convert_tensor_type(data['label_prob'])
185182
inputs, labels = inputs.to(self.device), labels.to(self.device)
186-
self.optimizer.zero_grad()
183+
# self.optimizer.zero_grad()
187184
# forward + backward + optimize
188185
outputs = self.net(inputs)
189186
loss = self.get_loss_value(data, outputs, labels)
@@ -196,20 +193,17 @@ def validation(self):
196193
avg_loss = running_loss / sample_num
197194
avg_score= running_score.double() / sample_num
198195
metrics = self.config['training'].get("evaluation_metric", "accuracy")
199-
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
200-
self.scheduler.step(avg_score)
201196
valid_scalers = {'loss': avg_loss, metrics: avg_score}
202197
return valid_scalers
203198

204199
def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
205-
metrics =self.config['training'].get("evaluation_metric", "accuracy")
200+
metrics = self.config['training'].get("evaluation_metric", "accuracy")
206201
loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']}
207202
acc_scalar ={'train':train_scalars[metrics],'valid':valid_scalars[metrics]}
208203
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
209204
self.summ_writer.add_scalars(metrics, acc_scalar, glob_it)
210205
self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it)
211206

212-
logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it))
213207
logging.info('train loss {0:.4f}, avg {1:} {2:.4f}'.format(
214208
train_scalars['loss'], metrics, train_scalars[metrics]))
215209
logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format(
@@ -251,7 +245,10 @@ def train_valid(self):
251245
checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start)
252246
self.checkpoint = torch.load(checkpoint_file, map_location = self.device)
253247
assert(self.checkpoint['iteration'] == iter_start)
254-
self.net.load_state_dict(self.checkpoint['model_state_dict'])
248+
if(len(device_ids) > 1):
249+
self.net.module.load_state_dict(self.checkpoint['model_state_dict'])
250+
else:
251+
self.net.load_state_dict(self.checkpoint['model_state_dict'])
255252
self.max_val_score = self.checkpoint.get('valid_pred', 0)
256253
self.max_val_it = self.checkpoint['iteration']
257254
self.best_model_wts = self.checkpoint['model_state_dict']
@@ -266,15 +263,28 @@ def train_valid(self):
266263
self.glob_it = iter_start
267264
for it in range(iter_start, iter_max, iter_valid):
268265
lr_value = self.optimizer.param_groups[0]['lr']
266+
t0 = time.time()
269267
train_scalars = self.training()
268+
t1 = time.time()
270269
valid_scalars = self.validation()
270+
t2 = time.time()
271+
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
272+
self.scheduler.step(valid_scalars[metrics])
273+
else:
274+
self.scheduler.step()
275+
271276
self.glob_it = it + iter_valid
277+
logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it))
278+
logging.info('learning rate {0:}'.format(lr_value))
279+
logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1))
272280
self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it)
273-
274281
if(valid_scalars[metrics] > self.max_val_score):
275282
self.max_val_score = valid_scalars[metrics]
276283
self.max_val_it = self.glob_it
277-
self.best_model_wts = copy.deepcopy(self.net.state_dict())
284+
if(len(device_ids) > 1):
285+
self.best_model_wts = copy.deepcopy(self.net.module.state_dict())
286+
else:
287+
self.best_model_wts = copy.deepcopy(self.net.state_dict())
278288

279289
stop_now = True if(early_stop_it is not None and \
280290
self.glob_it - self.max_val_it > early_stop_it) else False
@@ -306,7 +316,6 @@ def train_valid(self):
306316
self.max_val_it, metrics, self.max_val_score))
307317
self.summ_writer.close()
308318

309-
310319
def infer(self):
311320
device_ids = self.config['testing']['gpus']
312321
device = torch.device("cuda:{0:}".format(device_ids[0]))
@@ -318,8 +327,8 @@ def infer(self):
318327

319328
if(self.config['testing'].get('evaluation_mode', True)):
320329
self.net.eval()
321-
322-
output_csv = self.config['testing']['output_csv']
330+
331+
output_csv = self.config['testing']['output_dir'] + '/' + self.config['testing']['output_csv']
323332
class_num = self.config['network']['class_num']
324333
save_probability = self.config['testing'].get('save_probability', False)
325334

0 commit comments

Comments
 (0)