Skip to content

Commit

Permalink
Merge pull request #63 from qbouniot/dev_mixup
Browse files Browse the repository at this point in the history
✨ Mixup variants + Cross validation + Temperature scaling in routines
  • Loading branch information
alafage authored Nov 3, 2023
2 parents f7e4c93 + ffa2a37 commit 97695b9
Show file tree
Hide file tree
Showing 53 changed files with 1,940 additions and 560 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ repos:
language: python
types_or: [python, pyi]
exclude: ^auto_tutorials_source/
- id: pytest-check
name: pytest-check
entry: pytest
language: system
pass_filenames: false
# - id: pytest-check
# name: pytest-check
# entry: pytest
# language: system
# pass_filenames: false
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,26 @@ A quickstart is available at [torch-uncertainty.github.io/quickstart](https://to
To date, the following deep learning baselines have been implemented:

- Deep Ensembles
- MC-Dropout
- MC-Dropout - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html)
- BatchEnsemble
- Masksembles
- MIMO
- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873))
- Bayesian Neural Networks :construction: Work in progress :construction:
- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html)
- Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html)
- Regression with Beta Gaussian NLL Loss
- Deep Evidential Classification & Regression
- Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html)

### Augmentation methods

The following data augmentation methods have been implemented:

- Mixup, MixupIO, RegMixup, WarpingMixup

### Post-processing methods

To date, the following post-processing methods have been implemented:

- Temperature, Vector, & Matrix scaling
- Temperature, Vector, & Matrix scaling - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html)

## Tutorials

Expand Down
33 changes: 33 additions & 0 deletions docs/source/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,39 @@ For Monte-Carlo Dropout, consider citing:
* Authors: *Yarin Gal and Zoubin Ghahramani*
* Paper: `ICML 2016 <https://arxiv.org/pdf/1506.02142.pdf>`__.

Data Augmentation Methods
-------------------------

Mixup
^^^^^

For Mixup, consider citing:

**mixup: Beyond Empirical Risk Minimization**

* Authors: *Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, and David Lopez-Paz*
* Paper: `ICLR 2018 <https://arxiv.org/pdf/1710.09412.pdf>`__.

MixupIO
^^^^^^^

For MixupIO, consider citing:

**On the Pitfall of Mixup for Uncertainty Calibration**

* Authors: *Deng-Bao Wang, Lanqing Li, Peilin Zhao, Pheng-Ann Heng, and Min-Ling Zhang*
* Paper: `CVPR 2023 <https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_On_the_Pitfall_of_Mixup_for_Uncertainty_Calibration_CVPR_2023_paper.pdf>`

Warping Mixup
^^^^^^^^^^^^^

For Warping Mixup, consider citing:

**Tailoring Mixup to Data using Kernel Warping functions**

* Authors: *Quentin Bouniot, Pavlo Mozharovskyi, and Florence d'Alché-Buc*
* Paper: `ArXiv 2023 <https://arxiv.org/abs/2311.01434>`__.

Post-Processing Methods
-----------------------

Expand Down
66 changes: 52 additions & 14 deletions experiments/classification/cifar10/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_uncertainty.baselines import ResNet
from torch_uncertainty.datamodules import CIFAR10DataModule
from torch_uncertainty.optimization_procedures import get_procedure
from torch_uncertainty.utils import csv_writer

if __name__ == "__main__":
args = init_args(ResNet, CIFAR10DataModule)
Expand All @@ -14,22 +15,59 @@
else:
root = Path(args.root)

net_name = f"{args.version}-resnet{args.arch}-cifar10"
if args.exp_name == "":
args.exp_name = f"{args.version}-resnet{args.arch}-cifar10"

# datamodule
args.root = str(root / "data")
dm = CIFAR10DataModule(**vars(args))

# model
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar10", args.version
),
style="cifar",
**vars(args),
)

cli_main(model, dm, root, net_name, args)
if args.opt_temp_scaling:
calibration_set = dm.get_test_set
elif args.val_temp_scaling:
calibration_set = dm.get_val_set
else:
calibration_set = None

if args.use_cv:
list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over)
list_model = []
for i in range(len(list_dm)):
list_model.append(
ResNet(
num_classes=list_dm[i].dm.num_classes,
in_channels=list_dm[i].dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar10", args.version
),
style="cifar",
calibration_set=calibration_set,
**vars(args),
)
)

results = cli_main(
list_model, list_dm, args.exp_dir, args.exp_name, args
)
else:
# model
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "cifar10", args.version
),
style="cifar",
calibration_set=calibration_set,
**vars(args),
)

results = cli_main(model, dm, args.exp_dir, args.exp_name, args)

for dict_result in results:
csv_writer(
Path(args.exp_dir) / Path(args.exp_name) / "results.csv",
dict_result,
)
64 changes: 53 additions & 11 deletions experiments/classification/tiny-imagenet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch_uncertainty import cli_main, init_args
from torch_uncertainty.baselines import ResNet
from torch_uncertainty.datamodules import TinyImageNetDataModule
from torch_uncertainty.optimization_procedures import get_procedure
from torch_uncertainty.utils import csv_writer


def optim_tiny(model: nn.Module) -> dict:
Expand All @@ -26,20 +28,60 @@ def optim_tiny(model: nn.Module) -> dict:
else:
root = Path(args.root)

net_name = f"{args.version}-resnet{args.arch}-tiny-imagenet"
# net_name = f"{args.version}-resnet{args.arch}-tiny-imagenet"
if args.exp_name == "":
args.exp_name = f"{args.version}-resnet{args.arch}-cifar10"

# datamodule
args.root = str(root / "data")
dm = TinyImageNetDataModule(**vars(args))

# model
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=optim_tiny,
style="cifar",
**vars(args),
)
if args.opt_temp_scaling:
calibration_set = dm.get_test_set
elif args.val_temp_scaling:
calibration_set = dm.get_val_set
else:
calibration_set = None

if args.use_cv:
list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over)
list_model = []
for i in range(len(list_dm)):
list_model.append(
ResNet(
num_classes=list_dm[i].dm.num_classes,
in_channels=list_dm[i].dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "tiny-imagenet", args.version
),
style="cifar",
calibration_set=calibration_set,
**vars(args),
)
)

results = cli_main(
list_model, list_dm, args.exp_dir, args.exp_name, args
)
else:
# model
model = ResNet(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=get_procedure(
f"resnet{args.arch}", "tiny-imagenet", args.version
),
calibration_set=calibration_set,
style="cifar",
**vars(args),
)

results = cli_main(model, dm, args.exp_dir, args.exp_name, args)

cli_main(model, dm, root, net_name, args)
for dict_result in results:
csv_writer(
Path(args.exp_dir) / Path(args.exp_name) / "results.csv",
dict_result,
)
Loading

0 comments on commit 97695b9

Please sign in to comment.