Skip to content
This repository has been archived by the owner on Feb 15, 2025. It is now read-only.

Commit

Permalink
illustrate tent by image corruption example
Browse files Browse the repository at this point in the history
to illustrate the tent method and fully test-time adaptation setting,
we provide an example for adaptation to image corruptions.

this is simply *example code* for explanation, not *reference code* for
reproduction. that said, experimenting with this should give results
that are representative.

reference code will follow to reproduce our ImageNet-C results
  • Loading branch information
Dequan Wang authored and shelhamer committed Apr 15, 2021
1 parent 06f7ae0 commit 1ffcd70
Show file tree
Hide file tree
Showing 10 changed files with 628 additions and 0 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,36 @@ Dequan Wang\*, Evan Shelhamer\*, Shaoteng Liu, Bruno Olshausen, and Trevor Darre
Tent equips a model to adapt itself to new and different data ☀️ 🌧 ❄️ during testing.
Tent updates online and batch-by-batch to reduce error on dataset shifts like corruptions, simulation-to-real discrepancies, and other differences between training and testing data.

Our **example code** illustrates the method and provides representative results for image corruptions on CIFAR-10-C.
Note that the exact details of the model, optimization, etc. differ from the paper, so this is not for reproduction, but for explanation.

Please check back soon for our **reference code** to reproduce and extend tent!

## Example: Adapting to Image Corruptions on CIFAR-10-C

This example compares a baseline without adaptation (base), test-time normalization that updates feature statistics during testing (norm), and our method for entropy minimization during testing (tent).

- Dataset: [CIFAR-10-C](https://github.com/hendrycks/robustness/), with 15 corruption types and 5 levels.
- Model: [WRN-28-10](https://github.com/RobustBench/robustbench), the default model for RobustBench.

**Usage**:

```python
python cifar10c.py --cfg cfgs/base.yaml
python cifar10c.py --cfg cfgs/norm.yaml
python cifar10c.py --cfg cfgs/tent.yaml
```

**Result**: tent reduces the error (%) across corruption types at the most severe level of corruption (level 5).

| | mean | gauss_noise | shot_noise | impulse_noise | defocus_blur | glass_blur | motion_blur | zoom_blur | snow | frost | fog | brightness | contrast | elastic_trans | pixelate | jpeg |
| ---------------------------------------------------- | ---: | ----------: | ---------: | ------------: | -----------: | ---------: | ----------: | --------: | ---: | ----: | ---: | ---------: | -------: | ------------: | -------: | ---: |
| [base](./cifar10c.py) | 43.5 | 72.3 | 65.7 | 72.9 | 46.9 | 54.3 | 34.8 | 42.0 | 25.1 | 41.3 | 26.0 | 9.3 | 46.7 | 26.6 | 58.5 | 30.3 |
| [norm](./norm.py) | 20.4 | 28.1 | 26.1 | 36.3 | 12.8 | 35.3 | 14.2 | 12.1 | 17.3 | 17.4 | 15.3 | 8.4 | 12.6 | 23.8 | 19.7 | 27.3 |
| [tent](./tent.py) | 18.6 | 24.8 | 23.5 | 33.0 | 11.9 | 31.9 | 13.7 | 10.8 | 15.9 | 16.2 | 13.7 | 7.9 | 12.1 | 22.0 | 17.3 | 24.2 |

See the full results for this example in the [wandb report](https://wandb.ai/tent/cifar10c).

## Correspondence

Please contact Dequan Wang and Evan Shelhamer at dqwang AT cs.berkeley.edu and shelhamer AT google.com.
Expand Down
34 changes: 34 additions & 0 deletions cfgs/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
CORRUPTION:
MODEL: Standard
EVAL_ONLY: True
SEVERITY:
- 5
- 4
- 3
- 2
- 1
TYPE:
- gaussian_noise
- shot_noise
- impulse_noise
- defocus_blur
- glass_blur
- motion_blur
- zoom_blur
- snow
- frost
- fog
- brightness
- contrast
- elastic_transform
- pixelate
- jpeg_compression
BN:
FUNC: FrozenMeanVarBatchNorm2d
OPTIM:
BATCH_SIZE: 200
METHOD: Adam
ITER: 1
BETA: 0.9
LR: 1e-3
WD: 0.
34 changes: 34 additions & 0 deletions cfgs/norm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
CORRUPTION:
MODEL: Standard
EVAL_ONLY: True
SEVERITY:
- 5
- 4
- 3
- 2
- 1
TYPE:
- gaussian_noise
- shot_noise
- impulse_noise
- defocus_blur
- glass_blur
- motion_blur
- zoom_blur
- snow
- frost
- fog
- brightness
- contrast
- elastic_transform
- pixelate
- jpeg_compression
BN:
FUNC: TrainModeBatchNorm2d
OPTIM:
BATCH_SIZE: 200
METHOD: Adam
ITER: 1
BETA: 0.9
LR: 1e-3
WD: 0.
34 changes: 34 additions & 0 deletions cfgs/tent.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
CORRUPTION:
MODEL: Standard
EVAL_ONLY: False
SEVERITY:
- 5
- 4
- 3
- 2
- 1
TYPE:
- gaussian_noise
- shot_noise
- impulse_noise
- defocus_blur
- glass_blur
- motion_blur
- zoom_blur
- snow
- frost
- fog
- brightness
- contrast
- elastic_transform
- pixelate
- jpeg_compression
BN:
FUNC: TrainModeBatchNorm2d
OPTIM:
BATCH_SIZE: 200
METHOD: Adam
ITER: 1
BETA: 0.9
LR: 1e-3
WD: 0.
29 changes: 29 additions & 0 deletions cifar10c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging

import torch

from robustbench.data import load_cifar10c
from robustbench.utils import clean_accuracy as accuracy

from tent import tent
from conf import cfg, load_cfg_fom_args


def evaluate(cfg_file):
load_cfg_fom_args(cfg_file=cfg_file,
description="CIFAR-10-C evaluation.")
logger = logging.getLogger(__name__)
for severity in cfg.CORRUPTION.SEVERITY:
for corruption_type in cfg.CORRUPTION.TYPE:
x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX,
severity, cfg.DATA_DIR, False,
[corruption_type])
x_test, y_test = x_test.cuda(), y_test.cuda()
model = tent(cfg.CORRUPTION.MODEL)
acc = accuracy(model, x_test, y_test, cfg.OPTIM.BATCH_SIZE)
logger.info('accuracy [{}{}]: {:.2%}'.format(
corruption_type, severity, acc))


if __name__ == '__main__':
evaluate('cifar10c.yaml')
33 changes: 33 additions & 0 deletions cifar10c.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
CORRUPTION:
MODEL: Standard
SEVERITY:
- 5
- 4
- 3
- 2
- 1
TYPE:
- gaussian_noise
- shot_noise
- impulse_noise
- defocus_blur
- glass_blur
- motion_blur
- zoom_blur
- snow
- frost
- fog
- brightness
- contrast
- elastic_transform
- pixelate
- jpeg_compression
BN:
FUNC: TrainModeBatchNorm2d
OPTIM:
BATCH_SIZE: 200
METHOD: Adam
ITER: 1
BETA: 0.9
LR: 1e-3
WD: 0.
Loading

0 comments on commit 1ffcd70

Please sign in to comment.