Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit of the branch SSL_comparison #205

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions collections/_posts/2024-12-02-SSL_comparison.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
---
layout: review
title: "Systematic comparison of semi-supervised and self-supervised learning for medical image classification"
tags: classification, semi-supervision, self-supervision
author: "Juliette Moreau"
cite:
authors: "Zhe Huang, Ruije Jiang, Shuchin Aeron and Michael C. Hughes"
title: "Systematic comparison of semi-supervised and self-supervised learning for medical image classification"
venue: "CVPR, 2024"
pdf: "https://arxiv.org/pdf/2307.08919"
---


# Highlights

* Benchmark of SSL methods on medical images.
* Implementation of a real life hyperparameters optimization.
* Comparison of semi-supervised and self-supervised learning.

# Note

Code is available here: https://github.com/tufts-ml/SSL-vs-SSL-benchmark
nathanpainchaud marked this conversation as resolved.
Show resolved Hide resolved

# Introduction

In medical domain there is often a lack of annotated data for deep learning model training while there is a lot of unlabeled data in data health records. Semi-supervised and self-supervised learning are methods developed to take advantage of these unlabeled images to improve classification accuracy. The first one train classifiers jointly with two loss terms, while the second is a two-stages approach, the first to learn deep representation and the second to fine-tune the classifier. However, the two methods are rarely compared. The first question they try to answer is: **Which recent semi- or self-supervised methods are likely to be most effective?** \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • In the medical domain

  • "while there is a lot of unlabeled data in data health records" → Awkward phrasing. Suggested: "while a large amount of unlabeled data exists in health records."

  • The first one trains classifiers jointly

  • the second is a two-stage approach

  • General comment : maybe you could gain in lisibility with bullet points for Semi and SSL such as :

  • Semi : blablabla introductive content

  • SSL: introductive content with two-stage approach

  1. Pretraining
  2. Finetuning

