From 1ffcd70a290e0047c083d9b918fc5f20ded5247d Mon Sep 17 00:00:00 2001 From: Dequan Wang Date: Wed, 14 Apr 2021 01:06:37 +0100 Subject: [PATCH] illustrate tent by image corruption example to illustrate the tent method and fully test-time adaptation setting, we provide an example for adaptation to image corruptions. this is simply *example code* for explanation, not *reference code* for reproduction. that said, experimenting with this should give results that are representative. reference code will follow to reproduce our ImageNet-C results --- README.md | 28 ++++++ cfgs/base.yaml | 34 ++++++++ cfgs/norm.yaml | 34 ++++++++ cfgs/tent.yaml | 34 ++++++++ cifar10c.py | 29 +++++++ cifar10c.yaml | 33 +++++++ conf.py | 218 +++++++++++++++++++++++++++++++++++++++++++++++ norm.py | 79 +++++++++++++++++ requirements.txt | 6 ++ tent.py | 133 +++++++++++++++++++++++++++++ 10 files changed, 628 insertions(+) create mode 100644 cfgs/base.yaml create mode 100644 cfgs/norm.yaml create mode 100644 cfgs/tent.yaml create mode 100644 cifar10c.py create mode 100644 cifar10c.yaml create mode 100644 conf.py create mode 100644 norm.py create mode 100644 requirements.txt create mode 100644 tent.py diff --git a/README.md b/README.md index 4c5012c..5ddba44 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,36 @@ Dequan Wang\*, Evan Shelhamer\*, Shaoteng Liu, Bruno Olshausen, and Trevor Darre Tent equips a model to adapt itself to new and different data ☀️ 🌧 ❄️ during testing. Tent updates online and batch-by-batch to reduce error on dataset shifts like corruptions, simulation-to-real discrepancies, and other differences between training and testing data. +Our **example code** illustrates the method and provides representative results for image corruptions on CIFAR-10-C. +Note that the exact details of the model, optimization, etc. differ from the paper, so this is not for reproduction, but for explanation. + Please check back soon for our **reference code** to reproduce and extend tent! +## Example: Adapting to Image Corruptions on CIFAR-10-C + +This example compares a baseline without adaptation (base), test-time normalization that updates feature statistics during testing (norm), and our method for entropy minimization during testing (tent). + +- Dataset: [CIFAR-10-C](https://github.com/hendrycks/robustness/), with 15 corruption types and 5 levels. +- Model: [WRN-28-10](https://github.com/RobustBench/robustbench), the default model for RobustBench. + +**Usage**: + +```python +python cifar10c.py --cfg cfgs/base.yaml +python cifar10c.py --cfg cfgs/norm.yaml +python cifar10c.py --cfg cfgs/tent.yaml +``` + +**Result**: tent reduces the error (%) across corruption types at the most severe level of corruption (level 5). + +| | mean | gauss_noise | shot_noise | impulse_noise | defocus_blur | glass_blur | motion_blur | zoom_blur | snow | frost | fog | brightness | contrast | elastic_trans | pixelate | jpeg | +| ---------------------------------------------------- | ---: | ----------: | ---------: | ------------: | -----------: | ---------: | ----------: | --------: | ---: | ----: | ---: | ---------: | -------: | ------------: | -------: | ---: | +| [base](./cifar10c.py) | 43.5 | 72.3 | 65.7 | 72.9 | 46.9 | 54.3 | 34.8 | 42.0 | 25.1 | 41.3 | 26.0 | 9.3 | 46.7 | 26.6 | 58.5 | 30.3 | +| [norm](./norm.py) | 20.4 | 28.1 | 26.1 | 36.3 | 12.8 | 35.3 | 14.2 | 12.1 | 17.3 | 17.4 | 15.3 | 8.4 | 12.6 | 23.8 | 19.7 | 27.3 | +| [tent](./tent.py) | 18.6 | 24.8 | 23.5 | 33.0 | 11.9 | 31.9 | 13.7 | 10.8 | 15.9 | 16.2 | 13.7 | 7.9 | 12.1 | 22.0 | 17.3 | 24.2 | + +See the full results for this example in the [wandb report](https://wandb.ai/tent/cifar10c). + ## Correspondence Please contact Dequan Wang and Evan Shelhamer at dqwang AT cs.berkeley.edu and shelhamer AT google.com. diff --git a/cfgs/base.yaml b/cfgs/base.yaml new file mode 100644 index 0000000..ee42f86 --- /dev/null +++ b/cfgs/base.yaml @@ -0,0 +1,34 @@ +CORRUPTION: + MODEL: Standard + EVAL_ONLY: True + SEVERITY: + - 5 + - 4 + - 3 + - 2 + - 1 + TYPE: + - gaussian_noise + - shot_noise + - impulse_noise + - defocus_blur + - glass_blur + - motion_blur + - zoom_blur + - snow + - frost + - fog + - brightness + - contrast + - elastic_transform + - pixelate + - jpeg_compression +BN: + FUNC: FrozenMeanVarBatchNorm2d +OPTIM: + BATCH_SIZE: 200 + METHOD: Adam + ITER: 1 + BETA: 0.9 + LR: 1e-3 + WD: 0. diff --git a/cfgs/norm.yaml b/cfgs/norm.yaml new file mode 100644 index 0000000..f1bfbc5 --- /dev/null +++ b/cfgs/norm.yaml @@ -0,0 +1,34 @@ +CORRUPTION: + MODEL: Standard + EVAL_ONLY: True + SEVERITY: + - 5 + - 4 + - 3 + - 2 + - 1 + TYPE: + - gaussian_noise + - shot_noise + - impulse_noise + - defocus_blur + - glass_blur + - motion_blur + - zoom_blur + - snow + - frost + - fog + - brightness + - contrast + - elastic_transform + - pixelate + - jpeg_compression +BN: + FUNC: TrainModeBatchNorm2d +OPTIM: + BATCH_SIZE: 200 + METHOD: Adam + ITER: 1 + BETA: 0.9 + LR: 1e-3 + WD: 0. diff --git a/cfgs/tent.yaml b/cfgs/tent.yaml new file mode 100644 index 0000000..677b5f3 --- /dev/null +++ b/cfgs/tent.yaml @@ -0,0 +1,34 @@ +CORRUPTION: + MODEL: Standard + EVAL_ONLY: False + SEVERITY: + - 5 + - 4 + - 3 + - 2 + - 1 + TYPE: + - gaussian_noise + - shot_noise + - impulse_noise + - defocus_blur + - glass_blur + - motion_blur + - zoom_blur + - snow + - frost + - fog + - brightness + - contrast + - elastic_transform + - pixelate + - jpeg_compression +BN: + FUNC: TrainModeBatchNorm2d +OPTIM: + BATCH_SIZE: 200 + METHOD: Adam + ITER: 1 + BETA: 0.9 + LR: 1e-3 + WD: 0. diff --git a/cifar10c.py b/cifar10c.py new file mode 100644 index 0000000..ed7cc5c --- /dev/null +++ b/cifar10c.py @@ -0,0 +1,29 @@ +import logging + +import torch + +from robustbench.data import load_cifar10c +from robustbench.utils import clean_accuracy as accuracy + +from tent import tent +from conf import cfg, load_cfg_fom_args + + +def evaluate(cfg_file): + load_cfg_fom_args(cfg_file=cfg_file, + description="CIFAR-10-C evaluation.") + logger = logging.getLogger(__name__) + for severity in cfg.CORRUPTION.SEVERITY: + for corruption_type in cfg.CORRUPTION.TYPE: + x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX, + severity, cfg.DATA_DIR, False, + [corruption_type]) + x_test, y_test = x_test.cuda(), y_test.cuda() + model = tent(cfg.CORRUPTION.MODEL) + acc = accuracy(model, x_test, y_test, cfg.OPTIM.BATCH_SIZE) + logger.info('accuracy [{}{}]: {:.2%}'.format( + corruption_type, severity, acc)) + + +if __name__ == '__main__': + evaluate('cifar10c.yaml') diff --git a/cifar10c.yaml b/cifar10c.yaml new file mode 100644 index 0000000..82b1f58 --- /dev/null +++ b/cifar10c.yaml @@ -0,0 +1,33 @@ +CORRUPTION: + MODEL: Standard + SEVERITY: + - 5 + - 4 + - 3 + - 2 + - 1 + TYPE: + - gaussian_noise + - shot_noise + - impulse_noise + - defocus_blur + - glass_blur + - motion_blur + - zoom_blur + - snow + - frost + - fog + - brightness + - contrast + - elastic_transform + - pixelate + - jpeg_compression +BN: + FUNC: TrainModeBatchNorm2d +OPTIM: + BATCH_SIZE: 200 + METHOD: Adam + ITER: 1 + BETA: 0.9 + LR: 1e-3 + WD: 0. diff --git a/conf.py b/conf.py new file mode 100644 index 0000000..48b0f06 --- /dev/null +++ b/conf.py @@ -0,0 +1,218 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Configuration file (powered by YACS).""" + +import argparse +import os +import sys +import logging +import random +import torch +import numpy as np +from datetime import datetime +from iopath.common.file_io import g_pathmgr +from yacs.config import CfgNode as CfgNode +# Global config object (example usage: from core.config import cfg) +_C = CfgNode() +cfg = _C + + +# ------------------------------ Corruption options ----------------------------------- # +_C.CORRUPTION = CfgNode() + +_C.CORRUPTION.DATASET = 'cifar10' + +# Check https://github.com/hendrycks/robustness for corruption details +_C.CORRUPTION.TYPE = ['gaussian_noise', 'shot_noise', 'impulse_noise', + 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', + 'snow', 'frost', 'fog', 'brightness', 'contrast', + 'elastic_transform', 'pixelate', 'jpeg_compression'] +_C.CORRUPTION.SEVERITY = [5, 4, 3, 2, 1] + +# Check https://github.com/RobustBench/robustbench for available models +_C.CORRUPTION.MODEL = 'Standard' + +# Accumulate the optimizations or not +_C.CORRUPTION.EVAL_ONLY = False +_C.CORRUPTION.RESET_STATE = False + + +# Number of examples to evaluate (10000 for all samples in CIFAR-10) +_C.CORRUPTION.NUM_EX = 10000 + +# -------------------------------- Batch norm options -------------------------------- # +_C.BN = CfgNode() + +# (BatchNorm2d, TrainModeBatchNorm2d, FrozenMeanVarBatchNorm2d) +_C.BN.FUNC = 'BatchNorm2d' + +# BN epsilon +_C.BN.EPS = 1e-5 + +# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) +_C.BN.MOM = 0.1 + +# -------------------------------- Optimizer options --------------------------------- # +_C.OPTIM = CfgNode() + +# Choices: Adam, AdaMod +_C.OPTIM.METHOD = 'Adam' + +# Number of iterations +_C.OPTIM.ITER = 5 + +# Total mini-batch size +_C.OPTIM.BATCH_SIZE = 128 + +# Learning rate +_C.OPTIM.LR = 1e-3 + +# Beta +_C.OPTIM.BETA = 0.9 + +# Momentum +_C.OPTIM.MOMENTUM = 0.9 + +# Momentum dampening +_C.OPTIM.DAMPENING = 0.0 + +# Nesterov momentum +_C.OPTIM.NESTEROV = True + +# L2 regularization +_C.OPTIM.WD = 0.0 + + +# ---------------------------------- CUDNN options ----------------------------------- # +_C.CUDNN = CfgNode() + +# Perform benchmarking to select fastest CUDNN algorithms (best for fixed input sizes) +_C.CUDNN.BENCHMARK = True + +# ---------------------------------- Weights & Biases options ----------------------------------- # +_C.WANDB = CfgNode() +_C.WANDB.PROJECT = 'tent' +_C.WANDB.ENTITY = 'adaptation' +_C.WANDB.NAME = '' +_C.WANDB.LOG = True + +# ----------------------------------- Misc options ----------------------------------- # + +# Optional description of a config +_C.DESC = "" + +# Note that non-determinism is still present due to non-deterministic GPU ops +_C.RNG_SEED = 1 + +# Output directory +_C.SAVE_DIR = "./output" + +# Data directory +_C.DATA_DIR = "./data" + +# Weight directory +_C.CKPT_DIR = "./ckpt" + +# Log destination (in SAVE_DIR) +_C.LOG_DEST = "log.txt" + +# Log datetime +_C.LOG_TIME = '' + +# # Config destination (in SAVE_DIR) +# _C.CFG_DEST = "cfg.yaml" + +# ---------------------------------- Default config ---------------------------------- # +_CFG_DEFAULT = _C.clone() +_CFG_DEFAULT.freeze() + + +def assert_and_infer_cfg(): + """Checks config values invariants.""" + err_str = "The first lr step must start at 0" + assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str + data_splits = ["train", "val", "test"] + err_str = "Data split '{}' not supported" + assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT) + assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT) + err_str = "Mini-batch size should be a multiple of NUM_GPUS." + assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str + assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str + err_str = "Log destination '{}' not supported" + assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST) + + +def merge_from_file(cfg_file): + with g_pathmgr.open(cfg_file, "r") as f: + cfg = _C.load_cfg(f) + _C.merge_from_other_cfg(cfg) + + +def dump_cfg(): + """Dumps the config to the output directory.""" + cfg_file = os.path.join(_C.SAVE_DIR, _C.CFG_DEST) + with g_pathmgr.open(cfg_file, "w") as f: + _C.dump(stream=f) + + +def load_cfg(out_dir, cfg_dest="config.yaml"): + """Loads config from specified output directory.""" + cfg_file = os.path.join(out_dir, cfg_dest) + merge_from_file(cfg_file) + + +def reset_cfg(): + """Reset config to initial state.""" + cfg.merge_from_other_cfg(_CFG_DEFAULT) + + +def load_cfg_fom_args(cfg_file='conf.yaml', description="Config file options."): + """Load config from command line arguments and set any specified options.""" + current_time = datetime.now().strftime("%y%m%d_%H%M%S") + parser = argparse.ArgumentParser(description=description) + parser.add_argument("--cfg", dest="cfg_file", type=str, default=None, + help="Config file location") + parser.add_argument("opts", default=None, nargs=argparse.REMAINDER, + help="See conf.py for all options") + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + args = parser.parse_args() + + if args.cfg_file is None: + args.cfg_file = cfg_file + merge_from_file(args.cfg_file) + cfg.merge_from_list(args.opts) + if args.cfg_file == cfg_file: + log_dest = '{}_{}.txt'.format(cfg.CORRUPTION.MODEL, current_time) + else: + log_dest = os.path.basename(args.cfg_file) + log_dest = log_dest.replace('.yaml', '_{}.txt'.format(current_time)) + + g_pathmgr.mkdirs(cfg.SAVE_DIR) + cfg.LOG_TIME, cfg.LOG_DEST = current_time, log_dest + if cfg.WANDB.NAME == '': + cfg.WANDB.NAME = log_dest[:-4] + cfg.freeze() + + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(filename)s: %(lineno)4d]: %(message)s", + datefmt="%y/%m/%d %H:%M:%S", + handlers=[ + logging.FileHandler(os.path.join(cfg.SAVE_DIR, cfg.LOG_DEST)), + logging.StreamHandler() + ]) + + np.random.seed(cfg.RNG_SEED) + torch.manual_seed(cfg.RNG_SEED) + random.seed(cfg.RNG_SEED) + torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK + + logger = logging.getLogger(__name__) + version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()] + logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version)) + logger.info(cfg) diff --git a/norm.py b/norm.py new file mode 100644 index 0000000..f8e15a0 --- /dev/null +++ b/norm.py @@ -0,0 +1,79 @@ +import torch + +from torch import nn +from torch.nn import BatchNorm2d + +from conf import cfg + + +def get_norm(out_channels): + """ + Args: + norm (str or callable): one of the batch normalization types + (BatchNorm2d, TrainModeBatchNorm2d, FrozenMeanVarBatchNorm2d). + Returns: + nn.Module or None: the normalization module. + """ + return globals()[cfg.BN.FUNC](out_channels, cfg.BN.EPS, cfg.BN.MOM) + + +class TrainModeBatchNorm2d(nn.Module): + __constants__ = ['eps', 'num_features'] + + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + self.num_features, self.eps = num_features, eps + self.register_parameter("weight", + nn.Parameter(torch.ones(num_features))) + self.register_parameter("bias", + nn.Parameter(torch.zeros(num_features))) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features) - eps) + + def extra_repr(self): + return '{num_features}, eps={eps}, affine=True'.format(**self.__dict__) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + current_mean = x.mean([0, 2, 3]) + current_var = x.var([0, 2, 3], unbiased=False) + scale = self.weight * (current_var + self.eps).rsqrt() + bias = self.bias - current_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + return x * scale + bias + + +class FrozenMeanVarBatchNorm2d(nn.Module): + __constants__ = ['eps', 'num_features'] + + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super().__init__() + self.num_features, self.eps = num_features, eps + self.register_parameter("weight", + nn.Parameter(torch.ones(num_features))) + self.register_parameter("bias", + nn.Parameter(torch.zeros(num_features))) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features) - eps) + + def extra_repr(self): + return '{num_features}, eps={eps}, affine=True'.format(**self.__dict__) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + scale = self.weight * (self.running_var + self.eps).rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + return x * scale + bias diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..093aae9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +numpy>=1.19.5 +torch==1.8.1 +torchvision==0.9.1 +yacs==0.1.8 +iopath==0.1.8 +git+https://github.com/robustbench/robustbench@v0.1#egg=robustbench diff --git a/tent.py b/tent.py new file mode 100644 index 0000000..00f8372 --- /dev/null +++ b/tent.py @@ -0,0 +1,133 @@ +from copy import deepcopy +import logging + +import torch +import torch.jit + +from torch import nn +from torch.nn import Identity +from torch.optim import Adam, SGD + +from robustbench.utils import load_model +from robustbench.model_zoo.enums import ThreatModel + +from conf import cfg + + +logger = logging.getLogger(__name__) + + +def collect_bn_params(model, logging_bn_params): + bn_params, bn_names, bn_names_module = [], set(), set() + # only optimize the affine parameters after normalization + for name, _ in model.named_buffers(): + if "running_mean" in name: + bn_names.add(name.replace('running_mean', 'weight')) + bn_names.add(name.replace('running_mean', 'bias')) + bn_names_module.add(name[:-13]) # len(".running_mean") == 13 + + for name, params in model.named_parameters(): + if name in bn_names: # if "bn" in name: + params.requires_grad = True + bn_params.append(params) + else: + params.requires_grad = False + + if logging_bn_params: + logger.info('test-time optimized parameters: ') + logger.info(bn_names) + + return bn_params + + +def construct_optimizer(optim_param): + if cfg.OPTIM.METHOD == 'Adam': + return Adam(optim_param, + lr=cfg.OPTIM.LR, + betas=(cfg.OPTIM.BETA, 0.999), + weight_decay=cfg.OPTIM.WD) + elif cfg.OPTIM.METHOD == 'SGD': + return SGD(optim_param, + lr=cfg.OPTIM.LR, + momentum=cfg.OPTIM.MOMENTUM, + dampening=cfg.OPTIM.DAMPENING, + weight_decay=cfg.OPTIM.WD, + nesterov=cfg.OPTIM.NESTEROV) + else: + raise NotImplementedError + + +@torch.jit.script +def softmax_entropy(x: torch.Tensor) -> torch.Tensor: + return -(x.softmax(1) * x.log_softmax(1)).sum(1) + + +class SoftmaxEntropy(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return softmax_entropy(x).mean(0) + + +def save_model_and_optimizer(model, optimizer): + """Saves the state dicts of model and optimizer.""" + model_state = deepcopy(model.state_dict()) + optimizer_state = deepcopy(optimizer.state_dict()) + return model_state, optimizer_state + + +def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): + """Loads the state dicts of model and optimizer.""" + model.load_state_dict(model_state, strict=True) + optimizer.load_state_dict(optimizer_state) + + +@torch.enable_grad() +def optim_model(inputs, model, optimizer, loss_fun): + # Perform the forward pass + preds = model(inputs) + # Compute the loss + loss = loss_fun(preds) + # Perform the backward pass + optimizer.zero_grad() + loss.backward(retain_graph=True) + # Update the parameters + optimizer.step() + return preds + + +class tent(nn.Module): + + def __init__(self, model_name): + super().__init__() + self.model = load_model(model_name, cfg.CKPT_DIR, + cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda() + self.iter = cfg.OPTIM.ITER + self.eval_only = cfg.CORRUPTION.EVAL_ONLY + self.reset_state = cfg.CORRUPTION.RESET_STATE + self.loss_fun = SoftmaxEntropy().cuda() + self.optimizer = construct_optimizer( + collect_bn_params(self.model, False)) + self.model_state, self.optimizer_state = \ + save_model_and_optimizer(self.model, self.optimizer) + + def _reset_state(self, x): + load_model_and_optimizer(self.model, self.optimizer, + self.model_state, self.optimizer_state) + + def forward(self, x): + if self.reset_state: + self._reset_state(x) + + if not self.eval_only: + self.model.train() + for _ in range(self.iter): + x = optim_model( + x, self.model, + self.optimizer, + self.loss_fun) + else: + x = self.model(x) + + return x