Skip to content
This repository has been archived by the owner on Feb 15, 2025. It is now read-only.

Commit

Permalink
turn adaptation methods into modules, simplify, and clarify wording
Browse files Browse the repository at this point in the history
- clarify status of example and reference code: do try the example!
- turn methods into importable modules, decoupled from their
  configuration and arguments, for ease of experimentation and adoption
- comment and document
  • Loading branch information
shelhamer committed Apr 15, 2021
1 parent 1ffcd70 commit fc5705f
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 268 deletions.
39 changes: 20 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,38 @@
This is the official project repository for [Tent: Fully-Test Time Adaptation by Entropy Minimization](https://openreview.net/forum?id=uXl3bZLkr3c) by
Dequan Wang\*, Evan Shelhamer\*, Shaoteng Liu, Bruno Olshausen, and Trevor Darrell (ICLR 2021, spotlight).

Tent equips a model to adapt itself to new and different data ☀️ 🌧 ❄️ during testing.
Tent updates online and batch-by-batch to reduce error on dataset shifts like corruptions, simulation-to-real discrepancies, and other differences between training and testing data.
⛺️ Tent equips a model to adapt itself to new and different data during testing ☀️ 🌧❄️.
Tented models adapt online and batch-by-batch to reduce error on dataset shifts like corruptions, simulation-to-real discrepancies, and other differences between training and testing data.
This kind of adaptation is effective and efficient: tent makes just one update per batch to not interrupt inference.

Our **example code** illustrates the method and provides representative results for image corruptions on CIFAR-10-C.
Note that the exact details of the model, optimization, etc. differ from the paper, so this is not for reproduction, but for explanation.
To illustrate the tent method and fully test-time adaptation setting we provide **example code** for adapting to image corruptions on CIFAR-10-C.
The purpose of the example is explanation, not reproduction: exact details of the model architecture, optimization settings, etc. may differ from the paper.
That said, the results should be representative, so do give it a try and experiment!

Please check back soon for our **reference code** to reproduce and extend tent!
Please check back soon for **reference code** to exactly reproduce the ImageNet-C results in the paper.

## Example: Adapting to Image Corruptions on CIFAR-10-C

This example compares a baseline without adaptation (base), test-time normalization that updates feature statistics during testing (norm), and our method for entropy minimization during testing (tent).
This example compares a baseline without adaptation (source), test-time normalization for updating feature statistics during testing (norm), and our method for entropy minimization during testing (tent).

- Dataset: [CIFAR-10-C](https://github.com/hendrycks/robustness/), with 15 corruption types and 5 levels.
- Model: [WRN-28-10](https://github.com/RobustBench/robustbench), the default model for RobustBench.

**Usage**:

```python
python cifar10c.py --cfg cfgs/base.yaml
python cifar10c.py --cfg cfgs/source.yaml
python cifar10c.py --cfg cfgs/norm.yaml
python cifar10c.py --cfg cfgs/tent.yaml
```

**Result**: tent reduces the error (%) across corruption types at the most severe level of corruption (level 5).

| | mean | gauss_noise | shot_noise | impulse_noise | defocus_blur | glass_blur | motion_blur | zoom_blur | snow | frost | fog | brightness | contrast | elastic_trans | pixelate | jpeg |
| ---------------------------------------------------- | ---: | ----------: | ---------: | ------------: | -----------: | ---------: | ----------: | --------: | ---: | ----: | ---: | ---------: | -------: | ------------: | -------: | ---: |
| [base](./cifar10c.py) | 43.5 | 72.3 | 65.7 | 72.9 | 46.9 | 54.3 | 34.8 | 42.0 | 25.1 | 41.3 | 26.0 | 9.3 | 46.7 | 26.6 | 58.5 | 30.3 |
| [norm](./norm.py) | 20.4 | 28.1 | 26.1 | 36.3 | 12.8 | 35.3 | 14.2 | 12.1 | 17.3 | 17.4 | 15.3 | 8.4 | 12.6 | 23.8 | 19.7 | 27.3 |
| [tent](./tent.py) | 18.6 | 24.8 | 23.5 | 33.0 | 11.9 | 31.9 | 13.7 | 10.8 | 15.9 | 16.2 | 13.7 | 7.9 | 12.1 | 22.0 | 17.3 | 24.2 |
| | mean | gauss_noise | shot_noise | impulse_noise | defocus_blur | glass_blur | motion_blur | zoom_blur | snow | frost | fog | brightness | contrast | elastic_trans | pixelate | jpeg |
| ---------------------------------------------------------- | ---: | ----------: | ---------: | ------------: | -----------: | ---------: | ----------: | --------: | ---: | ----: | ---: | ---------: | -------: | ------------: | -------: | ---: |
| source [code](./cifar10c.py) [config](./cfgs/source.yaml) | 43.5 | 72.3 | 65.7 | 72.9 | 46.9 | 54.3 | 34.8 | 42.0 | 25.1 | 41.3 | 26.0 | 9.3 | 46.7 | 26.6 | 58.5 | 30.3 |
| norm [code](./norm.py) [config](./cfgs/norm.yaml) | 20.4 | 28.1 | 26.1 | 36.3 | 12.8 | 35.3 | 14.2 | 12.1 | 17.3 | 17.4 | 15.3 | 8.4 | 12.6 | 23.8 | 19.7 | 27.3 |
| tent [code](./tent.py) [config](./cfgs/tent.yaml) | 18.6 | 24.8 | 23.5 | 33.0 | 12.0 | 31.8 | 13.7 | 10.8 | 15.9 | 16.2 | 13.7 | 7.9 | 12.1 | 22.0 | 17.3 | 24.2 |

See the full results for this example in the [wandb report](https://wandb.ai/tent/cifar10c).

Expand All @@ -45,12 +47,11 @@ Please contact Dequan Wang and Evan Shelhamer at dqwang AT cs.berkeley.edu and s
If the tent method or fully test-time adaptation setting are helpful in your research, please consider citing our paper:

```bibtex
@inproceedings{
wang2021tent,
title={Tent: Fully Test-Time Adaptation by Entropy Minimization},
author={Dequan Wang and Evan Shelhamer and Shaoteng Liu and Bruno Olshausen and Trevor Darrell},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=uXl3bZLkr3c}
@inproceedings{wang2021tent,
title={Tent: Fully Test-Time Adaptation by Entropy Minimization},
author={Dequan Wang and Evan Shelhamer and Shaoteng Liu and Bruno Olshausen and Trevor Darrell},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=uXl3bZLkr3c}
}
```
17 changes: 6 additions & 11 deletions cfgs/norm.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
MODEL:
ADAPTATION: norm
ARCH: Standard
TEST:
BATCH_SIZE: 200
CORRUPTION:
MODEL: Standard
EVAL_ONLY: True
DATASET: cifar10
SEVERITY:
- 5
- 4
Expand All @@ -23,12 +27,3 @@ CORRUPTION:
- elastic_transform
- pixelate
- jpeg_compression
BN:
FUNC: TrainModeBatchNorm2d
OPTIM:
BATCH_SIZE: 200
METHOD: Adam
ITER: 1
BETA: 0.9
LR: 1e-3
WD: 0.
17 changes: 6 additions & 11 deletions cfgs/base.yaml → cfgs/source.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
MODEL:
ADAPTATION: source
ARCH: Standard
TEST:
BATCH_SIZE: 200
CORRUPTION:
MODEL: Standard
EVAL_ONLY: True
DATASET: cifar10
SEVERITY:
- 5
- 4
Expand All @@ -23,12 +27,3 @@ CORRUPTION:
- elastic_transform
- pixelate
- jpeg_compression
BN:
FUNC: FrozenMeanVarBatchNorm2d
OPTIM:
BATCH_SIZE: 200
METHOD: Adam
ITER: 1
BETA: 0.9
LR: 1e-3
WD: 0.
13 changes: 7 additions & 6 deletions cfgs/tent.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
MODEL:
ADAPTATION: tent
ARCH: Standard
TEST:
BATCH_SIZE: 200
CORRUPTION:
MODEL: Standard
EVAL_ONLY: False
DATASET: cifar10
SEVERITY:
- 5
- 4
Expand All @@ -23,12 +27,9 @@ CORRUPTION:
- elastic_transform
- pixelate
- jpeg_compression
BN:
FUNC: TrainModeBatchNorm2d
OPTIM:
BATCH_SIZE: 200
METHOD: Adam
ITER: 1
STEPS: 1
BETA: 0.9
LR: 1e-3
WD: 0.
105 changes: 99 additions & 6 deletions cifar10c.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,121 @@
import logging

import torch
import torch.optim as optim

from robustbench.data import load_cifar10c
from robustbench.model_zoo.enums import ThreatModel
from robustbench.utils import load_model
from robustbench.utils import clean_accuracy as accuracy

from tent import tent
import tent
import norm

from conf import cfg, load_cfg_fom_args


logger = logging.getLogger(__name__)


def evaluate(cfg_file):
load_cfg_fom_args(cfg_file=cfg_file,
description="CIFAR-10-C evaluation.")
logger = logging.getLogger(__name__)
# configure model
base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR,
cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda()
if cfg.MODEL.ADAPTATION == "source":
logger.info("test-time adaptation: NONE")
model = setup_source(base_model)
if cfg.MODEL.ADAPTATION == "norm":
logger.info("test-time adaptation: NORM")
model = setup_norm(base_model)
if cfg.MODEL.ADAPTATION == "tent":
logger.info("test-time adaptation: TENT")
model = setup_tent(base_model)
# evaluate on each severity and type of corruption in turn
for severity in cfg.CORRUPTION.SEVERITY:
for corruption_type in cfg.CORRUPTION.TYPE:
# reset adaptation for each combination of corruption x severity
# note: for evaluation protocol, but not necessarily needed
try:
model.reset()
logger.info("resetting model")
except:
logger.warning("not resetting model")
x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX,
severity, cfg.DATA_DIR, False,
[corruption_type])
x_test, y_test = x_test.cuda(), y_test.cuda()
model = tent(cfg.CORRUPTION.MODEL)
acc = accuracy(model, x_test, y_test, cfg.OPTIM.BATCH_SIZE)
logger.info('accuracy [{}{}]: {:.2%}'.format(
corruption_type, severity, acc))
acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE)
err = 1. - acc
logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}")


def setup_source(model):
"""Set up the baseline source model without adaptation."""
model.eval()
logger.info(f"model for evaluation: %s", model)
return model


def setup_norm(model):
"""Set up test-time normalization adaptation.
Adapt by normalizing features with test batch statistics.
The statistics are measured independently for each batch;
no running average or other cross-batch estimation is used.
"""
norm_model = norm.Norm(model)
logger.info(f"model for adaptation: %s", model)
stats, stat_names = norm.collect_stats(model)
logger.info(f"stats for adaptation: %s", stat_names)
return norm_model


def setup_tent(model):
"""Set up tent adaptation.
Configure the model for training + feature modulation by batch statistics,
collect the parameters for feature modulation by gradient optimization,
set up the optimizer, and then tent the model.
"""
model = tent.configure_model(model)
params, param_names = tent.collect_params(model)
optimizer = setup_optimizer(params)
tent_model = tent.Tent(model, optimizer,
steps=cfg.OPTIM.STEPS,
episodic=cfg.MODEL.EPISODIC)
logger.info(f"model for adaptation: %s", model)
logger.info(f"params for adaptation: %s", param_names)
logger.info(f"optimizer for adaptation: %s", optimizer)
return tent_model


def setup_optimizer(params):
"""Set up optimizer for tent adaptation.
Tent needs an optimizer for test-time entropy minimization.
In principle, tent could make use of any gradient optimizer.
In practice, we advise choosing Adam or SGD+momentum.
For optimization settings, we advise to use the settings from the end of
trainig, if known, or start with a low learning rate (like 0.001) if not.
For best results, try tuning the learning rate and batch size.
"""
if cfg.OPTIM.METHOD == 'Adam':
return optim.Adam(params,
lr=cfg.OPTIM.LR,
betas=(cfg.OPTIM.BETA, 0.999),
weight_decay=cfg.OPTIM.WD)
elif cfg.OPTIM.METHOD == 'SGD':
return optim.SGD(params,
lr=cfg.OPTIM.LR,
momentum=cfg.OPTIM.MOMENTUM,
dampening=cfg.OPTIM.DAMPENING,
weight_decay=cfg.OPTIM.WD,
nesterov=cfg.OPTIM.NESTEROV)
else:
raise NotImplementedError


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit fc5705f

Please sign in to comment.