From 1edb516af5e017c3d203bec478f54600bf9b5ca9 Mon Sep 17 00:00:00 2001 From: Sylwia Majchrowska Date: Fri, 18 Mar 2022 15:44:16 +0100 Subject: [PATCH] initial commit: --- README.md | 60 +++++++ dataset.py | 150 ++++++++++++++++ models.py | 65 +++++++ multi_classification.py | 382 ++++++++++++++++++++++++++++++++++++++++ train.py | 238 +++++++++++++++++++++++++ 5 files changed, 895 insertions(+) create mode 100644 README.md create mode 100644 dataset.py create mode 100644 models.py create mode 100644 multi_classification.py create mode 100644 train.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..90682ea --- /dev/null +++ b/README.md @@ -0,0 +1,60 @@ +# Mutlimodality for skin lesions classification + +Many people worldwide suffer from skin diseases. For diagnosis, physicians often combine multiple information sources. These include, for instance, clinical images, microscopic images and meta-data such as the age and gender of the patient. Deep learning algorithms can support the classification of skin lesions by fusing all the information together and evaluating it. Several such algorithms are already being developed. However, to apply these learning algorithms in the clinic, they need to be further improved to achieve higher diagnostic accuracy. + +## Dataset + +Download the [ISIC 2020 dataset](https://www.kaggle.com/nroman/melanoma-external-malignant-256). +In the directory you will find: +- metadata as `train.csv` and `test.csv`, +- images for train and test subsets. + +## Training multimodal EfficientNet + +In its most basic form, training new networks boils down to: + +```.bash +python train.py --save-name efficientnetb2_256_20ep --data-dir ./melanoma_external_256/ --image-size 256 \ + --n-epochs 20 --enet-type efficientnet-b2 --CUDA_VISIBLE_DEVICES 0 +python train.py --save-name efficientnetb2_256_20ep_meta --data-dir ./melanoma_external_256/ --image-size 256 \ + --n-epochs 20 --enet-type efficientnet-b2 --CUDA_VISIBLE_DEVICES 0 --use-meta +``` + +The first command is uses only images during training; for the second one additional addition of avalilable metadata is done. + +## Training multilabel classifier + +We created a model with multiple binary heads to distinguish between different type of biases, such as ruler and black frame. +To use the model check `multi_classification.py` script. + +```.bash +python multi_classification.py --img_path ./melanoma_external_256/train/train --ann_path gans_biases.csv \ + --mode train --model_path multiclasificator_efficientnet-b2_GAN.pth + +python multi_classification.py --img_path ./melanoma_external_256/train/train --ann_path gans_biases.csv \ + --mode val --model_path multiclasificator_efficientnet-b2_GAN.pth + +python multi_classification.py --img_path ./melanoma_external_256/test/test --mode test \ + --model_path multiclasificator_efficientnet-b2_GAN.pth --save_path annotations.csv +``` + +We can distinguish between 3 modes: +- train: we need here provided annotations of biases for each image, +- val: we need here provided annotations of biases for each image and trained model, +- test: we need trained model to create new annotations for unseen images. + +## Creditentials + +This project based on code produced by [1st place on liderboard for Kaggle ISIC 2020 competition](https://www.kaggle.com/c/siim-isic-melanoma-classification/leaderboard). + +More details can be found here: + +https://github.com/haqishen/SIIM-ISIC-Melanoma-Classification-1st-Place-Solution + +https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/175412 + +http://arxiv.org/abs/2010.05351 + +## Acknowledgements + +The project was developed during the first rotation of the [Eye for AI Program](https://www.ai.se/en/eyeforai) at the AI Competence Center of [Sahlgrenska University Hospital](https://www.sahlgrenska.se/en/). Eye for AI initiative is a global program focused on bringing more international talents into the Swedish AI landscape. diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..37c3f24 --- /dev/null +++ b/dataset.py @@ -0,0 +1,150 @@ +import os +import cv2 +import numpy as np +import pandas as pd +import albumentations +import torch +from torch.utils.data import Dataset + +from tqdm import tqdm + + +class MelanomaDataset(Dataset): + def __init__(self, csv, mode, meta_features, transform=None): + + self.csv = csv.reset_index(drop=True) + self.mode = mode + self.use_meta = meta_features is not None + self.meta_features = meta_features + self.transform = transform + + def __len__(self): + return self.csv.shape[0] + + def __getitem__(self, index): + + row = self.csv.iloc[index] + + image = cv2.imread(row.filepath) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if self.transform is not None: + res = self.transform(image=image) + image = res['image'].astype(np.float32) + else: + image = image.astype(np.float32) + + image = image.transpose(2, 0, 1) + + if self.use_meta: + data = (torch.tensor(image).float(), torch.tensor(self.csv.iloc[index][self.meta_features]).float()) + else: + data = torch.tensor(image).float() + + if self.mode == 'test': + return data + else: + return data, torch.tensor(self.csv.iloc[index].target).long() + + +def get_transforms(image_size): + + transforms_train = albumentations.Compose([ + albumentations.Transpose(p=0.5), + albumentations.VerticalFlip(p=0.5), + albumentations.HorizontalFlip(p=0.5), + albumentations.RandomBrightness(limit=0.2, p=0.75), + albumentations.RandomContrast(limit=0.2, p=0.75), + albumentations.OneOf([ + albumentations.MotionBlur(blur_limit=5), + albumentations.MedianBlur(blur_limit=5), + albumentations.GaussianBlur(blur_limit=5), + albumentations.GaussNoise(var_limit=(5.0, 30.0)), + ], p=0.7), + + albumentations.OneOf([ + albumentations.OpticalDistortion(distort_limit=1.0), + albumentations.GridDistortion(num_steps=5, distort_limit=1.), + albumentations.ElasticTransform(alpha=3), + ], p=0.7), + + albumentations.CLAHE(clip_limit=4.0, p=0.7), + albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5), + albumentations.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85), + albumentations.Resize(image_size, image_size), + albumentations.Cutout(max_h_size=int(image_size * 0.375), max_w_size=int(image_size * 0.375), num_holes=1, p=0.7), + albumentations.Normalize() + ]) + + transforms_val = albumentations.Compose([ + albumentations.Resize(image_size, image_size), + albumentations.Normalize() + ]) + + return transforms_train, transforms_val + + +def get_meta_data(df_train, df_test): + df_train['sex'].fillna(df_train['sex'].mode()[0], inplace=True) + df_train['age_approx'].fillna(df_train['age_approx'].median(), inplace=True) + df_train['anatom_site_general_challenge'].fillna('unknown', inplace=True) + df_test['anatom_site_general_challenge'].fillna('unknown', inplace=True) + df_test['sex'].fillna(df_test['sex'].mode()[0], inplace=True) + df_test['age_approx'].fillna(df_test['age_approx'].median(), inplace=True) + + # One-hot encoding of anatom_site_general_challenge feature + concat = pd.concat([df_train['anatom_site_general_challenge'], df_test['anatom_site_general_challenge']], ignore_index=True) + dummies = pd.get_dummies(concat, dummy_na=True, dtype=np.uint8, prefix='site') + df_train = pd.concat([df_train, dummies.iloc[:df_train.shape[0]]], axis=1) + df_test = pd.concat([df_test, dummies.iloc[df_train.shape[0]:].reset_index(drop=True)], axis=1) + # Sex features + df_train['sex'] = df_train['sex'].map({'male': 1, 'female': 0}) + df_test['sex'] = df_test['sex'].map({'male': 1, 'female': 0}) + # Age features + df_train['age_approx'] /= 90 + df_test['age_approx'] /= 90 + # patient id + df_train['patient_id'] = df_train['patient_id'].fillna(0) + # n_image per user + df_train['n_images'] = df_train.patient_id.map(df_train.groupby(['patient_id']).image_name.count()) + df_test['n_images'] = df_test.patient_id.map(df_test.groupby(['patient_id']).image_name.count()) + df_train.loc[df_train['patient_id'] == -1, 'n_images'] = 1 + df_train['n_images'] = np.log1p(df_train['n_images'].values) + df_test['n_images'] = np.log1p(df_test['n_images'].values) + # image size + train_images = df_train['filepath'].values + train_sizes = np.zeros(train_images.shape[0]) + for i, img_path in enumerate(tqdm(train_images)): + train_sizes[i] = os.path.getsize(img_path) + df_train['image_size'] = np.log(train_sizes) + test_images = df_test['filepath'].values + test_sizes = np.zeros(test_images.shape[0]) + for i, img_path in enumerate(tqdm(test_images)): + test_sizes[i] = os.path.getsize(img_path) + df_test['image_size'] = np.log(test_sizes) + + meta_features = ['sex', 'age_approx', 'n_images', 'image_size'] + [col for col in df_train.columns if col.startswith('site_')] + n_meta_features = len(meta_features) + return df_train, df_test, meta_features, n_meta_features + + +def get_df(data_dir, use_meta): + + df_train = pd.read_csv(os.path.join(data_dir, 'train.csv')) + df_train['filepath'] = df_train['image_name'].apply(lambda x: os.path.join(data_dir, f'train/train', f'{x}.jpg')) + + df_train['is_ext'] = 0 + + # test data + df_test = pd.read_csv(os.path.join(data_dir, 'test.csv')) + df_test['filepath'] = df_test['image_name'].apply(lambda x: os.path.join(data_dir, f'test/test', f'{x}.jpg')) + + if use_meta: + df_train, df_test, meta_features, n_meta_features = get_meta_data(df_train, df_test) + else: + meta_features = None + n_meta_features = 0 + + # class mapping + mel_idx = 1 + return df_train, df_test, meta_features, n_meta_features, mel_idx diff --git a/models.py b/models.py new file mode 100644 index 0000000..a47edf6 --- /dev/null +++ b/models.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from efficientnet_pytorch import EfficientNet + + +sigmoid = nn.Sigmoid() + + +class Swish(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * sigmoid(i) + ctx.save_for_backward(i) + return result + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class Swish_Module(nn.Module): + def forward(self, x): + return Swish.apply(x) + + +class Effnet_Melanoma(nn.Module): + def __init__(self, enet_type, out_dim, n_meta_features=0, n_meta_dim=[512, 128]): + super(Effnet_Melanoma, self).__init__() + self.n_meta_features = n_meta_features + self.enet = EfficientNet.from_pretrained(enet_type) + self.dropouts = nn.ModuleList([ + nn.Dropout(0.5) for _ in range(5) + ]) + in_ch = self.enet._fc.in_features + if n_meta_features > 0: + self.meta = nn.Sequential( + nn.Linear(n_meta_features, n_meta_dim[0]), + nn.BatchNorm1d(n_meta_dim[0]), + Swish_Module(), + nn.Dropout(p=0.3), + nn.Linear(n_meta_dim[0], n_meta_dim[1]), + nn.BatchNorm1d(n_meta_dim[1]), + Swish_Module(), + ) + in_ch += n_meta_dim[1] + self.myfc = nn.Linear(in_ch, out_dim) + self.enet._fc = nn.Identity() + + def extract(self, x): + x = self.enet(x) + return x + + def forward(self, x, x_meta=None): + x = self.extract(x).squeeze(-1).squeeze(-1) + if self.n_meta_features > 0: + x_meta = self.meta(x_meta) + x = torch.cat((x, x_meta), dim=1) + for i, dropout in enumerate(self.dropouts): + if i == 0: + out = self.myfc(dropout(x)) + else: + out += self.myfc(dropout(x)) + out /= len(self.dropouts) + return out diff --git a/multi_classification.py b/multi_classification.py new file mode 100644 index 0000000..b914170 --- /dev/null +++ b/multi_classification.py @@ -0,0 +1,382 @@ +import argparse +import os +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +import torchvision +import numpy as np + +from PIL import * +from PIL import ImageFile +from PIL import Image +from efficientnet_pytorch import EfficientNet + +import wandb + +#System settings +ImageFile.LOAD_TRUNCATED_IMAGES = True +os.environ['WANDB_CONSOLE'] = 'off' +#Coloring for print outputs +class color: + RED = '\033[91m' + BOLD = '\033[1m' + END = '\033[0m' + +class Identity(nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + +class MultilabelClassifier(nn.Module): + def __init__(self): + super().__init__() + #self.resnet = models.resnet34(pretrained=True) + #self.model_wo_fc = nn.Sequential(*(list(self.resnet.children())[:-1])) + self.efficient_net = EfficientNet.from_pretrained(model_name="efficientnet-b2", num_classes=2) + #self.model_wo_fc = nn.Sequential(*(list(self.efficient_net.children())[:-1])) + inch = self.efficient_net._fc.in_features + self.hair_dense = nn.Sequential( + nn.Dropout(p=0.2), + nn.Linear(in_features=inch, out_features=2) + ) + self.hair_short = nn.Sequential( + nn.Dropout(p=0.2), + nn.Linear(in_features=inch, out_features=2) + ) + self.hair_medium = nn.Sequential( + nn.Dropout(p=0.2), + nn.Linear(in_features=inch, out_features=2) + ) + self.black_frame = nn.Sequential( + nn.Dropout(p=0.2), + nn.Linear(in_features=inch, out_features=2) + ) + self.ruler_mark = nn.Sequential( + nn.Dropout(p=0.2), + nn.Linear(in_features=inch, out_features=2) + ) + self.other = nn.Sequential( + nn.Dropout(p=0.2), + nn.Linear(in_features=inch, out_features=2) + ) + self.efficient_net._fc = Identity() + + def forward(self, x): + x = self.efficient_net(x) + + return { + "hair_dense": self.hair_dense(x), + "hair_short": self.hair_short(x), + "hair_medium": self.hair_medium(x), + "black_frame": self.black_frame(x), + "ruler_mark": self.ruler_mark(x), + "other": self.other(x) + } + +class BiasDataset(Dataset): + def __init__(self, root_path: str, annotationfile_path: str, transform=None, train=True): + self.path = root_path + self.train = train + self.transform = transform + if self.train: + self.annotationfile_path = annotationfile_path + self.folder = [ + x.strip().split()[0] for x in open(self.annotationfile_path) + ] + else: + included_extensions = ['jpg','jpeg', 'bmp', 'png', 'gif'] + self.folder = sorted([fn for fn in os.listdir(self.path) + if any(fn.endswith(ext) for ext in included_extensions)]) + + def __len__(self): + if self.train: + return len(self.folder) + else: + return len(os.listdir(self.path)) + + def __getitem__(self,idx): + if self.train: + img_loc = os.path.join(self.path, self.folder[idx].split(',')[0]) + translation_dict = [int(label) for label in self.folder[idx].split(',')[1:]] + + label1 = translation_dict[0] + label2 = translation_dict[1] + label3 = translation_dict[2] + label4 = translation_dict[3] + label5 = translation_dict[4] + label6 = translation_dict[5] + else: + img_loc = os.path.join(self.path, self.folder[idx]) + image = Image.open(img_loc).convert('RGB') + single_img = self.transform(image) + + if self.train: + return {'image':single_img, 'labels': {"label_hair_dense": label1, + "label_hair_short": label2, + "label_hair_medium": label3, + "label_black_frame": label4, + "label_ruler_mark": label5, + "label_other": label6 + } + } + else: + return {'image':single_img, 'name': self.folder[idx]} + +def criterion(loss_func,outputs,pictures): + losses = 0 + for _, key in enumerate(outputs): + losses += loss_func(outputs[key], pictures['labels'][f'label_{key}'].to(device)) + return losses + +def training(model, device, lr_rate,epochs, train_loader, wandb_flag=True): + num_epochs = epochs + losses = [] + checkpoint_losses = [] + + optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate) + n_total_steps = len(train_loader) + + loss_func = nn.CrossEntropyLoss() + + for epoch in range(num_epochs): + for i, pictures in enumerate(train_loader): + images = pictures['image'].to(device) + pictures = pictures + + outputs = model(images) + + loss = criterion(loss_func,outputs, pictures) + losses.append(loss.item()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (i+1) % (int(n_total_steps/1)) == 0: + checkpoint_loss = torch.tensor(losses).mean().item() + checkpoint_losses.append(checkpoint_loss) + print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {checkpoint_loss:.4f}') + if wandb_flag: + wandb.log({f'train/training_loss': checkpoint_loss, 'epoch':epoch+1}) + if (i+1) % (int(n_total_steps/1)) == 0: + n_correct,n_samples,n_class_correct,n_class_samples = validation(model, test_loader, len(images), + classes_hair_dense, classes_hair_short, classes_hair_medium, + classes_black_frame, classes_ruler_mark, classes_other) + class_acc(n_correct, n_samples, n_class_correct, n_class_samples, class_list, wandb_flag) + + return checkpoint_losses, optimizer + +def validation(model, dataloader, batch_size, *args): + + with torch.no_grad(): + n_correct = [] + n_class_correct = [] + n_class_samples = [] + n_samples = 0 + + for arg in args: + n_correct.append(len(arg)) + n_class_correct.append([0 for _ in range(len(arg))]) + n_class_samples.append([0 for _ in range(len(arg))]) + + for pictures in dataloader: + images = pictures['image'].to(device) + outputs = model(images) + labels = [pictures['labels'][picture].to(device) for picture in pictures['labels']] + + for i,out in enumerate(outputs): + _, predicted = torch.max(outputs[out],1) + n_correct[i] += (predicted == labels[i]).sum().item() + + if i == 0: + n_samples += labels[i].size(0) + for k in range(batch_size): + label = labels[i][k] + pred = predicted[k] + if (label == pred): + n_class_correct[i][label] += 1 + n_class_samples[i][label] += 1 + + return n_correct,n_samples,n_class_correct,n_class_samples + +def class_acc(n_correct,n_samples,n_class_correct,n_class_samples,class_list, wandb_flag=True): + for i in range(len(class_list)): + print("-------------------------------------------------") + acc = 100.0 * n_correct[i] / n_samples if n_samples != 0 else 0 + print(color.BOLD + color.RED + f'Overall class performance: {round(acc,1)} %' + color.END) + for k in range(len(class_list[i])): + acc = 100.0 * n_class_correct[i][k] / n_class_samples[i][k] if n_class_samples[i][k] != 0 else 0 + print(f'Accuracy of {class_list[i][k]}: {round(acc,1)} %') + if wandb_flag: + wandb.log({'val/Acc_'+class_list[i][k]: round(acc,1)}) + print("-------------------------------------------------") + +def test(model, dataloader, save_path): + file = open(save_path,"w") + with torch.no_grad(): + for pictures in dataloader: + images = pictures['image'].to(device) + outputs = model(images) + img_labels = [pictures['name']] + for out in outputs: + _, predicted = torch.max(outputs[out],1) + img_labels.append([str(j) for j in predicted.cpu().tolist()]) + file.writelines([','.join(line)+'\n' for line in list(zip(*img_labels))]) + file.close() + +def get_args_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--img_path", help="path to all images") + parser.add_argument( + "--ann_path", + type=str, + default=None, + help="path to annotations (default: None)", + ) + parser.add_argument( + "--mode", + default="val", + choices=["train", "test", "val"], + help="mode for proces which will be done", + ) + parser.add_argument( + "--ratio", type=float, default=0.8, help="train/test ratio (default: 0.8)" + ) + parser.add_argument( + "--lr", type=float, default=0.0001, help="learning rate (default: 1e-4)" + ) + parser.add_argument( + "--epochs", + type=int, + default=30, + metavar="EPOCHS", + help="number of epochs to train (default: 30)", + ) + parser.add_argument( + "--num_workers", type=int, default=4, help="number of workers (default: 4)" + ) + parser.add_argument( + "--batch", type=int, default=16, help="batch size (default: 16)" + ) + parser.add_argument("--model_path", help="path to save or read model", + default="multiclasificator_efficientnet-b2_uGAN.pth") + parser.add_argument("--save_path", help="path to save pseudoannotations", + default="annotations.csv") + parser.add_argument( + "--seed", type=int, default=2022, help="random seed (default: 2022)" + ) + + # wandb settings + parser.add_argument( + "--wandb_flag", + action="store_true", + default=False, + help="Launch experiment and log metrics with wandb", + ) + return parser + + +if __name__ == "__main__": + parser = get_args_parser() + args = parser.parse_args() + + # set the seed for reproducibility + seed = args.seed + torch.manual_seed(seed) + np.random.seed(seed) + + #Getting the data + DATA_DIR = args.img_path + + # model path to load trained model or to save model after training + PATH = args.model_path + + # choose mode beetwen val, train, test + mode = args.mode + label_flag = False if args.mode == 'test' else True + + # source annotation path + annotationfile_path = args.ann_path + + #Pre-processing transformations + data_transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize((256,256)), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) + ]) + + classes_hair_dense = ['None_hair_dense', 'Yes_hair_dense'] + classes_hair_short = ['None_hair_short', 'Yes_hair_short'] + classes_hair_medium = ['None_hair_medium', 'Yes_hair_medium'] + classes_black_frame = ['None_black_frame', 'Yes_black_frame'] + classes_ruler_mark = ['None_ruler_mark', 'Yes_ruler_mark'] + classes_other = ['None_other', 'Yes_other'] + header = [ + "hair_dense", + "hair_short", + "hair_medium", + "black_frame", + "ruler_mark", + "other" + ] + class_list = [classes_hair_dense,classes_hair_short,classes_hair_medium,classes_black_frame,classes_ruler_mark,classes_other] + + dataset = BiasDataset(root_path=DATA_DIR, + annotationfile_path=annotationfile_path, + transform=data_transforms, + train=label_flag) + + if mode == 'train': + #Split the data in training and testing + train_val_ratio = args.ratio + train_len = round(len(dataset) * train_val_ratio) + val_len = len(dataset) - train_len + train_set, val_set = torch.utils.data.random_split(dataset, [train_len, val_len]) + + #Create the dataloader for each dataset + train_loader = DataLoader(train_set, batch_size=args.batch, shuffle=True, + num_workers=args.num_workers, drop_last=True) + test_loader = DataLoader(val_set, batch_size=args.batch, shuffle=False, + num_workers=args.num_workers, drop_last=True) + if args.wandb_flag: + wandb.init(project="dai-healthcare", entity='eyeforai', group='cls_biases', + config={"model": "efficientnet-b2"}) + + # define model + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = MultilabelClassifier().to(device) + checkpoint_losses, optimizer = training(model, device, args.lr, args.epochs, train_loader, + wandb_flag=args.wandb_flag) + + torch.save({ + 'epoch': args.epochs, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': checkpoint_losses[-1], + }, PATH) + n_correct,n_samples,n_class_correct,n_class_samples = validation(model, test_loader, args.batch, + classes_hair_dense, classes_hair_short, classes_hair_medium, + classes_black_frame, classes_ruler_mark, classes_other) + class_acc(n_correct, n_samples, n_class_correct, n_class_samples, class_list, wandb_flag=args.wandb_flag) + else: + test_loader = DataLoader(dataset, batch_size=args.batch, shuffle=False, + num_workers=args.num_workers, drop_last=True) + # define model + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = MultilabelClassifier().to(device) + checkpoint = torch.load(PATH) + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + if mode == 'test': + # path to save annotations + SAVE_PATH = args.save_path + test(model, test_loader, SAVE_PATH) + elif mode == 'val': + n_correct,n_samples,n_class_correct,n_class_samples = validation(model,test_loader, args.batch, + classes_hair_dense, classes_hair_short, classes_hair_medium, + classes_black_frame, classes_ruler_mark, classes_other) + class_acc(n_correct, n_samples, n_class_correct, n_class_samples, class_list, wandb_flag=False) + else: + print("Wrong mode!") diff --git a/train.py b/train.py new file mode 100644 index 0000000..088948b --- /dev/null +++ b/train.py @@ -0,0 +1,238 @@ +import os +import time +import random +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from sklearn.metrics import roc_auc_score +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data.sampler import RandomSampler +from warmup_scheduler import GradualWarmupScheduler +from sklearn.model_selection import train_test_split +from dataset import get_df, get_transforms, MelanomaDataset +from models import Effnet_Melanoma + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--save-name', type=str, required=True) + parser.add_argument('--data-dir', type=str, default='.') + parser.add_argument('--image-size', type=int, required=True) + parser.add_argument('--enet-type', type=str, required=True) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--num-workers', type=int, default=32) + parser.add_argument('--init-lr', type=float, default=3e-5) + parser.add_argument('--out-dim', type=int, default=2) + parser.add_argument('--n-epochs', type=int, default=20) + parser.add_argument('--use-meta', action='store_true') + parser.add_argument('--model-dir', type=str, default='./weights') + parser.add_argument('--log-dir', type=str, default='./logs') + parser.add_argument('--CUDA_VISIBLE_DEVICES', type=str, default='0') + parser.add_argument('--n-meta-dim', type=str, default='512,128') + + args, _ = parser.parse_known_args() + return args + + +def set_seed(seed=0): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +class GradualWarmupSchedulerV2(GradualWarmupScheduler): + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): + super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler) + def get_lr(self): + if self.last_epoch > self.total_epoch: + if self.after_scheduler: + if not self.finished: + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] + self.finished = True + return self.after_scheduler.get_lr() + return [base_lr * self.multiplier for base_lr in self.base_lrs] + if self.multiplier == 1.0: + return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] + else: + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + + +def train_epoch(model, loader, optimizer): + model.train() + train_loss = [] + bar = tqdm(loader) + for (data, target) in bar: + + optimizer.zero_grad() + + if args.use_meta: + data, meta = data + data, meta, target = data.to(device), meta.to(device), target.to(device) + logits = model(data, meta) + else: + data, target = data.to(device), target.to(device) + logits = model(data) + + loss = criterion(logits, target) + loss.backward() + + if args.image_size in [896,576]: + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + + loss_np = loss.detach().cpu().numpy() + train_loss.append(loss_np) + smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100) + bar.set_description('loss: %.5f, smth: %.5f' % (loss_np, smooth_loss)) + + train_loss = np.mean(train_loss) + return train_loss + + +def get_trans(img, I): + if I >= 4: + img = img.transpose(2, 3) + if I % 4 == 0: + return img + elif I % 4 == 1: + return img.flip(2) + elif I % 4 == 2: + return img.flip(3) + elif I % 4 == 3: + return img.flip(2).flip(3) + + +def val_epoch(model, loader, mel_idx, n_test=1, get_output=False): + model.eval() + val_loss = [] + LOGITS = [] + PROBS = [] + TARGETS = [] + with torch.no_grad(): + for (data, target) in tqdm(loader): + + if args.use_meta: + data, meta = data + data, meta, target = data.to(device), meta.to(device), target.to(device) + logits = torch.zeros((data.shape[0], args.out_dim)).to(device) + probs = torch.zeros((data.shape[0], args.out_dim)).to(device) + for I in range(n_test): + l = model(get_trans(data, I), meta) + logits += l + probs += l.softmax(1) + else: + data, target = data.to(device), target.to(device) + logits = torch.zeros((data.shape[0], args.out_dim)).to(device) + probs = torch.zeros((data.shape[0], args.out_dim)).to(device) + for I in range(n_test): + l = model(get_trans(data, I)) + logits += l + probs += l.softmax(1) + logits /= n_test + probs /= n_test + + LOGITS.append(logits.detach().cpu()) + PROBS.append(probs.detach().cpu()) + TARGETS.append(target.detach().cpu()) + + loss = criterion(logits, target) + val_loss.append(loss.detach().cpu().numpy()) + + val_loss = np.mean(val_loss) + LOGITS = torch.cat(LOGITS).numpy() + PROBS = torch.cat(PROBS).numpy() + TARGETS = torch.cat(TARGETS).numpy() + + if get_output: + return LOGITS, PROBS + else: + acc = (PROBS.argmax(1) == TARGETS).mean() * 100. + auc = roc_auc_score((TARGETS == mel_idx).astype(float), PROBS[:, mel_idx]) + return val_loss, acc, auc + + +def run(df_train, df_valid, meta_features, n_meta_features, transforms_train, transforms_val, mel_idx): + + dataset_train = MelanomaDataset(df_train, 'train', meta_features, transform=transforms_train) + dataset_valid = MelanomaDataset(df_valid, 'valid', meta_features, transform=transforms_val) + train_loader = torch.utils.data.DataLoader( + dataset_train, batch_size=args.batch_size, + sampler=RandomSampler(dataset_train), num_workers=args.num_workers) + valid_loader = torch.utils.data.DataLoader( + dataset_valid, batch_size=args.batch_size, num_workers=args.num_workers) + + model = Effnet_Melanoma( + args.enet_type, + n_meta_features=n_meta_features, + n_meta_dim=[int(nd) for nd in args.n_meta_dim.split(',')], + out_dim=args.out_dim + ) + model = model.to(device) + + auc_max = 0. + model_file = os.path.join(args.model_dir, f'{args.save_name}_best.pth') + model_file2 = os.path.join(args.model_dir, f'{args.save_name}_final.pth') + + optimizer = optim.Adam(model.parameters(), lr=args.init_lr) + scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, args.n_epochs - 1) + scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=1, after_scheduler=scheduler_cosine) + + print(len(dataset_train), len(dataset_valid)) + + for epoch in range(1, args.n_epochs + 1): + print(time.ctime(), f'Epoch {epoch}') + + train_loss = train_epoch(model, train_loader, optimizer) + val_loss, acc, auc = val_epoch(model, valid_loader, mel_idx) + + content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {train_loss:.5f}, valid loss: {(val_loss):.5f}, acc: {(acc):.4f}, auc: {(auc):.6f}.' + print(content) + with open(os.path.join(args.log_dir, f'log_{args.save_name}.txt'), 'a') as appender: + appender.write(content + '\n') + + scheduler_warmup.step() + if epoch==2: scheduler_warmup.step() # bug workaround + + if auc > auc_max: + print('auc_max ({:.6f} --> {:.6f}). Saving model ...'.format(auc_max, auc)) + torch.save(model.state_dict(), model_file) + auc_max = auc + + torch.save(model.state_dict(), model_file2) + + +def main(): + + df_train, df_test, meta_features, n_meta_features, mel_idx = get_df( + args.data_dir, + args.use_meta + ) + + transforms_train, transforms_val = get_transforms(args.image_size) + train_split, valid_split = train_test_split( + df_train, stratify=df_train.target, test_size=0.20, random_state=42) + df_train = pd.DataFrame(train_split) + df_valid = pd.DataFrame(valid_split) + + run(df_train, df_valid, meta_features, n_meta_features, transforms_train, transforms_val, mel_idx) + + +if __name__ == '__main__': + + args = parse_args() + os.makedirs(args.model_dir, exist_ok=True) + os.makedirs(args.log_dir, exist_ok=True) + os.environ['CUDA_VISIBLE_DEVICES'] = args.CUDA_VISIBLE_DEVICES + + set_seed() + + device = torch.device('cuda') + criterion = nn.CrossEntropyLoss() + + main()