-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
FMVPMac
authored and
FMVPMac
committed
Jul 13, 2023
0 parents
commit 37f2b08
Showing
25 changed files
with
4,018 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
## TAGNet: A transformer-based axial guided network for bile duct segmentation | ||
|
||
Implementation of [TAGNet](https://www.sciencedirect.com/science/article/abs/pii/S1746809423006778) in Pytorch, a bile duct segmentation model based on CNN and Transformer hybrid architecture network. The paper has been accepted on https://www.sciencedirect.com/science/article/abs/pii/S1746809423006778 | ||
|
||
### Data Prepocessing | ||
|
||
Considering the privacy issues involved in the data, the bile duct dataset is not made public. You can train and test the model on some publicly available liver vascular datasets (CT images). | ||
|
||
#### Data directory format | ||
|
||
---code | ||
|
||
---TAGNet source code | ||
|
||
---data | ||
|
||
---imagesTr | ||
|
||
---BileDuct_xxx_volume.nii.gz | ||
|
||
---imagesTs | ||
|
||
---BileDuct_xxx_volume.nii.gz | ||
|
||
---labelsTr | ||
|
||
---BileDuct_xxx_label.nii.gz | ||
|
||
---labelsTs | ||
|
||
---BileDuct_xxx_label.nii.gz | ||
|
||
---mask | ||
|
||
---BileDuct_xxx_liver.nii.gz | ||
|
||
---dataset.json | ||
|
||
The `xxx` represents the serial number of the data. The file `dataset.json` stores path information for each data and will be generated while preprocessing raw data (CT images). | ||
|
||
#### Data preprocess | ||
|
||
You can preprocess your own data via: | ||
|
||
```shell | ||
python preprocessor.py --root_dir your_data_root_dir --dataset your_dataset_name | ||
``` | ||
|
||
### Model Training | ||
|
||
You can train the model with the following statement: | ||
|
||
```shell | ||
python -W ignore train.py --model TAGNet --batch_size 8 --lr 0.01 --epoch 300 --patience 50 --loss_func sat --model_remark dice --gpu 6 7 | ||
``` | ||
|
||
**Parameter Description** | ||
|
||
`--model` Model name. | ||
|
||
`--batch_size` Batch size. | ||
|
||
`--lr` Learning rate. | ||
|
||
`--epoch` Maximum number of epochs for training. | ||
|
||
`--patience` Maximum number of steps for training early stopping. If there is no lower loss update for periods exceeding the set value during training, stop training to prevent overfitting. | ||
|
||
`--loss_func` Loss function. | ||
|
||
`--model_remark` This is a free option. When training the same model, you may choose different loss functions or hyperparameters to verify the model performance. This parameter could store different models by adding a remark to the model. | ||
|
||
`--gpu` Support for multi-GPU training. | ||
|
||
### Inference | ||
|
||
You can test the model with the following statement: | ||
|
||
```shell | ||
python -W ignore test.py --model TAGNet --best --gpu 7 --model_remark dice --postprocess [Optional] --save_infer [Optional] --save_csv [Optional] | ||
``` | ||
|
||
**Optional Parameters Description** | ||
|
||
`--postprocess` Post-processing the results of inference. | ||
|
||
`--save_infer` Store inference results in the form of nii.gz. | ||
|
||
`save_csv` Export the quantitative evaluation results of inference to a file. | ||
|
||
|
||
|
||
Also, you can generate more detailed quantitative evaluation results (non-post-processing and post-processing) through `summary.py`, and generate a json file to store the results via: | ||
|
||
```shell | ||
python summary.py | ||
``` | ||
|
||
Before running the py file, you need to manually specify the output path of the file (`predict_path`), the path of the ground truth (`gt_path`), and the model name (`model_name`) in the code. | ||
|
||
### Citation | ||
|
||
``` | ||
@article{ZHOU2023105244, | ||
title = {TAGNet: A transformer-based axial guided network for bile duct segmentation}, | ||
journal = {Biomedical Signal Processing and Control}, | ||
volume = {86}, | ||
pages = {105244}, | ||
year = {2023}, | ||
issn = {1746-8094}, | ||
doi = {https://doi.org/10.1016/j.bspc.2023.105244}, | ||
url = {https://www.sciencedirect.com/science/article/pii/S1746809423006778}, | ||
author = {Guang-Quan Zhou and Fuxing Zhao and Qing-Han Yang and Kai-Ni Wang and Shengxiao Li and Shoujun Zhou and Jian Lu and Yang Chen}, | ||
keywords = {Bile duct, Medical image segmentation, Axial attention, Transformer, Deep learning}, | ||
} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import argparse | ||
|
||
|
||
parser = argparse.ArgumentParser(description='set hyper parameters.') | ||
|
||
# hyper parameters. | ||
parser.add_argument('--n_classes', type=int, default=1, help="classes number") | ||
parser.add_argument('--epoch', type=int, default=500, help='epoch: default = 100') | ||
parser.add_argument('--lr', type=float, default=0.005, help='learning rate: default = 0.001') | ||
parser.add_argument('--batch_size', type=int, default=8, help='Unet:32; Unet++:16; AttUnet:?') | ||
parser.add_argument('--sample_frequency', type=int, default=8, help='sample frequency') | ||
parser.add_argument('--loss_func', type=str, default='dice', help='Defalut: dice; dice/diceBCE/focalBCE/') | ||
parser.add_argument('--model', type=str, default='Unet', help='Defalut: Unet; Unet/AttUnet/NestedUnet/Unet_PAM/') | ||
|
||
# parser.add_argument('--model_path_saved', type=str, default='') | ||
|
||
# Preprocess parameters | ||
parser.add_argument('--crop_slices', type=int, default=48) | ||
parser.add_argument('--crop_image_size', type=int, default=128) | ||
parser.add_argument('--val_crop_max_size', type=int, default=96) | ||
parser.add_argument('--test_rate', type=float, default=0.2, help='') | ||
parser.add_argument('--valid_rate', type=float, default=0.2, help='') | ||
parser.add_argument('--xy_down_scale', type=float, default=0.5, help='') | ||
parser.add_argument('--norm_factor', type=float, default=200.0, help='') | ||
parser.add_argument('--slice_down_scale', type=float, default=1.0, help='') | ||
parser.add_argument('--scale', type=int, default=2, help='AATM_V6 scale') | ||
parser.add_argument('--consistency', type=float, default=0.1, help='consistency') | ||
parser.add_argument('--edge_weight', type=float, default=0.1, help='edge_weight') | ||
|
||
parser.add_argument('--save_path',default = 'runs',help='tensorboard saved path') | ||
parser.add_argument('--data_root_dir', type=str, default='/data1/zfx/data/', help='data_root_dir: default = /home/sophia/zfx/data/') | ||
|
||
parser.add_argument('--kfold', type=int, default=0, help='Cross-validation: default = 0') | ||
parser.add_argument('--model_remark', type=str, default="", help='This is a remark for model') | ||
parser.add_argument('--dataset', type=str, default='BileDuct', help='Defalut: ZDYY; 3Dircadb/ZDYY/bileDuct') | ||
parser.add_argument('--patience', type=int, default=50, help='Early stopping number: default = 12') | ||
# parser.add_argument('--device_ids', type=int, nargs='+', default=[4,5,6,7], help='Multiple GPUs number: defalut = [5,6,7]') | ||
|
||
parser.add_argument('--gated_scale', type=float, default=0.125, help='gated attention scale: default = 0.125') | ||
parser.add_argument('--best_model', default=False, action='store_true', help='choose the best model') | ||
parser.add_argument('--continue_training', default=False, action='store_true', help='whether to continue training') | ||
parser.add_argument('--begin_epoch', type=int, default=1, help='epoch of begin training') | ||
parser.add_argument('--final_model', default=False, action='store_true', help='choose final model') | ||
parser.add_argument('--dsv', default=False, action='store_true', help='deeep supervision') | ||
parser.add_argument('--save_infer', default=False, action='store_true', help='save predicted map or not') | ||
parser.add_argument('--use_cpu', default=False, action='store_true', help='whether to use gpu for testing') | ||
parser.add_argument('--mulGPU', default=False, action='store_true', help='whether to use multiple GPUSs') | ||
parser.add_argument('-gpu', '--gpu', nargs='+', default='None', help="GPU ids to use") | ||
parser.add_argument('--postprocess', default=False, action='store_true', help='whether to run postprocessing') | ||
parser.add_argument('--save_csv', default=False, action='store_true', help='save predicted map or not') | ||
|
||
|
||
# for cropping and preprocessing data | ||
parser.add_argument('--sample_slices', type=int, default=3, help='consecutive slices.') | ||
parser.add_argument('--process_type', type=str, default='2D', help='preprocessing type: 2D: interpolation; 2.5D: z-axis interpolation; 3D: all-axis interpolation') | ||
parser.add_argument('--process_mode', type=str, default='train', help='crop and preprocess mode: train or val') | ||
|
||
|
||
args = parser.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import os | ||
import sys | ||
sys.path.append('../') | ||
|
||
from os.path import join | ||
|
||
import cv2 | ||
import torch | ||
import numpy as np | ||
|
||
from torchvision import transforms | ||
from datasets.utils_loader import * | ||
from utils.file_utils import load_json | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
from imgaug import augmenters as iaa | ||
from imgaug.augmentables.segmaps import SegmentationMapsOnImage | ||
|
||
|
||
class data_loader(Dataset): | ||
def __init__(self, args, filename, mode='train', process_type='2D', sample_slices=3): | ||
|
||
self.args = args | ||
self.mode = mode | ||
self.type = process_type | ||
|
||
self.dataset_dir = join(args.data_root_dir, args.dataset) | ||
if self.mode == 'train': | ||
self.filename_list = load_json(join(self.dataset_dir, 'preprocessed_data', filename))['train'] | ||
elif self.mode == 'val': | ||
self.filename_list = load_json(join(self.dataset_dir, 'preprocessed_data', filename))['val'] | ||
elif self.mode == 'test': | ||
self.filename_list = load_json(join(self.dataset_dir, 'preprocessed_data', filename))['test'] | ||
|
||
self.image, self.label = self._get_image_list(self.filename_list) | ||
|
||
self.extend_slice = sample_slices // 2 | ||
|
||
self.data_aug = iaa.Sequential([ | ||
iaa.Affine( | ||
scale=(0.5, 1.2), | ||
rotate=(-15, 15) | ||
), # rotate the image | ||
iaa.Flipud(0.5), | ||
iaa.PiecewiseAffine(scale=(0.01, 0.05)), | ||
iaa.Sometimes( | ||
0.1, | ||
iaa.GaussianBlur((0.1, 1.5)), | ||
), | ||
iaa.Sometimes( | ||
0.1, | ||
iaa.LinearContrast((0.5, 2.0), per_channel=0.5), | ||
) | ||
]) | ||
self.transforms = transforms.Compose([ToTensor()]) | ||
|
||
def __getitem__(self, index): | ||
|
||
if self.mode == 'train': | ||
|
||
mid_slice = self.image.shape[1] // 2 | ||
|
||
image = self.image[index][mid_slice-self.extend_slice:mid_slice+self.extend_slice+1, ...].copy() | ||
label = self.label[index][mid_slice,...].copy() | ||
|
||
image = image.transpose(1, 2, 0) | ||
segmap = SegmentationMapsOnImage(np.uint8(label), shape=(256, 256)) | ||
|
||
# data augmentation | ||
image, label = self.data_aug(image=image, segmentation_maps=segmap) | ||
|
||
image, label = image.copy(), label.copy() | ||
|
||
image = image.transpose(2, 0, 1) | ||
label = label.get_arr() | ||
edge = self.get_edge(label) | ||
|
||
else: | ||
mid_slice = self.image[index].shape[1] // 2 | ||
image = self.image[index][:, mid_slice-self.extend_slice:mid_slice+self.extend_slice+1, ...].copy() | ||
label = self.label[index][:, mid_slice, ...].copy() | ||
|
||
edge = self.get_edge(label) | ||
image, label, edge = torch.from_numpy(image), torch.from_numpy(label).unsqueeze(1), torch.from_numpy(edge).unsqueeze(1) | ||
|
||
sample = {'image': image, 'label': label, 'edge': edge} | ||
|
||
if self.mode == 'train': | ||
sample = self.transforms(sample) | ||
|
||
return sample | ||
|
||
def __len__(self): | ||
return self.image.shape[0] | ||
|
||
def get_edge(self, label): | ||
|
||
if len(label.shape) == 2: | ||
edge = cv2.Canny(np.uint8(label), 0, 1) | ||
edge[edge > 0] = 1 | ||
elif len(label.shape) == 3: | ||
edge = np.empty_like(label) | ||
for idx in range(label.shape[0]): | ||
temp_edge = cv2.Canny(np.uint8(label[idx]), 0, 1) | ||
temp_edge[temp_edge > 0] = 1 | ||
edge[idx] = temp_edge | ||
|
||
return edge | ||
|
||
def _get_image_list(self, filename_list): | ||
|
||
ct_list, label_list = [], [] | ||
|
||
for dic in filename_list: | ||
|
||
data = np.load(dic['preprocess_npy']) | ||
|
||
if self.mode == 'train': | ||
ct_list.extend(data[0]) | ||
label_list.extend(data[-1]) | ||
else: | ||
ct_list.append(data[0]) | ||
label_list.append(data[-1]) | ||
return np.array(ct_list), np.array(label_list) | ||
|
||
if __name__ == '__main__': | ||
|
||
|
||
from config import args | ||
|
||
filename = 'split_train_val.json' | ||
|
||
data_set = data_loader(args, filename, 'val', sample_slices=5) | ||
print('length of dataset: ', len(data_set)) | ||
|
||
data_load = DataLoader(dataset=data_set, batch_size=1, shuffle=True, num_workers=8) | ||
|
||
for i, sample in enumerate(data_load): | ||
print("ct: {}, seg: {}, edge: {}".format(sample['image'].shape, sample['label'].shape, sample['edge'].shape)) | ||
break |
Oops, something went wrong.