-
Hi! I'm currently working on training a model using MONAI and PyTorch Lightning and the model doesn't seem to generalize well. I am using a Covid-19 infection dataset with Coronacases and Radiopaedia volumes. You can find it here HRCT transformations are being used for "coronacases" volumes of the dataset, while CBCT are being used for "radiopaedia" volumes. Val transformations are used for validation, while the others are used for training. The differenciation relays on the orientation of the volumes and the intensity scaling. SPATIAL_SIZE = (64, 64, 64)
NUM_RAND_PATCHES = 16
LEVEL = -650
WIDTH = 1500
LOWER_BOUND_WINDOW_HRCT = LEVEL - (WIDTH // 2)
UPPER_BOUND_WINDOW_HRCT = LEVEL + (WIDTH // 2)
LOWER_BOUND_WINDOW_CBCT = 0
UPPER_BOUND_WINDOW_CBCT = 255
def get_hrct_transforms():
return monai.transforms.Compose(
[
monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
spatial_size=SPATIAL_SIZE, pos=1, neg=1,
num_samples=NUM_RAND_PATCHES, allow_smaller=True),
# monai.transforms.SpatialPadd(keys=('img', 'mask'), spatial_size=SPATIAL_SIZE, method='symmetric'),
monai.transforms.ScaleIntensityRanged(keys=('img',), a_min=LOWER_BOUND_WINDOW_HRCT,
a_max=UPPER_BOUND_WINDOW_HRCT, b_min=0.0, b_max=1.0, clip=True),
monai.transforms.ToTensord(keys=("img", "mask")),
]
)
def get_cbct_transforms():
return monai.transforms.Compose(
[
monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
spatial_size=SPATIAL_SIZE, pos=1, neg=1,
num_samples=NUM_RAND_PATCHES, allow_smaller=True),
# monai.transforms.SpatialPadd(keys=('img', 'mask'), spatial_size=SPATIAL_SIZE, method='symmetric'),
monai.transforms.ScaleIntensityd(keys='img', minv=LOWER_BOUND_WINDOW_CBCT, maxv=UPPER_BOUND_WINDOW_CBCT),
monai.transforms.ToTensord(keys=("img", "mask")),
]
)
def get_val_hrct_transforms():
return monai.transforms.Compose(
[
monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
monai.transforms.ScaleIntensityRanged(keys=('img',), a_min=LOWER_BOUND_WINDOW_HRCT,
a_max=UPPER_BOUND_WINDOW_HRCT, b_min=0.0, b_max=1.0,
clip=True),
monai.transforms.ToTensord(keys=("img", "mask")),
]
)
def get_val_cbct_transforms():
return monai.transforms.Compose(
[
monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
monai.transforms.ScaleIntensityd(keys='img', minv=LOWER_BOUND_WINDOW_CBCT, maxv=UPPER_BOUND_WINDOW_CBCT),
monai.transforms.ToTensord(keys=("img", "mask")),
]
) Additionally, I am using an Unet architecture and I have tried trainning with dice and generalized dice loss functions. The optimizer is an AdamW and currently trying over 5000 epocs. The issue I'm facing is that, despite the training loss decreasing and the training Dice score increasing as expected, the validation loss doesn't decrease as much as I would like and the validation Dice score doesn't improve significantly. class Net(L.pytorch.LightningModule):
def __init__(self):
super(Net, self).__init__()
self.save_hyperparameters()
self.model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2
)
self.dice_metric = DiceMetric(include_background=False, reduction="mean")
self.train_dice_metric = DiceMetric(include_background=False, reduction="mean")
self.loss_function = monai.losses.GeneralizedDiceLoss(sigmoid=True, include_background=False)
self.post_pred = monai.transforms.Compose(
[monai.transforms.Activations(sigmoid=True), monai.transforms.AsDiscrete(threshold_values=0.5)])
self.post_label = monai.transforms.Compose([monai.transforms.AsDiscrete(threshold_values=0.5)])
self.best_val_dice = 0
self.best_val_epoch = 0
self.validation_step_outputs = []
self.train_step_outputs = []
self.training_ds = None
self.validation_ds = None
self.test_ds = None
def prepare_data(self) -> None:
# Load images and masks
logging.info(f"Loading images from {COVID_CASES_PATH}")
images = load_images_from_path(COVID_CASES_PATH)
labels = load_images_from_path(INFECTION_MASKS_PATH)
# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
logging.debug(data_dicts)
shuffler = np.random.RandomState(SEED)
shuffler.shuffle(data_dicts)
data_dicts = list(data_dicts)
# Split the data into training (70%), validation (10%), and test sets (20%)
test_split = int(len(data_dicts) * 0.2)
val_split = int(len(data_dicts) * 0.1)
train_paths = data_dicts[test_split + val_split:]
val_paths = data_dicts[test_split:test_split + val_split]
test_paths = data_dicts[:test_split]
# Define the CovidDataset instances for training, validation, and test
self.training_ds = CovidDataset(volumes=train_paths, hrct_transform=get_hrct_transforms(),
cbct_transform=get_cbct_transforms())
self.validation_ds = CovidDataset(volumes=val_paths, hrct_transform=get_val_hrct_transforms(),
cbct_transform=get_val_cbct_transforms())
self.test_ds = CovidDataset(volumes=test_paths, hrct_transform=get_val_hrct_transforms(),
cbct_transform=get_val_cbct_transforms())
def train_dataloader(self):
train_dataloader = DataLoader(self.training_ds, batch_size=1, num_workers=4)
return train_dataloader
def val_dataloader(self):
val_dataloader = DataLoader(self.validation_ds, batch_size=1, num_workers=4)
return val_dataloader
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-5)
return optimizer
def training_step(self, batch, batch_idx):
inputs, labels = batch["img"], batch["mask"]
outputs = self.forward(inputs)
loss = self.loss_function(outputs, labels)
outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
labels = [self.post_label(i) for i in decollate_batch(labels)]
self.train_dice_metric(y_pred=outputs, y=labels)
train_loss_dictionary = {"loss": loss}
self.train_step_outputs.append(train_loss_dictionary)
return train_loss_dictionary
def on_train_epoch_end(self) -> None:
train_loss = 0
for output in self.train_step_outputs:
train_loss += output["loss"].sum().item()
mean_train_loss = torch.tensor(train_loss / len(self.train_step_outputs)) # Total loss of batches / number of batches
mean_train_dice = self.train_dice_metric.aggregate().item()
self.train_dice_metric.reset()
self.log_dict({"train_dice": mean_train_dice, "train_loss": train_loss / len(self.train_step_outputs)}, prog_bar=True)
tensorboard_logs = {
"train_dice": mean_train_dice,
"train_loss": mean_train_loss,
}
self.logger.experiment.add_scalars("losses", {"train": mean_train_loss}, self.current_epoch)
self.logger.experiment.add_scalars("dice", {"train": mean_train_dice}, self.current_epoch)
self.logger.log_metrics(tensorboard_logs, step=self.current_epoch)
self.train_step_outputs.clear()
def validation_step(self, batch, batch_idx):
inputs, labels = batch["img"], batch["mask"]
roi_size = VALIDATION_INFERENCE_ROI_SIZE
sw_batch_size = 4
outputs = sliding_window_inference(inputs, roi_size, sw_batch_size, self.forward)
loss = self.loss_function(outputs, labels)
outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
labels = [self.post_label(i) for i in decollate_batch(labels)]
self.dice_metric(y_pred=outputs, y=labels)
validation_loss_dictionary = {"loss": loss}
self.validation_step_outputs.append(validation_loss_dictionary)
return validation_loss_dictionary
def on_validation_epoch_end(self) -> None:
val_loss = 0
for output in self.validation_step_outputs:
val_loss += output["loss"].sum().item()
mean_val_loss = torch.tensor(val_loss / len(self.validation_step_outputs))
mean_val_dice = self.dice_metric.aggregate().item()
self.dice_metric.reset()
self.log_dict({"val_dice": mean_val_dice, "val_loss": val_loss / len(self.validation_step_outputs)}, prog_bar=True)
tensorboard_logs = {
"val_dice": mean_val_dice,
"val_loss": mean_val_loss,
}
self.logger.experiment.add_scalars("losses", {"val_loss": mean_val_loss}, self.current_epoch)
self.logger.experiment.add_scalars("dice", {"val_dice": mean_val_dice}, self.current_epoch)
self.logger.log_metrics(tensorboard_logs, step=self.current_epoch)
if mean_val_dice > self.best_val_dice:
self.best_val_dice = mean_val_dice
self.best_val_epoch = self.current_epoch
self.validation_step_outputs.clear() Here you can see some logs of the terminal so you can have a better understanding of the training process:
Environment
I would appreciate any guidance on how to improve the generalization of my model. Are there specific strategies or best practices that I should consider? Any insights into potential issues with my current approach would also be very helpful. Thank you for your assistance! EDIT: |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi @SrMateos, thanks for your interest here. Based on the situation described where training loss is decreasing and Dice score is improving, but validation loss and Dice score are not showing similar improvement over 5000 epochs with a Unet architecture and dice-based loss functions, here are a few suggestions:
Hope it helps, thanks. |
Beta Was this translation helpful? Give feedback.
Hi @KumoLiu,
It appears that the issue was with the
threshold_values
parameter in the AsDiscrete transformation for the postPred transforms. The steps were not discretizing the values correctly, which caused the Dice coefficient to malfunction and the model to produce inaccurate results.Thank you for your time and attention.