Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add All4One #382

Merged
merged 11 commits into from
Jan 8, 2024
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ The library is self-contained, but it is possible to use the models outside of s
---

## Methods available
* [All4One](https://openaccess.thecvf.com/content/ICCV2023/html/Estepa_All4One_Symbiotic_Neighbour_Contrastive_Learning_via_Self-Attention_and_Redundancy_Reduction_ICCV_2023_paper.html)
* [Barlow Twins](https://arxiv.org/abs/2103.03230)
* [BYOL](https://arxiv.org/abs/2006.07733)
* [DeepCluster V2](https://arxiv.org/abs/2006.09882)
Expand Down Expand Up @@ -216,6 +217,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o

| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint |
|--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:|
| All4One | ResNet18 | 1000 | :x: | 93.24 | 99.88 | [:link:](https://drive.google.com/drive/folders/1dtYmZiftruQ7B2PQ8fo44wguCZ0eSzAd?usp=sharing) |
| Barlow Twins | ResNet18 | 1000 | :x: | 92.10 | 99.73 | [:link:](https://drive.google.com/drive/folders/1L5RAM3lCSViD2zEqLtC-GQKVw6mxtxJ_?usp=sharing) |
| BYOL | ResNet18 | 1000 | :x: | 92.58 | 99.79 | [:link:](https://drive.google.com/drive/folders/1KxeYAEE7Ev9kdFFhXWkPZhG-ya3_UwGP?usp=sharing) |
|DeepCluster V2| ResNet18 | 1000 | :x: | 88.85 | 99.58 | [:link:](https://drive.google.com/drive/folders/1tkEbiDQ38vZaQUsT6_vEpxbDxSUAGwF-?usp=sharing) |
Expand All @@ -237,6 +239,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o

| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint |
|--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:|
| All4One | ResNet18 | 1000 | :x: | 72.17 | 93.35 | [:link:](https://drive.google.com/drive/folders/1oQcC80XPr-Wxhjs-PEqD_8VhUa_izqeZ?usp=sharing) |
| Barlow Twins | ResNet18 | 1000 | :x: | 70.90 | 91.91 | [:link:](https://drive.google.com/drive/folders/1hDLSApF3zSMAKco1Ck4DMjyNxhsIR2yq?usp=sharing) |
| BYOL | ResNet18 | 1000 | :x: | 70.46 | 91.96 | [:link:](https://drive.google.com/drive/folders/1hwsEdsfsUulD2tAwa4epKK9pkSuvFv6m?usp=sharing) |
|DeepCluster V2| ResNet18 | 1000 | :x: | 63.61 | 88.09 | [:link:](https://drive.google.com/drive/folders/1gAKyMz41mvGh1BBOYdc_xu6JPSkKlWqK?usp=sharing) |
Expand All @@ -257,6 +260,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o

| Method | Backbone | Epochs | Dali | Acc@1 (online) | Acc@1 (offline) | Acc@5 (online) | Acc@5 (offline) | Checkpoint |
|-------------------------|:--------:|:------:|:------------------:|:--------------:|:---------------:|:--------------:|:---------------:|:----------:|
| All4One | ResNet18 | 400 | :heavy_check_mark: | 81.93 | - | 96.23 | - | [:link:](https://drive.google.com/drive/folders/1bJCRLP5Rz_JEylNq9C4sY3ccYZSchUGR?usp=sharing) |
| Barlow Twins :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.38 | 80.16 | 95.28 | 95.14 | [:link:](https://drive.google.com/drive/folders/1rj8RbER9E71mBlCHIZEIhKPUFn437D5O?usp=sharing) |
| BYOL :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.16 | 80.32 | 95.02 | 94.94 | [:link:](https://drive.google.com/drive/folders/1riOLjMawD_znO4HYj8LBN2e1X4jXpDE1?usp=sharing) |
| DeepCluster V2 | ResNet18 | 400 | :x: | 75.36 | 75.4 | 93.22 | 93.10 | [:link:](https://drive.google.com/drive/folders/1d5jPuavrQ7lMlQZn5m2KnN5sPMGhHFo8?usp=sharing) |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ While the library is self contained, it is possible to use the models outside of

solo/methods/base
solo/methods/linear
solo/methods/all4one
solo/methods/barlow
solo/methods/byol
solo/methods/deepclusterv2
Expand Down
48 changes: 48 additions & 0 deletions docs/source/solo/methods/all4one.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
All4One
======

.. automethod:: solo.methods.all4one.All4One.__init__
:noindex:


add_model_specific_args
~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.add_model_specific_args
:noindex:

learnable_params
~~~~~~~~~~~~~~~~
.. autoattribute:: solo.methods.all4one.All4One.learnable_params
:noindex:

dequeue_and_enqueue
~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.dequeue_and_enqueue
:noindex:

find_nn
~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.find_nn
:noindex:

off_diagonal
~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.off_diagonal
:noindex:


save_NN
~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.save_NN
:noindex:


forward
~~~~~~~
.. automethod:: solo.methods.all4one.All4One.forward
:noindex:

training_step
~~~~~~~~~~~~~
.. automethod:: solo.methods.all4one.All4One.training_step
:noindex:
100 changes: 100 additions & 0 deletions docs/source/solo/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,103 @@ Whitening

.. automethod:: solo.utils.whitening.Whitening2d.__init__
:noindex:


PositionalEncoding1D
---------------------
:class:`PositionalEncoding1D` applies positional encoding to the last dimension of a 3D tensor.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding1D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding1D.forward
:noindex:

PositionalEncodingPermute1D
---------------------------
:class:`PositionalEncodingPermute1D` permutes the input tensor and applies 1D positional encoding.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute1D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute1D.forward
:noindex:

PositionalEncoding2D
---------------------
:class:`PositionalEncoding2D` applies positional encoding to the last two dimensions of a 4D tensor.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding2D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding2D.forward
:noindex:

PositionalEncodingPermute2D
---------------------------
:class:`PositionalEncodingPermute2D` permutes the input tensor and applies 2D positional encoding.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute2D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute2D.forward
:noindex:

PositionalEncoding3D
---------------------
:class:`PositionalEncoding3D` applies positional encoding to the last three dimensions of a 5D tensor.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding3D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncoding3D.forward
:noindex:

PositionalEncodingPermute3D
---------------------------
:class:`PositionalEncodingPermute3D` permutes the input tensor and applies 3D positional encoding.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute3D.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.PositionalEncodingPermute3D.forward
:noindex:

Summer
------
:class:`Summer` adds positional encoding to the original tensor.

__init__
~~~~~~~~
.. automethod:: solo.utils.positional_encoding.Summer.__init__
:noindex:

forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.Summer.forward
:noindex:

58 changes: 58 additions & 0 deletions scripts/pretrain/cifar/all4one.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
defaults:
- _self_
- augmentations: asymmetric.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "All4One-cifar100" # change here for cifar10
method: "all4one"
backbone:
name: "resnet18"
method_kwargs:
temperature: 0.2
proj_hidden_dim: 2048
pred_hidden_dim: 4096
proj_output_dim: 256
queue_size: 98304
momentum:
base_tau: 0.99
final_tau: 1.0
data:
dataset: cifar100 # change here for cifar10
train_path: "./datasets/"
val_path: "./datasets/"
format: "image_folder"
num_workers: 4
optimizer:
name: "lars"
batch_size: 256
lr: 1.0
classifier_lr: 0.1
weight_decay: 1e-5
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: False

# overwrite PL stuff
max_epochs: 1000
devices: [0]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
55 changes: 55 additions & 0 deletions scripts/pretrain/imagenet-100/all4one.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
defaults:
- _self_
- augmentations: asymmetric.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "all4one-imagenet100"
method: "all4one"
backbone:
name: "resnet18"
method_kwargs:
temperature: 0.2
proj_hidden_dim: 2048
pred_hidden_dim: 4096
proj_output_dim: 256
queue_size: 98340
data:
dataset: imagenet100
train_path: "./datasets/imagenet-100/train"
val_path: "./datasets/imagenet-100/val"
format: "dali"
num_workers: 4
optimizer:
name: "lars"
batch_size: 128
lr: 1.0
classifier_lr: 0.1
weight_decay: 1e-5
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: True

# overwrite PL stuff
max_epochs: 400
devices: [0, 1]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
4 changes: 4 additions & 0 deletions solo/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from solo.methods.vibcreg import VIbCReg
from solo.methods.vicreg import VICReg
from solo.methods.wmse import WMSE
from solo.methods.all4one import All4One


METHODS = {
# base classes
Expand All @@ -61,6 +63,7 @@
"vibcreg": VIbCReg,
"vicreg": VICReg,
"wmse": WMSE,
"all4one": All4One,
}
__all__ = [
"BarlowTwins",
Expand All @@ -83,4 +86,5 @@
"VIbCReg",
"VICReg",
"WMSE",
"All4One",
]
Loading
Loading