Skip to content

Commit

Permalink
update the loss and input shape
Browse files Browse the repository at this point in the history
  • Loading branch information
jaehwan committed Jan 24, 2024
1 parent a3b98ae commit 6d1dab4
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 65 deletions.
10 changes: 7 additions & 3 deletions configs/augmentations.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
RandomAffine:
scales: !!python/tuple [ 0.8, 1.2 ]
degrees: !!python/tuple [ -5, 5 ]
scales: !!python/tuple [ 0.5, 1.5 ]
degrees: !!python/tuple [ -10, 10 ]
isotropic: false
image_interpolation: linear
p: 0.5
p: 0.5

RandomFlip:
axes: 0
flip_probability: 0.5
17 changes: 9 additions & 8 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@ experiment:
name: Anchor # Anchor, MidLine, SmallandHard #Segmentation metadata

data_loader:
dataset: /home/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/data/preprocessed/HaN-Seg/set_1/preprocessing_B3B84D64C6
dataset: /home/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/data/preprocessed/HaN-Seg/set_1/preprocessing_7CB7AA3181
kfold: 1 # 1 is not use kfold validation other is use kfold validation
augmentations: configs/augmentations.yaml
batch_size: 1
num_workers: 16
patch_loader: False
patch_shape:
- 128
- 128
- 64
resize_shape:
- 168
- 280
- 360
- 288 # x
- 288 # y
- 64 # z
sampler_type: UniformSampler

model:
Expand All @@ -35,9 +36,9 @@ optimizer:
name: AdamW

trainer:
reload: True
checkpoint: '/home/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/src/HanSeg_2023/experiments/Anchor/train/main_1820C18417/checkpoints/best.pth'
do_train: False
reload: False
checkpoint: ''
do_train: True
do_test: False
do_inference: True
do_inference: False
epochs: 1000
4 changes: 3 additions & 1 deletion configs/preprocessing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ preprocessing:
Resample:
target: !!python/tuple [1, 1, 1]
RescaleIntensity:
out_min_max: !!python/tuple [0, 1]
out_min_max: !!python/tuple [0, 1]

check_preprocessing: True
40 changes: 19 additions & 21 deletions datasets/HaN.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ def _numpy_reader(self, path):
data = torch.from_numpy(np.load(path)).float()
affine = torch.eye(4, requires_grad=False)
return data, affine

def _nrrd_reader(self, path):
raw_data, _ = nrrd.read(path)
data = torch.from_numpy(raw_data).float()
affine = torch.eye(4, requires_grad=False) # Identity matrix(단위 행렬)
return data, affine

def _split_data(self, data_list):
# train and val data split
Expand Down Expand Up @@ -88,27 +82,31 @@ def _get_subjects_list(self, root, splits):
subject_dict = {
'partition': split,
'patient': patient,
'ct': tio.ScalarImage(ct_data_path, reader=self._nrrd_reader),
'ct': tio.ScalarImage(ct_data_path),
# 'mr': tio.ScalarImage(mr_data_path, reader=self._nrrd_reader),
'label': tio.LabelMap(label_path, reader=self._nrrd_reader),
'label': tio.LabelMap(label_path,),
}

subjects.append(tio.Subject(**subject_dict))
print(f"Loaded {len(subjects)} patients for split {split}")
return subjects

def get_loader(self, config):
#todo
sampler = SamplerFactory(config).get()
queue = tio.Queue(
subjects_dataset=self,
max_length=300,
samples_per_volume=10,
sampler=sampler,
num_workers=config.num_workers,
shuffle_subjects=True,
shuffle_patches=True,
start_background=False,
)
loader = DataLoader(queue, batch_size=config.batch_size, num_workers=0, pin_memory=True)
# patch-based training
if config.patch_loader:
sampler = SamplerFactory(config).get()
queue = tio.Queue(
subjects_dataset=self,
max_length=300,
samples_per_volume=10,
sampler=sampler,
num_workers=config.num_workers,
shuffle_subjects=True,
shuffle_patches=True,
start_background=False,
)
loader = DataLoader(queue, batch_size=config.batch_size, num_workers=0, pin_memory=True)
else: # subject-based training
dataset = tio.SubjectsDataset(self._subjects, transform=self._transform)
loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=True)
return loader
16 changes: 8 additions & 8 deletions datasets/label_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
}