Suggestion of rephrasing of chatGPT (2ème review pour le prix d'une hihi):
In the medical domain, annotated data for deep learning is often scarce, while a large amount of unlabeled data exists in health records. Semi-supervised and self-supervised learning methods have been developed to leverage this unlabeled data and improve classification accuracy. The former trains classifiers jointly with two loss terms, while the latter follows a two-stage approach: first, learning deep representations, and second, fine-tuning the classifier. However, these two methods are rarely compared.

But those methods are very sensitive to hyperparameters, and benchmarks often don't consider this (either no hyperparameters tuning or tuning with a huge labeled validation set, bigger than the training set, which is not realistic). In this paper, they try to take that in account by answering the question: **Given limited available labeled data and limited compute, is hyperparameter tuning worthwhile?**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmarks often don't consider this : Not very clear imo, it could be more explicit. Maybe you could add another sentence instead of the note in parentheses


# Methods

## Definition of SSL

A unified perspective for different learning paradigms is presented with two available datasets: $$L$$ which is a small labeled dataset and $$U$$ the unlabeled dataset with $$ \lvert U \rvert >> \lvert L \rvert $$. The learning process can be written as:

$$ v^*, w^* ← argmin_{v,w} \sum_{x,y \in L} \lambda^L l^L(y, g_w(f_v(x))) + \sum_{x \in U} \lambda^U l^U(x, f_v,g_w) $$

where $$l^L$$ and $$l^U$$ are the loss functions linked to the labeled and unlabeled datasets and $$\lambda^L$$ and $$\lambda_{U}$$ the associated weights. In addition, $$f_v(\cdot)$$ denotes a neural network backbone with parameters $$v$$ that produces an embedding of the input images and $$g_w(\cdot)$$ denotes a final linear softmax classification layer with parameters $$w$$.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$$\lambda_{U}$$ --> $$\lambda^{U}$$ for uniformity


| Method | $$\lambda^L$$ | $$\lambda^U$$ |
|-----------------|---------------|---------------|
| Supervised | 1 | 0 |
| Semi-supervised | >0 | >0 |
| Self-supervised | 0 then 1 | 1 then 0 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Self-supervised | 0 then 1 | 1 then 0" → Could be clearer. Suggested: "Self-supervised | Pretraining (0) → Fine-tuning (1) | Pretraining (1) → Fine-tuning (0)."
I agree with Chatgpt on this point but i can understand it doesn't fit nicely, so up to you


## Compared methods

**Supervised learning**

Only using the labeled $$(L)$$ dataset.

* Sup: to denote a classifier trained with classical multiclass cross entropy loss.
* MixUp: also multiclass cross entropy, but with mixup data augmentation (creating new training samples by mixing two samples from the original dataset).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"creating new training samples by mixing two samples from the original dataset" → Suggested: "by linearly combining two training samples and their corresponding labels."

* SupCon: using a supervised contrastive learning loss (pull together samples belonging to the same class in the embedded space while pushing apart from samples of other classes).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused by this subsection, I imagine this is both the methods and their notation for the compared supervised methods, maybe it is lacking an introductive sentence (which you did in the next subsections)


**Semi-supervised learning**

They selected 6 methods of the recent years to cover a wide spectrum of strategies and cost, to take advantage of the unlabeled $$(U)$$ and labeled $$(L)$$ datasets jointly.

* Pseudo-label: sets the label to unlabeled data to the class which has the maximum predicted probability.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"sets the label to unlabeled data to the class..." → Suggested: "assigns labels to unlabeled data based on the class with the highest predicted probability."

* Mean teacher: teacher weights are updated as exponential moving average of student weights after each sample, either the sample used is labeled or not.
* MixMatch: uses mixup data augmentation to combine label and unlabeled data to compute labels for this new input images.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine labeled and unlabeled data

* FixMatch: combination of pseudo-label method to set a label to unlabeled images if the prediction confidence is high enough with strong data augmentation, which prediction has to match the label set previously.
* FlexMatch: adding a curriculum notion to FixMatch to consider different learning status and learning difficulties of different classes, basically adjusting the threshold of confidence to define a class to an unlabeled image.
* CoMatch: learns jointly class probability and low-dimensional embeddings to impose a smoothness constraint and improve the pseudo-labels, while those last regularize the structure of the embeddings through graph based contrastive learning.

**Self-supervised learning**

In addition, 7 self-supervised methods are compared:

* SimCLR: contrastive learning self-supervised method based on data augmentation with a learning nonlinear transformation between the representation and contrastive loss to improve the learned representations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"with a learning nonlinear transformation..." → Awkward phrasing. Suggested: "using a learnable nonlinear transformation..."

* MOCO v2: for momentum contrastive, which uses a dynamic dictionary with a queue and a moving-averaged encoder that facilitates contrastive unsupervised learning.
* SwAV: contrastive learning that does not compare features directly but comparing different augmentation of the same images.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"...but comparing different augmentation of the same images." → Correct: "...but compares different augmentations of the same image."

* BYOL: it is based on two networks (online & target) that learn from each other, the online one learns to predict what will be the representation created by the target one from different augmentations of the same image.
* SimSiam: siamese networks maximizing the similarity between two augmentations of one image, with one making a prediction.
* DINO: named after self-distillation with no labels, several transformations of an image passed through student and teacher networks, similarity of the features extracted is computed as loss, and again the teacher weights are updated with an exponential moving average of the student parameters.
* Barlow Twins: introduction of an objective function to make as close as possible to identity matrix the cross-correlation matrix between the outputs of two identical networks fed with distorted versions of the same image to reduce redundancy in the components of the vectors.

## Datasets and training parameters

Four datasets with different characteristics are used to compare the 16 methods.

| Dataset | Image type | Number of classes | Resolution |
|-------------|-----------------------------|--------------------|------------|
| TissueMNIST | Kidney cortex cells | 8 | 28*28 |
| PathMNIST | Colorectal cancer histology | 9 | 28*28 |
| TMED-2 | Echocardiogram US images | 4 | 112*112 |
| AIROGS | Retina color fundus | 2 | 384*384 |

The details of the images' distribution are in the following table, with the number of images labeled or not by category. The two first datasets are entirely labeled, so a part of the images were considered as unlabeled to fit the training paradigm of SSL.

![](/collections/images/SSL_comparison/datasets_description.jpg)

Here are some images examples.

![](/collections/images/SSL_comparison/datasets_example.jpg)

# Experimental setup

For the low resolution images ResNet-18 is used, WideResNet-28-2 is used for TMED-2 and ResNet-18 or 50 are tested for AIROGS dataset. \
A maximum of 200 epochs is performed and training is stopped if the balanced accuracy plateaus for 20 consecutive epochs.

Hyperparameters (learning rate, weight decay and unlabeled loss weight, in addition to the parameters specific for each method) optimization is made for each method and dataset. To fit with real life constraints, they fixed a certain number of hours with a NVIDIA A100 GPU for each hyperparameter optimization (25h for PathMNIST, 50h for TissueMNIST and 100h for TMED-2 and AIROGS). Within this fixed budget, a random search of hyperparameters is done, tacking the best with validation set.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taking the best using the validation set.


# Results

Five separate trials were performed and mean balanced accuracy is calculated each 30 or 60 minutes on validation and test sets.

Balanced accuracy on test set over time is represented for each dataset and method. As there is a rough monotonic improvement in test performance over time despite using a realistic-size validation set, they conclude that checkpoint selection and hyperparameters optimization can be effective with a realistically-sized validation set.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"As there is a rough monotonic improvement in test performance..." → Suggested: " As a roughly monotonic improvement in test performance is observed..."


![](/collections/images/SSL_comparison/accuracy_evolution.jpg)

All the methods are compared and none of them clearly stands out from the other depending on the dataset. So they measure the balanced accuracy relative gain over the best supervised method for each dataset. MixMatch represent the best overall choice.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MixMatch represents the best overall choice.


![](/collections/images/SSL_comparison/results_comparison.jpg)

The tuning strategy, even though it reserves data for a validation set, is competitive in comparison with transferring the hyperparameters from another dataset (natural images or medical images) as it may have been done in other benchmark papers.

![](/collections/images/SSL_comparison/hyperparameters_strategy_evaluation.jpg)


# Discussion

There is a real gain in performance of performing SSL, but it depends on the dataset (less gain on TMED-2). However, they insist on the importance of the selection of the hyperparameters and proper evaluation protocol depending on specific needs. Especially MixMatch requires cautious hyperparameters tuning when applied to a new dataset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • "There is a real gain in performance of performing SSL" → Suggested: "SSL provides a tangible performance gain..."
  • "less gain on TMED-2" → Unclear. Suggested: "a smaller performance improvement is observed for TMED-2."
  • "MixMatch requires cautious hyperparameters tuning" → Suggested: "MixMatch requires careful hyperparameter tuning..."

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading