Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
FMVPMac authored and FMVPMac committed Jul 13, 2023
0 parents commit 37f2b08
Show file tree
Hide file tree
Showing 25 changed files with 4,018 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
117 changes: 117 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## TAGNet: A transformer-based axial guided network for bile duct segmentation

Implementation of [TAGNet](https://www.sciencedirect.com/science/article/abs/pii/S1746809423006778) in Pytorch, a bile duct segmentation model based on CNN and Transformer hybrid architecture network. The paper has been accepted on https://www.sciencedirect.com/science/article/abs/pii/S1746809423006778

### Data Prepocessing

Considering the privacy issues involved in the data, the bile duct dataset is not made public. You can train and test the model on some publicly available liver vascular datasets (CT images).

#### Data directory format

---code

​ ---TAGNet source code

---data

​ ---imagesTr

​ ---BileDuct_xxx_volume.nii.gz

​ ---imagesTs

​ ---BileDuct_xxx_volume.nii.gz

​ ---labelsTr

​ ---BileDuct_xxx_label.nii.gz

​ ---labelsTs

​ ---BileDuct_xxx_label.nii.gz

​ ---mask

​ ---BileDuct_xxx_liver.nii.gz

​ ---dataset.json

The `xxx` represents the serial number of the data. The file `dataset.json` stores path information for each data and will be generated while preprocessing raw data (CT images).

#### Data preprocess

You can preprocess your own data via:

```shell
python preprocessor.py --root_dir your_data_root_dir --dataset your_dataset_name
```

### Model Training

You can train the model with the following statement:

```shell
python -W ignore train.py --model TAGNet --batch_size 8 --lr 0.01 --epoch 300 --patience 50 --loss_func sat --model_remark dice --gpu 6 7
```

**Parameter Description**

`--model` Model name.

`--batch_size` Batch size.

`--lr` Learning rate.

`--epoch` Maximum number of epochs for training.

`--patience` Maximum number of steps for training early stopping. If there is no lower loss update for periods exceeding the set value during training, stop training to prevent overfitting.

`--loss_func` Loss function.

`--model_remark` This is a free option. When training the same model, you may choose different loss functions or hyperparameters to verify the model performance. This parameter could store different models by adding a remark to the model.

`--gpu` Support for multi-GPU training.

### Inference

You can test the model with the following statement:

```shell
python -W ignore test.py --model TAGNet --best --gpu 7 --model_remark dice --postprocess [Optional] --save_infer [Optional] --save_csv [Optional]
```

**Optional Parameters Description**

`--postprocess` Post-processing the results of inference.

`--save_infer` Store inference results in the form of nii.gz.

`save_csv` Export the quantitative evaluation results of inference to a file.



Also, you can generate more detailed quantitative evaluation results (non-post-processing and post-processing) through `summary.py`, and generate a json file to store the results via:

```shell
python summary.py
```

Before running the py file, you need to manually specify the output path of the file (`predict_path`), the path of the ground truth (`gt_path`), and the model name (`model_name`) in the code.

### Citation

```
@article{ZHOU2023105244,
title = {TAGNet: A transformer-based axial guided network for bile duct segmentation},
journal = {Biomedical Signal Processing and Control},
volume = {86},
pages = {105244},
year = {2023},
issn = {1746-8094},
doi = {https://doi.org/10.1016/j.bspc.2023.105244},
url = {https://www.sciencedirect.com/science/article/pii/S1746809423006778},
author = {Guang-Quan Zhou and Fuxing Zhao and Qing-Han Yang and Kai-Ni Wang and Shengxiao Li and Shoujun Zhou and Jian Lu and Yang Chen},
keywords = {Bile duct, Medical image segmentation, Axial attention, Transformer, Deep learning},
}
```

59 changes: 59 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import argparse


parser = argparse.ArgumentParser(description='set hyper parameters.')

# hyper parameters.
parser.add_argument('--n_classes', type=int, default=1, help="classes number")
parser.add_argument('--epoch', type=int, default=500, help='epoch: default = 100')
parser.add_argument('--lr', type=float, default=0.005, help='learning rate: default = 0.001')
parser.add_argument('--batch_size', type=int, default=8, help='Unet:32; Unet++:16; AttUnet:?')
parser.add_argument('--sample_frequency', type=int, default=8, help='sample frequency')
parser.add_argument('--loss_func', type=str, default='dice', help='Defalut: dice; dice/diceBCE/focalBCE/')
parser.add_argument('--model', type=str, default='Unet', help='Defalut: Unet; Unet/AttUnet/NestedUnet/Unet_PAM/')

# parser.add_argument('--model_path_saved', type=str, default='')

# Preprocess parameters
parser.add_argument('--crop_slices', type=int, default=48)
parser.add_argument('--crop_image_size', type=int, default=128)
parser.add_argument('--val_crop_max_size', type=int, default=96)
parser.add_argument('--test_rate', type=float, default=0.2, help='')
parser.add_argument('--valid_rate', type=float, default=0.2, help='')
parser.add_argument('--xy_down_scale', type=float, default=0.5, help='')
parser.add_argument('--norm_factor', type=float, default=200.0, help='')
parser.add_argument('--slice_down_scale', type=float, default=1.0, help='')
parser.add_argument('--scale', type=int, default=2, help='AATM_V6 scale')
parser.add_argument('--consistency', type=float, default=0.1, help='consistency')
parser.add_argument('--edge_weight', type=float, default=0.1, help='edge_weight')

parser.add_argument('--save_path',default = 'runs',help='tensorboard saved path')
parser.add_argument('--data_root_dir', type=str, default='/data1/zfx/data/', help='data_root_dir: default = /home/sophia/zfx/data/')

parser.add_argument('--kfold', type=int, default=0, help='Cross-validation: default = 0')
parser.add_argument('--model_remark', type=str, default="", help='This is a remark for model')
parser.add_argument('--dataset', type=str, default='BileDuct', help='Defalut: ZDYY; 3Dircadb/ZDYY/bileDuct')
parser.add_argument('--patience', type=int, default=50, help='Early stopping number: default = 12')
# parser.add_argument('--device_ids', type=int, nargs='+', default=[4,5,6,7], help='Multiple GPUs number: defalut = [5,6,7]')

parser.add_argument('--gated_scale', type=float, default=0.125, help='gated attention scale: default = 0.125')
parser.add_argument('--best_model', default=False, action='store_true', help='choose the best model')
parser.add_argument('--continue_training', default=False, action='store_true', help='whether to continue training')
parser.add_argument('--begin_epoch', type=int, default=1, help='epoch of begin training')
parser.add_argument('--final_model', default=False, action='store_true', help='choose final model')
parser.add_argument('--dsv', default=False, action='store_true', help='deeep supervision')
parser.add_argument('--save_infer', default=False, action='store_true', help='save predicted map or not')
parser.add_argument('--use_cpu', default=False, action='store_true', help='whether to use gpu for testing')
parser.add_argument('--mulGPU', default=False, action='store_true', help='whether to use multiple GPUSs')
parser.add_argument('-gpu', '--gpu', nargs='+', default='None', help="GPU ids to use")
parser.add_argument('--postprocess', default=False, action='store_true', help='whether to run postprocessing')
parser.add_argument('--save_csv', default=False, action='store_true', help='save predicted map or not')


# for cropping and preprocessing data
parser.add_argument('--sample_slices', type=int, default=3, help='consecutive slices.')
parser.add_argument('--process_type', type=str, default='2D', help='preprocessing type: 2D: interpolation; 2.5D: z-axis interpolation; 3D: all-axis interpolation')
parser.add_argument('--process_mode', type=str, default='train', help='crop and preprocess mode: train or val')


args = parser.parse_args()
140 changes: 140 additions & 0 deletions datasets/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import sys
sys.path.append('../')

from os.path import join

import cv2
import torch
import numpy as np

from torchvision import transforms
from datasets.utils_loader import *
from utils.file_utils import load_json
from torch.utils.data import Dataset, DataLoader

from imgaug import augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage


class data_loader(Dataset):
def __init__(self, args, filename, mode='train', process_type='2D', sample_slices=3):

self.args = args
self.mode = mode
self.type = process_type

self.dataset_dir = join(args.data_root_dir, args.dataset)
if self.mode == 'train':
self.filename_list = load_json(join(self.dataset_dir, 'preprocessed_data', filename))['train']
elif self.mode == 'val':
self.filename_list = load_json(join(self.dataset_dir, 'preprocessed_data', filename))['val']
elif self.mode == 'test':
self.filename_list = load_json(join(self.dataset_dir, 'preprocessed_data', filename))['test']

self.image, self.label = self._get_image_list(self.filename_list)

self.extend_slice = sample_slices // 2

self.data_aug = iaa.Sequential([
iaa.Affine(
scale=(0.5, 1.2),
rotate=(-15, 15)
), # rotate the image
iaa.Flipud(0.5),
iaa.PiecewiseAffine(scale=(0.01, 0.05)),
iaa.Sometimes(
0.1,
iaa.GaussianBlur((0.1, 1.5)),
),
iaa.Sometimes(
0.1,
iaa.LinearContrast((0.5, 2.0), per_channel=0.5),
)
])
self.transforms = transforms.Compose([ToTensor()])

def __getitem__(self, index):

if self.mode == 'train':

mid_slice = self.image.shape[1] // 2

image = self.image[index][mid_slice-self.extend_slice:mid_slice+self.extend_slice+1, ...].copy()
label = self.label[index][mid_slice,...].copy()

image = image.transpose(1, 2, 0)
segmap = SegmentationMapsOnImage(np.uint8(label), shape=(256, 256))

# data augmentation
image, label = self.data_aug(image=image, segmentation_maps=segmap)

image, label = image.copy(), label.copy()

image = image.transpose(2, 0, 1)
label = label.get_arr()
edge = self.get_edge(label)

else:
mid_slice = self.image[index].shape[1] // 2
image = self.image[index][:, mid_slice-self.extend_slice:mid_slice+self.extend_slice+1, ...].copy()
label = self.label[index][:, mid_slice, ...].copy()

edge = self.get_edge(label)
image, label, edge = torch.from_numpy(image), torch.from_numpy(label).unsqueeze(1), torch.from_numpy(edge).unsqueeze(1)

sample = {'image': image, 'label': label, 'edge': edge}

if self.mode == 'train':
sample = self.transforms(sample)

return sample

def __len__(self):
return self.image.shape[0]

def get_edge(self, label):

if len(label.shape) == 2:
edge = cv2.Canny(np.uint8(label), 0, 1)
edge[edge > 0] = 1
elif len(label.shape) == 3:
edge = np.empty_like(label)
for idx in range(label.shape[0]):
temp_edge = cv2.Canny(np.uint8(label[idx]), 0, 1)
temp_edge[temp_edge > 0] = 1
edge[idx] = temp_edge

return edge

def _get_image_list(self, filename_list):

ct_list, label_list = [], []

for dic in filename_list:

data = np.load(dic['preprocess_npy'])

if self.mode == 'train':
ct_list.extend(data[0])
label_list.extend(data[-1])
else:
ct_list.append(data[0])
label_list.append(data[-1])
return np.array(ct_list), np.array(label_list)

if __name__ == '__main__':


from config import args

filename = 'split_train_val.json'

data_set = data_loader(args, filename, 'val', sample_slices=5)
print('length of dataset: ', len(data_set))

data_load = DataLoader(dataset=data_set, batch_size=1, shuffle=True, num_workers=8)

for i, sample in enumerate(data_load):
print("ct: {}, seg: {}, edge: {}".format(sample['image'].shape, sample['label'].shape, sample['edge'].shape))
break
Loading

0 comments on commit 37f2b08

Please sign in to comment.