Anchor_dict = {
"background": 0,
"Bone_Mandible": 4,
"Brainstem": 5,
"Eye_AL": 12,
"Eye_AR": 13,
"Eye_PL": 14,
"Eye_PR": 15,
"SpinalCord": 30,
"background": 0, # 0
"Bone_Mandible": 4, # 1
"Brainstem": 5, # 2
"Eye_AL": 12, # 3
"Eye_AR": 13, # 4
"Eye_PL": 14, # 5
"Eye_PR": 15, # 6
"SpinalCord": 30, # 7
}
23 changes: 18 additions & 5 deletions preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def timehash():
# main
patient_data_list = os.listdir(source_dir)
subjects = []
table = []
for patient in patient_data_list:
# generate labels
ct_data_path = os.path.join(source_dir, patient, patient + '_IMG_CT.nrrd')
Expand All @@ -86,9 +87,9 @@ def timehash():
subject_dict = {

'patient': patient,
'ct': tio.ScalarImage(ct_data_path, reader=_nrrd_reader),
# 'mr': tio.ScalarImage(mr_data_path, reader=self._nrrd_reader),
'label': tio.LabelMap(label_path, reader=_nrrd_reader),
'ct': tio.ScalarImage(ct_data_path),
# 'mr': tio.ScalarImage(mr_data_path),
'label': tio.LabelMap(label_path),
}

# preprocessing
Expand All @@ -98,11 +99,23 @@ def timehash():
os.makedirs(os.path.join(save_dir, patient), exist_ok=True)
ct_data_path = os.path.join(save_dir, patient, patient + '_IMG_CT.nrrd')
label_path = os.path.join(save_dir, patient, patient + f'_{preproc.experiment.name}.seg.nrrd')
nrrd.write(ct_data_path, transform_subject['ct'][tio.DATA].squeeze(0).numpy())
nrrd.write(label_path, transform_subject['label'][tio.DATA].squeeze(0).numpy())

transformed_ct_data = transform_subject['ct'][tio.DATA].squeeze(0).numpy()
transformed_label_data = transform_subject['label'][tio.DATA].squeeze(0).numpy()

nrrd.write(ct_data_path, transformed_ct_data)
nrrd.write(label_path, transformed_label_data)

print(f"Saved {patient} patients")
subjects.append(tio.Subject(**subject_dict))

if preproc.check_preprocessing:
# check preprocessing
info = [patient, transformed_ct_data.shape, transformed_label_data.shape, np.unique(transformed_label_data)]
table.append(info)

print(f"Completed {len(subjects)} patients")
df = pd.DataFrame(table, columns=['data_name', 'CT_shape', 'label_shape', 'label_unique'])
df.to_csv(os.path.join(save_dir, 'preprocessing.csv'), index=True)


52 changes: 33 additions & 19 deletions utils/TrainFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,32 @@ def __init__(self, config, debug=False):

# load evaluator
self.evaluator = Evaluator(classes=num_classes)
if self.config.data_loader.patch_loader:
tranform = tio.Compose([
tio.CropOrPad(self.config.data_loader.resize_shape, padding_mode=4),
# self.config.data_loader.preprocessing,
self.config.data_loader.augmentations])
else:
tranform = tio.Compose([
tio.Resize(self.config.data_loader.resize_shape),
# self.config.data_loader.preprocessing,
self.config.data_loader.augmentations])

self.train_dataset = HaN(
config = self.config,
splits='train',
transform=tio.Compose([
# tio.CropOrPad(self.config.data_loader.resize_shape, padding_mode=0),
# self.config.data_loader.preprocessing,
self.config.data_loader.augmentations,
]),
transform=tranform,
sampler=self.config.data_loader.sampler_type
)
self.val_dataset = HaN(
config = self.config,
transform=tranform,
splits='val',
# transform=self.config.data_loader.preprocessing,
)
self.test_dataset = HaN(
config = self.config,
transform=tranform,
splits='test',
# transform=self.config.data_loader.preprocessing,
)
Expand Down Expand Up @@ -158,6 +167,8 @@ def extract_data_from_feature(self, feature):

def train(self):

# self.num_classes

self.model.train()
self.evaluator.reset_eval()

Expand All @@ -170,11 +181,14 @@ def train(self):
with torch.cuda.amp.autocast():
preds = self.model(images) # pred shape B, C(N), H, W, D
preds_soft = F.softmax(preds, dim=1)
# 이미 여기서 thread: [23,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size
gt_onehot = torch.nn.functional.one_hot(gt.squeeze().long(), num_classes=self.num_classes)
gt_onehot = gt_onehot.unsqueeze(0)
gt_onehot = torch.movedim(gt_onehot, -1, 1)
# 이미 여기서 thread: [23,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size -> patch base 일 경우
# gt shape B, C(N), H, W, D and C(N) -> C(N) is one-hot encoded
gt_onehot = F.one_hot(gt.squeeze(0).long(), num_classes=self.num_classes).permute(0, 4, 1, 2, 3)
assert preds_soft.ndim == gt_onehot.ndim, f'Gt and output dimensions are not the same before loss. {preds_soft.ndim} vs {gt_onehot.ndim}'
# ignore background
preds_soft = preds_soft[:, 1:, ...]
gt_onehot = gt_onehot[:, 1:, ...]


if self.loss.names[0] == 'Dice3DLoss':
loss, dice = self.loss.losses[self.loss.names[0]](preds_soft, gt_onehot)
Expand Down Expand Up @@ -232,19 +246,19 @@ def test(self, phase):
for i, d in tqdm(enumerate(data_loader), total=len(data_loader), desc=f'{phase} epoch {str(self.epoch)}'):
images, gt = self.extract_data_from_feature(d)

output = self.model(images)
output_soft = F.softmax(output, dim=1)
# 이미 여기서 thread: [23,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size
gt_onehot = torch.nn.functional.one_hot(gt.squeeze().long(), num_classes=self.num_classes)
gt_onehot = gt_onehot.unsqueeze(0)
gt_onehot = torch.movedim(gt_onehot, -1, 1)
assert output.ndim == gt.ndim, f'Gt and output dimensions are not the same before loss. {output.ndim} vs {gt.ndim}'
preds = self.model(images)
preds_soft = F.softmax(preds, dim=1)
gt_onehot = F.one_hot(gt.squeeze(0).long(), num_classes=self.num_classes).permute(0, 4, 1, 2, 3)
assert preds_soft.ndim == gt_onehot.ndim, f'Gt and output dimensions are not the same before loss. {preds_soft.ndim} vs {gt_onehot.ndim}'
# ignore background
preds_soft = preds_soft[:, 1:, ...]
gt_onehot = gt_onehot[:, 1:, ...]

if self.loss.names[0] == 'Dice3DLoss':
loss, dice = self.loss.losses[self.loss.names[0]](output_soft, gt_onehot)
loss, dice = self.loss.losses[self.loss.names[0]](preds_soft, gt_onehot)
else:
dice = compute_per_channel_dice(output_soft, gt_onehot)
loss = self.loss.losses[self.loss.names[0]](output_soft, gt_onehot).cuda(self.config.device)
dice = compute_per_channel_dice(preds_soft, gt_onehot)
loss = self.loss.losses[self.loss.names[0]](preds_soft, gt_onehot).cuda(self.config.device)
losses.append(loss.item())
# self.evaluator.compute_metrics(output, gt)
self.evaluator.add_dice(dice=dice)
Expand Down

0 comments on commit 6d1dab4

Please sign in to comment.