From 4928a7d26e09dfdd76fad57fe5fecc856cddaa37 Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Tue, 10 Oct 2023 16:35:01 +0200 Subject: [PATCH 01/27] add mixup --- torch_uncertainty/models/resnet/std.py | 11 ++ torch_uncertainty/routines/classification.py | 82 +++++++++- torch_uncertainty/transforms/__init__.py | 2 + torch_uncertainty/transforms/mixup.py | 155 +++++++++++++++++++ 4 files changed, 242 insertions(+), 8 deletions(-) create mode 100644 torch_uncertainty/transforms/mixup.py diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index f179613f..cb7a8204 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -310,6 +310,17 @@ def forward(self, x: Tensor) -> Tensor: out = self.linear(out) return out + def feats_forward(self, x: Tensor) -> Tensor: + out = F.relu(self.bn1(self.conv1(x))) + out = self.optional_pool(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = self.pool(out) + out = self.flatten(out) + return out + def resnet18( in_channels: int, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index f7b457bf..6208b7e1 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -9,7 +9,7 @@ from einops import rearrange from pytorch_lightning.utilities.memory import get_model_size_mb from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT -from timm.data import Mixup +from timm.data import Mixup as timm_Mixup from torch import nn from torchmetrics import Accuracy, CalibrationError, MetricCollection from torchmetrics.classification import ( @@ -30,6 +30,7 @@ NegativeLogLikelihood, VariationRatio, ) +from ..transforms import Mixup, MixupIO, RegMixup, WarpingMixup # fmt:on @@ -59,6 +60,11 @@ def __init__( loss: Type[nn.Module], optimization_procedure: Any, format_batch_fn: nn.Module = nn.Identity(), + mixtype: str = "erm", + mixmode: str = "elem", + dist_sim: str = "emb", + kernel_tau_max: float = 1.0, + kernel_tau_std: float = 0.5, mixup_alpha: float = 0, cutmix_alpha: float = 0, ood_detection: bool = False, @@ -143,12 +149,47 @@ def __init__( "Cutmix alpha and Mixup alpha must be positive." f"Got {mixup_alpha} and {cutmix_alpha}." ) - elif mixup_alpha > 0 or cutmix_alpha > 0: + + self.mixtype = mixtype + self.mixmode = mixmode + self.dist_sim = dist_sim + + if self.mixtype == "erm": + self.mixup = lambda x, y: (x, y) + elif self.mixtype == "timm": + self.mixup = timm_Mixup( + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "mixup": self.mixup = Mixup( - mixup_alpha=mixup_alpha, cutmix_alpha=cutmix_alpha + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "mixup_io": + self.mixup = MixupIO( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "regmixup": + self.mixup = RegMixup( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "kernel_warping": + self.mixup = WarpingMixup( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + apply_kernel=True, + tau_max=kernel_tau_max, + tau_std=kernel_tau_std, ) - else: - self.mixup = lambda x, y: (x, y) # Handle ELBO special cases self.is_elbo = ( @@ -193,7 +234,17 @@ def on_train_start(self) -> None: def training_step( self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> STEP_OUTPUT: - batch = self.mixup(*batch) + if self.mixtype == "kernel_warping": + if self.dist_sim == "emb": + with torch.no_grad(): + feats = self.model.feats_forward(batch[0]) + + self.mixup(*batch, feats) + elif self.dist_sim == "inp": + self.mixup(*batch, batch[0]) + else: + batch = self.mixup(*batch) + inputs, targets = self.format_batch_fn(batch) if self.is_elbo: @@ -299,10 +350,10 @@ def add_model_specific_args( - ``--logits``: sets :attr:`use_logits` to ``True``. """ parent_parser.add_argument( - "--mixup", dest="mixup_alpha", type=float, default=0 + "--mixup_alpha", dest="mixup_alpha", type=float, default=0 ) parent_parser.add_argument( - "--cutmix", dest="cutmix_alpha", type=float, default=0 + "--cutmix_alpha", dest="cutmix_alpha", type=float, default=0 ) parent_parser.add_argument( "--entropy", dest="use_entropy", action="store_true" @@ -310,6 +361,21 @@ def add_model_specific_args( parent_parser.add_argument( "--logits", dest="use_logits", action="store_true" ) + parent_parser.add_argument( + "--mixtype", dest="mixtype", type=str, default="erm" + ) + parent_parser.add_argument( + "--mixmode", dest="mixmode", type=str, default="elem" + ) + parent_parser.add_argument( + "--dist_sim", dest="dist_sim", type=str, default="emb" + ) + parent_parser.add_argument( + "--kernel_tau_max", dest="kernel_tau_max", type=float, default=1.0 + ) + parent_parser.add_argument( + "--kernel_tau_std", dest="kernel_tau_std", type=float, default=0.5 + ) return parent_parser diff --git a/torch_uncertainty/transforms/__init__.py b/torch_uncertainty/transforms/__init__.py index 9e42bee6..5e89af19 100644 --- a/torch_uncertainty/transforms/__init__.py +++ b/torch_uncertainty/transforms/__init__.py @@ -29,3 +29,5 @@ Color, Sharpness, ] + +from .mixup import Mixup, MixupIO, RegMixup, WarpingMixup diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py new file mode 100644 index 00000000..b4f26fef --- /dev/null +++ b/torch_uncertainty/transforms/mixup.py @@ -0,0 +1,155 @@ +import scipy +import torch +import torch.nn.functional as F + +import numpy as np + +# TODO: torch beta warping (with tensor linspace + approx beta cdf using trapz) +# TODO: Mixup with roll to be more efficient (remove sampling of index) + + +def beta_warping(x, alpha_cdf=1.0, eps=1e-12): + return scipy.stats.beta.cdf(x, a=alpha_cdf + eps, b=alpha_cdf + eps) + + +def sim_gauss_kernel(dist, tau_max=1.0, tau_std=0.5): + dist_rate = tau_max * np.exp( + -(dist - 1) / (np.mean(dist) * 2 * tau_std * tau_std) + ) + + return 1 / (dist_rate + 1e-12) + + +class AbstractMixup: + def __init__(self, alpha=1.0, mode="batch", num_classes=1000) -> None: + self.alpha = alpha + self.num_classes = num_classes + self.mode = mode + + def _get_params(self, batch_size: int, device: torch.device): + if self.mode == "batch": + lam = np.random.beta(self.alpha, self.alpha) + else: + lam = torch.tensor( + np.random.beta(self.alpha, self.alpha, batch_size), + device=device, + ) + + index = torch.randperm(batch_size, device=device) + + return lam, index + + def _linear_mixing( + self, + lam: torch.Tensor | float, + inp: torch.Tensor, + index: torch.Tensor, + ) -> torch.Tensor: + if isinstance(lam, torch.Tensor): + lam = lam.view(-1, *[1 for _ in range(inp.ndim - 1)]).float() + + return lam * inp + (1 - lam) * inp[index, :] + + def _mix_target( + self, + lam: torch.Tensor | float, + target: torch.Tensor, + index: torch.Tensor, + ) -> torch.Tensor: + y1 = F.one_hot(target, self.num_classes) + y2 = F.one_hot(target[index], self.num_classes) + if isinstance(lam, torch.Tensor): + lam = lam.view(-1, *[1 for _ in range(y1.ndim - 1)]).float() + + if isinstance(lam, torch.Tensor) and lam.dtype == torch.bool: + return lam * y1 + (~lam) * y2 + else: + return lam * y1 + (1 - lam) * y2 + + def __call__( + self, x: torch.Tensor, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + return x, y + + +class Mixup(AbstractMixup): + def __call__( + self, x: torch.Tensor, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + lam, index = self._get_params(x.size()[0], x.device) + + mixed_x = self._linear_mixing(lam, x, index) + + mixed_y = self._mix_target(lam, y, index) + + return mixed_x, mixed_y + + +class MixupIO(AbstractMixup): + def __call__( + self, x: torch.Tensor, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + lam, index = self._get_params(x.size()[0], x.device) + + mixed_x = self._linear_mixing(lam, x, index) + + mixed_y = self._mix_target((lam > 0.5), y, index) + + return mixed_x, mixed_y + + +class RegMixup(AbstractMixup): + def __call__( + self, x: torch.Tensor, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + lam, index = self._get_params(x.size()[0], x.device) + + part_x = self._linear_mixing(lam, x, index) + + part_y = self._mix_target(lam, y, index) + + mixed_x = torch.cat([x, part_x], dim=0) + mixed_y = torch.cat([F.one_hot(y, self.num_classes), part_y], dim=0) + + return mixed_x, mixed_y + + +class WarpingMixup(AbstractMixup): + def __init__( + self, + alpha=1.0, + mode="batch", + num_classes=1000, + apply_kernel=True, + tau_max=1.0, + tau_std=0.5, + ) -> None: + super().__init__(alpha, mode, num_classes) + self.apply_kernel = apply_kernel + self.tau_max = tau_max + self.tau_std = tau_std + + def __call__( + self, x: torch.Tensor, y: torch.Tensor, feats, warp_param=1.0 + ) -> tuple[torch.Tensor, torch.Tensor]: + lam, index = self._get_params(x.size()[0], x.device) + + if self.apply_kernel: + l2_dist = ( + ( + (feats - feats[index]) + .pow(2) + .sum([i for i in range(len(feats.size())) if i > 0]) + ) + .cpu() + .numpy() + ) + warp_param = sim_gauss_kernel(l2_dist, self.tau_max, self.tau_std) + + k_lam = torch.tensor(beta_warping(lam, warp_param), device=x.device) + + mixed_x = self._linear_mixing(k_lam, x, index) + + mixed_y = self._mix_target(k_lam, y, index) + + return mixed_x, mixed_y From 3961fd1a29a9efbdf341e5acf1ea94fc0e4a5d91 Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Wed, 11 Oct 2023 10:04:45 +0200 Subject: [PATCH 02/27] fix kernel warping --- torch_uncertainty/transforms/mixup.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index b4f26fef..b862269d 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -129,6 +129,16 @@ def __init__( self.tau_max = tau_max self.tau_std = tau_std + def _get_params(self, batch_size: int, device: torch.device): + if self.mode == "batch": + lam = np.random.beta(self.alpha, self.alpha) + else: + lam = np.random.beta(self.alpha, self.alpha, batch_size) + + index = torch.randperm(batch_size, device=device) + + return lam, index + def __call__( self, x: torch.Tensor, y: torch.Tensor, feats, warp_param=1.0 ) -> tuple[torch.Tensor, torch.Tensor]: From a22675f0965dbfa920970f0e9621dd1c4e6185a1 Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Thu, 12 Oct 2023 20:02:04 +0200 Subject: [PATCH 03/27] temp scaling --- experiments/classification/cifar10/resnet.py | 9 ++- torch_uncertainty/__init__.py | 18 ++++- torch_uncertainty/datamodules/cifar10.py | 6 ++ .../post_processing/calibration/scaler.py | 22 +++--- .../calibration/temperature_scaler.py | 13 ++-- torch_uncertainty/routines/classification.py | 46 ++++++++++- torch_uncertainty/transforms/mixup.py | 77 +++++++++++++++++-- 7 files changed, 163 insertions(+), 28 deletions(-) diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index 4958ce8f..7ec09de8 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -22,6 +22,13 @@ args.root = str(root / "data") dm = CIFAR10DataModule(**vars(args)) + if args.opt_temp_scaling: + args.calibration_set = dm.get_test_set + elif args.val_temp_scaling: + args.calibration_set = dm.get_val_set + else: + args.calibration_set = None + # model model = ResNet( num_classes=dm.num_classes, @@ -34,4 +41,4 @@ **vars(args), ) - cli_main(model, dm, root, net_name, args) + cli_main(model, dm, args.exp_dir, net_name, args) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index ec557934..30bd191a 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -52,7 +52,18 @@ def init_args( action="store_true", help="Allow resuming the training (save optimizer's states)", ) - + parser.add_argument( + "--exp_dir", + type=str, + default="logs/", + help="Directory to store experiment files", + ) + parser.add_argument( + "--opt_temp_scaling", action="store_true", default=False + ) + parser.add_argument( + "--val_temp_scaling", action="store_true", default=False + ) parser = pl.Trainer.add_argparse_args(parser) if network is not None: parser = network.add_model_specific_args(parser) @@ -98,7 +109,7 @@ def cli_main( # logger tb_logger = TensorBoardLogger( - str(root / "logs"), + str(root), name=net_name, default_hp_metric=False, log_graph=args.log_graph, @@ -125,6 +136,7 @@ def cli_main( callbacks=callbacks, logger=tb_logger, deterministic=(args.seed is not None), + inference_mode=args.opt_temp_scaling or args.val_temp_scaling, ) if args.summary: @@ -133,7 +145,7 @@ def cli_main( elif args.test is not None: if args.test >= 0: ckpt_file, _ = get_version( - root=(root / "logs" / net_name), version=args.test + root=(root / net_name), version=args.test ) test_values = trainer.test( network, datamodule=datamodule, ckpt_path=str(ckpt_file) diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index ac06a6cc..0bc07b43 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -244,6 +244,12 @@ def _data_loader( persistent_workers=self.persistent_workers, ) + def get_test_set(self) -> Dataset: + return self.test + + def get_val_set(self) -> Dataset: + return self.val + @classmethod def add_argparse_args( cls, diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 17afabe0..5f031e3d 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -79,17 +79,20 @@ def fit( logits = torch.cat(logits_list).detach().to(self.device) labels = torch.cat(labels_list).detach().to(self.device) - optimizer = optim.LBFGS( - self.temperature, lr=self.lr, max_iter=self.max_iter - ) + with torch.enable_grad(): + optimizer = optim.LBFGS( + self.temperature, lr=self.lr, max_iter=self.max_iter + ) def calib_eval() -> float: - optimizer.zero_grad() - loss = self.criterion(self._scale(logits), labels) - loss.backward() - return loss - - optimizer.step(calib_eval) + with torch.enable_grad(): + optimizer.zero_grad() + loss = self.criterion(self._scale(logits), labels) + loss.backward() + return loss + + with torch.enable_grad(): + optimizer.step(calib_eval) self.trained = True if save_logits: self.logits = logits @@ -105,6 +108,7 @@ def forward(self, logits: torch.Tensor) -> torch.Tensor: ) return self._scale(logits) + @torch.enable_grad() def _scale(self, logits: torch.Tensor) -> torch.Tensor: """ Scale the logits with the optimal temperature. diff --git a/torch_uncertainty/post_processing/calibration/temperature_scaler.py b/torch_uncertainty/post_processing/calibration/temperature_scaler.py index 36d27183..7f40c208 100644 --- a/torch_uncertainty/post_processing/calibration/temperature_scaler.py +++ b/torch_uncertainty/post_processing/calibration/temperature_scaler.py @@ -64,12 +64,13 @@ def _scale(self, logits: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Scaled logits. """ - temperature = ( - self.temperature[0] - .unsqueeze(1) - .expand(logits.size(0), logits.size(1)) - ) - return logits / temperature + with torch.enable_grad(): + temperature = ( + self.temperature[0] + .unsqueeze(1) + .expand(logits.size(0), logits.size(1)) + ) + return logits / temperature @property def temperature(self) -> list: diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 6208b7e1..96852bdb 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -11,6 +11,7 @@ from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from timm.data import Mixup as timm_Mixup from torch import nn +from torch.utils.data import Dataset from torchmetrics import Accuracy, CalibrationError, MetricCollection from torchmetrics.classification import ( BinaryAccuracy, @@ -30,6 +31,7 @@ NegativeLogLikelihood, VariationRatio, ) +from ..post_processing import TemperatureScaler from ..transforms import Mixup, MixupIO, RegMixup, WarpingMixup @@ -70,6 +72,7 @@ def __init__( ood_detection: bool = False, use_entropy: bool = False, use_logits: bool = False, + calibration_set: Dataset | None = None, **kwargs, ) -> None: super().__init__() @@ -91,6 +94,8 @@ def __init__( self.use_logits = use_logits self.use_entropy = use_entropy + self.calibration_set = calibration_set + self.binary_cls = num_classes == 1 # model @@ -130,6 +135,9 @@ def __init__( self.val_cls_metrics = cls_metrics.clone(prefix="hp/val_") self.test_cls_metrics = cls_metrics.clone(prefix="hp/test_") + if self.calibration_set is not None: + self.ts_cls_metrics = cls_metrics.clone(prefix="hp/ts_") + self.test_entropy_id = Entropy() if self.ood_detection: @@ -228,6 +236,9 @@ def on_train_start(self) -> None: "hp/test_aupr": 0, "hp/test_auroc": 0, "hp/test_fpr95": 0, + "hp/ts_test_nll": 0, + "hp/ts_test_ece": 0, + "hp/ts_test_brier": 0, }, ) @@ -237,11 +248,11 @@ def training_step( if self.mixtype == "kernel_warping": if self.dist_sim == "emb": with torch.no_grad(): - feats = self.model.feats_forward(batch[0]) + feats = self.model.feats_forward(batch[0]).detach() - self.mixup(*batch, feats) + batch = self.mixup(*batch, feats) elif self.dist_sim == "inp": - self.mixup(*batch, batch[0]) + batch = self.mixup(*batch, batch[0]) else: batch = self.mixup(*batch) @@ -301,6 +312,15 @@ def test_step( else: ood_values = -confs + if ( + self.calibration_set is not None + and self.scaler is not None + and self.cal_model is not None + ): + cal_logits = self.cal_model(inputs) + cal_probs = F.softmax(cal_logits, dim=-1) + self.ts_cls_metrics.update(cal_probs, targets) + if dataloader_idx == 0: self.test_cls_metrics.update(probs, targets) self.test_entropy_id(probs) @@ -332,12 +352,31 @@ def test_epoch_end( ) self.test_cls_metrics.reset() + if ( + self.calibration_set is not None + and self.scaler is not None + and self.cal_model is not None + ): + self.log_dict(self.ts_cls_metrics.compute()) + self.ts_cls_metrics.reset() + if self.ood_detection: self.log_dict( self.test_ood_metrics.compute(), ) self.test_ood_metrics.reset() + def on_test_start(self) -> None: + if self.calibration_set is not None: + with torch.enable_grad(): + self.scaler = TemperatureScaler(device=self.device).fit( + model=self.model, calibration_set=self.calibration_set() + ) + self.cal_model = torch.nn.Sequential(self.model, self.scaler) + else: + self.scaler = None + self.cal_model = None + @staticmethod def add_model_specific_args( parent_parser: ArgumentParser, @@ -376,6 +415,7 @@ def add_model_specific_args( parent_parser.add_argument( "--kernel_tau_std", dest="kernel_tau_std", type=float, default=0.5 ) + return parent_parser diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index b862269d..d5fc165b 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -1,11 +1,13 @@ import scipy import torch import torch.nn.functional as F +from torch import Tensor import numpy as np # TODO: torch beta warping (with tensor linspace + approx beta cdf using trapz) # TODO: Mixup with roll to be more efficient (remove sampling of index) +# TODO: MIT and Rank Mixup def beta_warping(x, alpha_cdf=1.0, eps=1e-12): @@ -16,6 +18,67 @@ def sim_gauss_kernel(dist, tau_max=1.0, tau_std=0.5): dist_rate = tau_max * np.exp( -(dist - 1) / (np.mean(dist) * 2 * tau_std * tau_std) ) + return 1 / (dist_rate + 1e-12) + + +def tensor_linspace(start: Tensor, stop: Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly + spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / ( + num - 1 + ) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] + # to allow for broadcastings + # using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here + # but torchscript + # "cannot statically infer the expected size of a list in this contex", + # hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out + + +def torch_beta_cdf( + x: Tensor, c1: Tensor | float, c2: Tensor | float, npts=100, eps=1e-12 +): + if isinstance(c1, float): + if c1 == c2: + c1 = torch.tensor([c1], device=x.device) + c2 = c1 + else: + c1 = torch.tensor([c1], device=x.device) + if isinstance(c2, float): + c2 = torch.tensor([c2], device=x.device) + bt = torch.distributions.Beta(c1, c2) + + if isinstance(x, float): + x = torch.tensor(x) + + X = tensor_linspace(torch.zeros_like(x) + eps, x, npts) + return torch.trapezoid(bt.log_prob(X).exp(), X, dim=0) + + +def torch_beta_warping( + x: Tensor, alpha_cdf: float | Tensor = 1.0, eps=1e-12, npts=100 +): + return torch_beta_cdf( + x=x, c1=alpha_cdf + eps, c2=alpha_cdf + eps, npts=npts, eps=eps + ) + + +def torch_sim_gauss_kernel(dist: Tensor, tau_max=1.0, tau_std=0.5): + dist_rate = tau_max * torch.exp( + -(dist - 1) / (torch.mean(dist) * 2 * tau_std * tau_std) + ) return 1 / (dist_rate + 1e-12) @@ -140,17 +203,19 @@ def _get_params(self, batch_size: int, device: torch.device): return lam, index def __call__( - self, x: torch.Tensor, y: torch.Tensor, feats, warp_param=1.0 + self, + x: torch.Tensor, + y: torch.Tensor, + feats: torch.Tensor, + warp_param=1.0, ) -> tuple[torch.Tensor, torch.Tensor]: lam, index = self._get_params(x.size()[0], x.device) if self.apply_kernel: l2_dist = ( - ( - (feats - feats[index]) - .pow(2) - .sum([i for i in range(len(feats.size())) if i > 0]) - ) + (feats - feats[index]) + .pow(2) + .sum([i for i in range(len(feats.size())) if i > 0]) .cpu() .numpy() ) From 0130f52b62bb90c0d5b0364a33560569dbd244a3 Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Fri, 20 Oct 2023 12:36:25 +0200 Subject: [PATCH 04/27] fix temp scaler + add cross val + abstract datamodule --- experiments/classification/cifar10/resnet.py | 57 +++-- .../classification/tiny-imagenet/resnet.py | 57 ++++- tests/_dummies/baseline.py | 10 +- tests/_dummies/datamodule.py | 85 ++----- torch_uncertainty/__init__.py | 182 ++++++++++----- .../baselines/classification/resnet.py | 2 +- .../baselines/classification/vgg.py | 2 +- .../baselines/classification/wideresnet.py | 2 +- torch_uncertainty/baselines/regression/mlp.py | 6 +- torch_uncertainty/datamodules/abstract.py | 211 ++++++++++++++++++ torch_uncertainty/datamodules/cifar10.py | 86 +++---- torch_uncertainty/datamodules/cifar100.py | 81 +++---- .../datamodules/tiny_imagenet.py | 80 ++----- .../datamodules/uci_regression.py | 80 ++----- .../calibration/matrix_scaler.py | 2 +- .../post_processing/calibration/scaler.py | 24 +- .../calibration/temperature_scaler.py | 15 +- .../calibration/vector_scaler.py | 2 +- torch_uncertainty/routines/classification.py | 12 +- torch_uncertainty/routines/regression.py | 6 +- torch_uncertainty/utils/__init__.py | 1 + torch_uncertainty/utils/misc.py | 19 ++ 22 files changed, 606 insertions(+), 416 deletions(-) create mode 100644 torch_uncertainty/datamodules/abstract.py create mode 100644 torch_uncertainty/utils/misc.py diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index 7ec09de8..a7fe361d 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -7,6 +7,7 @@ from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import CIFAR10DataModule from torch_uncertainty.optimization_procedures import get_procedure +from torch_uncertainty.utils import csv_writter # fmt: on if __name__ == "__main__": @@ -16,7 +17,8 @@ else: root = Path(args.root) - net_name = f"{args.version}-resnet{args.arch}-cifar10" + if args.exp_name == "": + args.exp_name = f"{args.version}-resnet{args.arch}-cifar10" # datamodule args.root = str(root / "data") @@ -29,16 +31,43 @@ else: args.calibration_set = None - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=get_procedure( - f"resnet{args.arch}", "cifar10", args.version - ), - style="cifar", - **vars(args), - ) - - cli_main(model, dm, args.exp_dir, net_name, args) + if args.use_cv: + list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "cifar10", args.version + ), + style="cifar", + **vars(args), + ) + ) + + results = cli_main( + list_model, list_dm, args.exp_dir, args.exp_name, args + ) + else: + # model + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "cifar10", args.version + ), + style="cifar", + **vars(args), + ) + + results = cli_main(model, dm, args.exp_dir, args.exp_name, args) + + for dict_result in results: + csv_writter( + Path(args.exp_dir) / Path(args.exp_name) / "results.csv", + dict_result, + ) diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index 58ea4a5f..39c886ed 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -6,6 +6,7 @@ from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import TinyImageNetDataModule +from torch_uncertainty.utils import csv_writter # fmt: on @@ -28,20 +29,54 @@ def optim_tiny(model: nn.Module) -> dict: else: root = Path(args.root) - net_name = f"{args.version}-resnet{args.arch}-tiny-imagenet" + # net_name = f"{args.version}-resnet{args.arch}-tiny-imagenet" + if args.exp_name == "": + args.exp_name = f"{args.version}-resnet{args.arch}-cifar10" # datamodule args.root = str(root / "data") dm = TinyImageNetDataModule(**vars(args)) - # model - model = ResNet( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_tiny, - style="cifar", - **vars(args), - ) + if args.opt_temp_scaling: + args.calibration_set = dm.get_test_set + elif args.val_temp_scaling: + args.calibration_set = dm.get_val_set + else: + args.calibration_set = None + + if args.use_cv: + list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_tiny, + style="cifar", + **vars(args), + ) + ) + + results = cli_main( + list_model, list_dm, args.exp_dir, args.exp_name, args + ) + else: + # model + model = ResNet( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_tiny, + style="cifar", + **vars(args), + ) + + results = cli_main(model, dm, args.exp_dir, args.exp_name, args) - cli_main(model, dm, root, net_name, args) + for dict_result in results: + csv_writter( + Path(args.exp_dir) / Path(args.exp_name) / "results.csv", + dict_result, + ) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index f5cfa2f5..1ed33ac1 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -1,6 +1,6 @@ # fmt: off from argparse import ArgumentParser -from typing import Any +from typing import Any, Type from pytorch_lightning import LightningModule from torch import nn @@ -19,12 +19,12 @@ # fmt: on -class DummyClassificationBaseline: +class DummyClassificationBaseline(LightningModule): def __new__( cls, num_classes: int, in_channels: int, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, baseline_type: str = "single", **kwargs, @@ -64,12 +64,12 @@ def add_model_specific_args( return parser -class DummyRegressionBaseline: +class DummyRegressionBaseline(LightningModule): def __new__( cls, in_features: int, out_features: int, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, baseline_type: str = "single", dist_estimation: int = 1, diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index c8ca4334..aa488efd 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -4,14 +4,15 @@ from typing import Any, List, Optional, Union import torchvision.transforms as T -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader + +from torch_uncertainty.datamodules.abstract import AbstractDataModule from .dataset import DummyClassificationDataset, DummyRegressionDataset # fmt: on -class DummyClassificationDataModule(LightningDataModule): +class DummyClassificationDataModule(AbstractDataModule): num_channels = 1 image_size: int = 8 training_task = "classification" @@ -27,17 +28,16 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() - - root = Path(root) + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - self.root: Path = root self.ood_detection = ood_detection - self.batch_size = batch_size self.num_classes = num_classes - self.num_workers = num_workers - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers self.dataset = DummyClassificationDataset self.ood_dataset = DummyClassificationDataset @@ -80,47 +80,26 @@ def setup(self, stage: Optional[str] = None) -> None: transform=self.transform_test, ) - def train_dataloader(self) -> DataLoader: - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - return self._data_loader(self.val) - def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: dataloader = [self._data_loader(self.test)] if self.ood_detection: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) - @classmethod def add_argparse_args( cls, parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=2) - p.add_argument("--num_workers", type=int, default=1) + p = super().add_argparse_args(parent_parser) p.add_argument( "--evaluate_ood", dest="ood_detection", action="store_true" ) return parent_parser -class DummyRegressionDataModule(LightningDataModule): +class DummyRegressionDataModule(AbstractDataModule): in_features = 4 training_task = "regression" @@ -135,17 +114,16 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - if isinstance(root, str): - root = Path(root) - self.root: Path = root self.ood_detection = ood_detection - self.batch_size = batch_size self.out_features = out_features - self.num_workers = num_workers - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers self.dataset = DummyRegressionDataset self.ood_dataset = DummyRegressionDataset @@ -181,40 +159,19 @@ def setup(self, stage: Optional[str] = None) -> None: transform=self.transform_test, ) - def train_dataloader(self) -> DataLoader: - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - return self._data_loader(self.val) - def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: dataloader = [self._data_loader(self.test)] if self.ood_detection: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) - @classmethod def add_argparse_args( cls, parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=2) - p.add_argument("--num_workers", type=int, default=1) + p = super().add_argparse_args(parent_parser) p.add_argument( "--evaluate_ood", dest="ood_detection", action="store_true" ) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 30bd191a..9e7c6c61 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -1,6 +1,7 @@ # fmt: off # flake8: noqa from argparse import ArgumentParser, Namespace +from collections import defaultdict from pathlib import Path from typing import Dict, Optional, Type, Union @@ -14,6 +15,7 @@ import numpy as np +from .datamodules.abstract import AbstractDataModule from .utils import get_version @@ -58,6 +60,12 @@ def init_args( default="logs/", help="Directory to store experiment files", ) + parser.add_argument( + "--exp_name", + type=str, + default="", + help="Name of the experiment folder", + ) parser.add_argument( "--opt_temp_scaling", action="store_true", default=False ) @@ -75,16 +83,19 @@ def init_args( def cli_main( - network: pl.LightningModule, - datamodule: pl.LightningDataModule, + network: pl.LightningModule | list[pl.LightningModule], + datamodule: AbstractDataModule | list[AbstractDataModule], root: Union[Path, str], net_name: str, args: Namespace, -) -> Dict: +) -> list[Dict]: if isinstance(root, str): root = Path(root) - training_task = datamodule.training_task + if isinstance(datamodule, list): + training_task = datamodule[0].dm.training_task + else: + training_task = datamodule.training_task if training_task == "classification": monitor = "hp/val_acc" mode = "max" @@ -105,58 +116,125 @@ def cli_main( pl.seed_everything(args.seed, workers=True) if args.channels_last: - network = network.to(memory_format=torch.channels_last) - - # logger - tb_logger = TensorBoardLogger( - str(root), - name=net_name, - default_hp_metric=False, - log_graph=args.log_graph, - version=args.test, - ) + if isinstance(network, list): + for i in range(len(network)): + network[i] = network[i].to(memory_format=torch.channels_last) + else: + network = network.to(memory_format=torch.channels_last) - # callbacks - save_checkpoints = ModelCheckpoint( - monitor=monitor, - mode=mode, - save_last=True, - save_weights_only=not args.enable_resume, - ) + if args.use_cv: + test_values = [] + for i in range(len(datamodule)): + print( + f"Starting fold {i} out of {args.train_over} of a {args.n_splits}-fold CV." + ) - # Select the best model, monitor the lr and stop if NaN - callbacks = [ - save_checkpoints, - LearningRateMonitor(logging_interval="step"), - EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), - ] - # trainer - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - logger=tb_logger, - deterministic=(args.seed is not None), - inference_mode=args.opt_temp_scaling or args.val_temp_scaling, - ) + # logger + tb_logger = TensorBoardLogger( + str(root), + name=net_name, + default_hp_metric=False, + log_graph=args.log_graph, + version=args.test, + ) + + # callbacks + save_checkpoints = ModelCheckpoint( + monitor=monitor, + mode=mode, + save_last=True, + save_weights_only=not args.enable_resume, + ) - if args.summary: - summary(network, input_size=list(datamodule.input_shape).insert(0, 1)) - test_values = {} - elif args.test is not None: - if args.test >= 0: - ckpt_file, _ = get_version( - root=(root / net_name), version=args.test + # Select the best model, monitor the lr and stop if NaN + callbacks = [ + save_checkpoints, + LearningRateMonitor(logging_interval="step"), + EarlyStopping( + monitor=monitor, patience=np.inf, check_finite=True + ), + ] + + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=callbacks, + logger=tb_logger, + deterministic=(args.seed is not None), + inference_mode=not ( + args.opt_temp_scaling or args.val_temp_scaling + ), ) - test_values = trainer.test( - network, datamodule=datamodule, ckpt_path=str(ckpt_file) + trainer.fit(network[i], datamodule[i]) + test_values.append( + trainer.test(datamodule=datamodule[i], ckpt_path="last")[0] ) - else: - test_values = trainer.test(network, datamodule=datamodule) + + all_test_values = defaultdict(list) + for test_value in test_values: + for key in test_value: + all_test_values[key].append(test_value[key]) + + avg_test_values = {} + for key in all_test_values: + avg_test_values[key] = np.mean(all_test_values[key]) + + return [avg_test_values] else: - # training and testing - trainer.fit(network, datamodule) - if args.fast_dev_run is False: - test_values = trainer.test(datamodule=datamodule, ckpt_path="best") + # logger + tb_logger = TensorBoardLogger( + str(root), + name=net_name, + default_hp_metric=False, + log_graph=args.log_graph, + version=args.test, + ) + + # callbacks + save_checkpoints = ModelCheckpoint( + monitor=monitor, + mode=mode, + save_last=True, + save_weights_only=not args.enable_resume, + ) + + # Select the best model, monitor the lr and stop if NaN + callbacks = [ + save_checkpoints, + LearningRateMonitor(logging_interval="step"), + EarlyStopping(monitor=monitor, patience=np.inf, check_finite=True), + ] + + # trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=callbacks, + logger=tb_logger, + deterministic=(args.seed is not None), + inference_mode=not (args.opt_temp_scaling or args.val_temp_scaling), + ) + if args.summary: + summary( + network, + input_size=list(datamodule.input_shape).insert(0, 1), + ) + test_values = [{}] + elif args.test is not None: + if args.test >= 0: + ckpt_file, _ = get_version( + root=(root / net_name), version=args.test + ) + test_values = trainer.test( + network, datamodule=datamodule, ckpt_path=str(ckpt_file) + ) + else: + test_values = trainer.test(network, datamodule=datamodule) else: - test_values = {} - return test_values + # training and testing + trainer.fit(network, datamodule) + if args.fast_dev_run is False: + test_values = trainer.test( + datamodule=datamodule, ckpt_path="best" + ) + else: + test_values = [{}] + return test_values diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 383bb142..ff2dc7aa 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -52,7 +52,7 @@ # fmt: on -class ResNet: +class ResNet(LightningModule): r"""ResNet backbone baseline for classification providing support for various versions and architectures. diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 6bfbaf93..45bd1b6e 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -32,7 +32,7 @@ # fmt: on -class VGG: +class VGG(LightningModule): r"""VGG backbone baseline for classification providing support for various versions and architectures. diff --git a/torch_uncertainty/baselines/classification/wideresnet.py b/torch_uncertainty/baselines/classification/wideresnet.py index b5c70877..4a1d49e2 100644 --- a/torch_uncertainty/baselines/classification/wideresnet.py +++ b/torch_uncertainty/baselines/classification/wideresnet.py @@ -32,7 +32,7 @@ # fmt: on -class WideResNet: +class WideResNet(LightningModule): r"""Wide-ResNet28x10 backbone baseline for classification providing support for various versions. diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 49a90461..e671ad43 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -1,7 +1,7 @@ # fmt: off from argparse import ArgumentParser from pathlib import Path -from typing import Any, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Type, Union import torch from pytorch_lightning import LightningModule @@ -17,7 +17,7 @@ # fmt: on -class MLP: +class MLP(LightningModule): r"""MLP baseline for regression providing support for various versions.""" single = ["vanilla"] @@ -28,7 +28,7 @@ def __new__( cls, num_outputs: int, in_features: int, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, version: Literal["vanilla", "packed"], hidden_dims: List[int], diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py new file mode 100644 index 00000000..1fa2175a --- /dev/null +++ b/torch_uncertainty/datamodules/abstract.py @@ -0,0 +1,211 @@ +from argparse import ArgumentParser +from pathlib import Path +from typing import Any, List, Optional, Union + +from pytorch_lightning import LightningDataModule +from sklearn.model_selection import StratifiedKFold +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.sampler import SubsetRandomSampler + +from numpy.typing import ArrayLike + + +class AbstractDataModule(LightningDataModule): + training_task = "" + + def __init__( + self, + root: Union[str, Path], + batch_size: int, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + **kwargs, + ) -> None: + super().__init__() + + if isinstance(root, str): + root = Path(root) + self.root: Path = root + self.batch_size = batch_size + self.num_workers = num_workers + + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + + def setup(self, stage: Optional[str] = None) -> None: + self.train = Dataset() + self.val = Dataset() + self.test = Dataset() + + def get_train_set(self) -> Dataset: + return self.train + + def get_test_set(self) -> Dataset: + return self.test + + def get_val_set(self) -> Dataset: + return self.val + + def train_dataloader(self) -> DataLoader: + r"""Get the training dataloader. + + Return: + DataLoader: training dataloader. + """ + return self._data_loader(self.train, shuffle=True) + + def val_dataloader(self) -> DataLoader: + r"""Get the validation dataloader. + + Return: + DataLoader: validation dataloader. + """ + return self._data_loader(self.val) + + def test_dataloader(self) -> List[DataLoader]: + r"""Get test dataloaders. + + Return: + List[DataLoader]: test set for in distribution data + and out-of-distribution data. + """ + dataloader = [self._data_loader(self.test)] + return dataloader + + def _data_loader( + self, dataset: Dataset, shuffle: bool = False + ) -> DataLoader: + """Create a dataloader for a given dataset. + + Args: + dataset (Dataset): Dataset to create a dataloader for. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults + to False. + + Return: + DataLoader: Dataloader for the given dataset. + """ + return DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=shuffle, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + ) + + # These two functions have to be defined in each datamodule + # by setting the correct path to the matrix of data for each dataset. + # It is generally "Dataset.samples" or "Dataset.data" + # They are used for constructing cross validation splits + def _get_train_data(self) -> ArrayLike: + pass + + def _get_train_targets(self) -> ArrayLike: + pass + + def make_cross_val_splits(self, n_splits=10, train_over=4) -> list: + self.setup("fit") + skf = StratifiedKFold(n_splits) + cv_dm = [] + + for fold, (train_idx, val_idx) in enumerate( + skf.split(self._get_train_data(), self._get_train_targets()) + ): + if fold >= train_over: + break + + fold_dm = CrossValDataModule( + root=self.root, + train_idx=train_idx, + val_idx=val_idx, + datamodule=self, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + ) + + cv_dm.append(fold_dm) + + return cv_dm + + @classmethod + def add_argparse_args( + cls, + parent_parser: ArgumentParser, + **kwargs: Any, + ) -> ArgumentParser: + p = parent_parser.add_argument_group("datamodule") + p.add_argument("--root", type=str, default="./data/") + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--val_split", type=float, default=0.0) + p.add_argument("--num_workers", type=int, default=4) + p.add_argument("--use_cv", action="store_true") + p.add_argument("--n_splits", type=int, default=10) + p.add_argument("--train_over", type=int, default=4) + return parent_parser + + +class CrossValDataModule(AbstractDataModule): + def __init__( + self, + root: str | Path, + train_idx: ArrayLike, + val_idx: ArrayLike, + datamodule: AbstractDataModule, + batch_size: int, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + **kwargs, + ) -> None: + super().__init__( + root, + batch_size, + num_workers, + pin_memory, + persistent_workers, + **kwargs, + ) + + self.train_idx = train_idx + self.val_idx = val_idx + self.dm = datamodule + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + self.train = self.dm.train + self.val = self.dm.val + elif stage == "test": + self.test = self.val + + def _data_loader(self, dataset: Dataset, idx: ArrayLike) -> DataLoader: + return DataLoader( + dataset=dataset, + sampler=SubsetRandomSampler(idx), + shuffle=False, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + ) + + def train_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.train_idx) + + def val_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.val_idx) + + def test_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.val_idx) + + def get_train_set(self) -> Dataset: + return self.dm.train + + def get_test_set(self) -> Dataset: + return self.dm.val + + def get_val_set(self) -> Dataset: + return self.dm.val diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index aabb045b..b7437ef3 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -4,19 +4,22 @@ from typing import Any, List, Literal, Optional, Union import torchvision.transforms as T -from pytorch_lightning import LightningDataModule from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR10, SVHN +import numpy as np +from numpy.typing import ArrayLike + from ..datasets import AggregatedDataset from ..datasets.classification import CIFAR10C, CIFAR10H from ..transforms import Cutout +from .abstract import AbstractDataModule # fmt: on -class CIFAR10DataModule(LightningDataModule): +class CIFAR10DataModule(AbstractDataModule): """DataModule for CIFAR10. Args: @@ -60,19 +63,17 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - if isinstance(root, str): - root = Path(root) - self.root: Path = root - self.ood_detection = ood_detection - self.batch_size = batch_size self.val_split = val_split - self.num_workers = num_workers self.num_dataloaders = num_dataloaders - - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers + self.ood_detection = ood_detection if test_alt == "c": self.dataset = CIFAR10C @@ -202,54 +203,23 @@ def train_dataloader(self) -> DataLoader: else: return self._data_loader(self.train, shuffle=True) - def val_dataloader(self) -> DataLoader: - r"""Gets the validation dataloader for CIFAR10. - - Returns: - DataLoader: CIFAR10 validation dataloader. - """ - return self._data_loader(self.val) - def test_dataloader(self) -> List[DataLoader]: - r"""Get the test dataloaders for CIFAR10. + r"""Get test dataloaders. Return: - List[DataLoader]: Dataloaders of the CIFAR10 test set (in - distribution data) and SVHN test split (out-of-distribution - data). + List[DataLoader]: test set for in distribution data + and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.ood_detection: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - """Create a dataloader for a given dataset. + def _get_train_data(self) -> ArrayLike: + return self.train.dataset.data[self.train.indices] - Args: - dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. - - Return: - DataLoader: Dataloader for the given dataset. - """ - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) - - def get_test_set(self) -> Dataset: - return self.test - - def get_val_set(self) -> Dataset: - return self.val + def _get_train_targets(self) -> ArrayLike: + return np.array(self.train.dataset.targets)[self.train.indices] @classmethod def add_argparse_args( @@ -257,18 +227,16 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--val_split", type=float, default=0.0) - p.add_argument("--num_workers", type=int, default=4) - p.add_argument( - "--evaluate_ood", dest="ood_detection", action="store_true" - ) + p = super().add_argparse_args(parent_parser) + + # Arguments for CIFAR10 p.add_argument("--cutout", type=int, default=0) p.add_argument("--auto_augment", type=str) p.add_argument("--test_alt", choices=["c", "h"], default=None) p.add_argument( "--severity", dest="corruption_severity", type=int, default=None ) + p.add_argument( + "--evaluate_ood", dest="ood_detection", action="store_true" + ) return parent_parser diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/cifar100.py index 8392fc3e..66b80fd2 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/cifar100.py @@ -5,19 +5,22 @@ import torch import torchvision.transforms as T -from pytorch_lightning import LightningDataModule from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR100, SVHN +import numpy as np +from numpy.typing import ArrayLike + from ..datasets import AggregatedDataset from ..datasets.classification import CIFAR100C from ..transforms import Cutout +from .abstract import AbstractDataModule # fmt: on -class CIFAR100DataModule(LightningDataModule): +class CIFAR100DataModule(AbstractDataModule): """DataModule for CIFAR100. Args: @@ -62,20 +65,17 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() - - if isinstance(root, str): - root = Path(root) + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - self.root: Path = root - self.ood_detection = ood_detection - self.batch_size = batch_size self.val_split = val_split - self.num_workers = num_workers self.num_dataloaders = num_dataloaders - - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers + self.ood_detection = ood_detection if test_alt == "c": self.dataset = CIFAR100C @@ -203,48 +203,23 @@ def train_dataloader(self) -> DataLoader: else: return self._data_loader(self.train, shuffle=True) - def val_dataloader(self) -> DataLoader: - """Get the validation dataloader for CIFAR100. - - Return: - DataLoader: CIFAR100 validation dataloader. - """ - return self._data_loader(self.val) - def test_dataloader(self) -> List[DataLoader]: - """Get the test dataloaders for CIFAR100. + r"""Get test dataloaders. Return: - List[DataLoader]: Dataloaders of the CIFAR100 test set (in - distribution data) and SVHN test split (out-of-distribution - data). + List[DataLoader]: test set for in distribution data + and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.ood_detection: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - """Create a dataloader for a given dataset. + def _get_train_data(self) -> ArrayLike: + return self.train.dataset.data[self.train.indices] - Args: - dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. - - Return: - DataLoader: Dataloader for the given dataset. - """ - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) + def _get_train_targets(self) -> ArrayLike: + return np.array(self.train.dataset.targets)[self.train.indices] @classmethod def add_argparse_args( @@ -252,14 +227,9 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--val_split", type=float, default=0.0) - p.add_argument("--num_workers", type=int, default=4) - p.add_argument( - "--evaluate_ood", dest="ood_detection", action="store_true" - ) + p = super().add_argparse_args(parent_parser) + + # Arguments for CIFAR100 p.add_argument("--cutout", type=int, default=0) p.add_argument( "--randaugment", dest="enable_randaugment", action="store_true" @@ -269,4 +239,7 @@ def add_argparse_args( p.add_argument( "--severity", dest="corruption_severity", type=int, default=1 ) + p.add_argument( + "--evaluate_ood", dest="ood_detection", action="store_true" + ) return parent_parser diff --git a/torch_uncertainty/datamodules/tiny_imagenet.py b/torch_uncertainty/datamodules/tiny_imagenet.py index beebe077..e30d0619 100644 --- a/torch_uncertainty/datamodules/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/tiny_imagenet.py @@ -3,17 +3,19 @@ from typing import Any, List, Optional, Union import torchvision.transforms as T -from pytorch_lightning import LightningDataModule from timm.data.auto_augment import rand_augment_transform from torch import nn -from torch.utils.data import ConcatDataset, DataLoader, Dataset +from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN +from numpy.typing import ArrayLike + from ..datasets.classification import ImageNetO, TinyImageNet +from .abstract import AbstractDataModule # fmt: on -class TinyImageNetDataModule(LightningDataModule): +class TinyImageNetDataModule(AbstractDataModule): num_classes = 200 num_channels = 3 training_task = "classification" @@ -30,17 +32,16 @@ def __init__( persistent_workers: bool = True, **kwargs, ) -> None: - super().__init__() + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) # TODO: COMPUTE STATS - if isinstance(root, str): - root = Path(root) - self.root: Path = root self.ood_detection = ood_detection - self.batch_size = batch_size - self.num_workers = num_workers - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers self.ood_ds = ood_ds self.dataset = TinyImageNet @@ -166,55 +167,23 @@ def setup(self, stage: Optional[str] = None) -> None: transform=self.transform_test, ) - def train_dataloader(self) -> DataLoader: - r"""Get the training dataloader for TinyImageNet. - - Return: - DataLoader: TinyImageNet training dataloader. - """ - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - r"""Get the validation dataloader for TinyImageNet. - - Return: - DataLoader: TinyImageNet validation dataloader. - """ - return self._data_loader(self.val) - def test_dataloader(self) -> List[DataLoader]: - r"""Get test dataloaders for TinyImageNet. + r"""Get test dataloaders. Return: - List[DataLoader]: TinyImageNet test set (in distribution data) and - SVHN test split (out-of-distribution data). + List[DataLoader]: test set for in distribution data + and out-of-distribution data. """ dataloader = [self._data_loader(self.test)] if self.ood_detection: dataloader.append(self._data_loader(self.ood)) return dataloader - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - """Create a dataloader for a given dataset. + def _get_train_data(self) -> ArrayLike: + return self.train.samples - Args: - dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. - - Return: - DataLoader: Dataloader for the given dataset. - """ - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) + def _get_train_targets(self) -> ArrayLike: + return self.train.label_data @classmethod def add_argparse_args( @@ -222,14 +191,13 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=256) - p.add_argument("--num_workers", type=int, default=4) + p = super().add_argparse_args(parent_parser) + + # Arguments for Tiny Imagenet p.add_argument( - "--evaluate_ood", dest="ood_detection", action="store_true" + "--rand_augment", dest="rand_augment_opt", type=str, default=None ) p.add_argument( - "--rand_augment", dest="rand_augment_opt", type=str, default=None + "--evaluate_ood", dest="ood_detection", action="store_true" ) return parent_parser diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index 025d45ff..285d1388 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -4,15 +4,15 @@ from pathlib import Path from typing import Any, Optional, Tuple, Union -from pytorch_lightning import LightningDataModule from torch import Generator -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import random_split from ..datasets.regression import UCIRegression +from .abstract import AbstractDataModule # fmt: on -class UCIDataModule(LightningDataModule): +class UCIDataModule(AbstractDataModule): """The UCI regression datasets. Args: @@ -47,17 +47,15 @@ def __init__( split_seed: int = 42, **kwargs, ) -> None: - super().__init__() + super().__init__( + root=root, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) - if isinstance(root, str): - root = Path(root) - self.root: Path = root - self.batch_size = batch_size self.val_split = val_split - self.num_workers = num_workers - - self.pin_memory = pin_memory - self.persistent_workers = persistent_workers self.dataset = partial( UCIRegression, dataset_name=dataset_name, seed=split_seed @@ -89,51 +87,14 @@ def setup(self, stage: Optional[str] = None) -> None: if self.val_split == 0: self.val = self.test - def train_dataloader(self) -> DataLoader: - """Get the training dataloader for UCI Regression. - - Return: - DataLoader: UCI Regression training dataloader. - """ - return self._data_loader(self.train, shuffle=True) - - def val_dataloader(self) -> DataLoader: - """Get the validation dataloader for UCI Regression. - - Return: - DataLoader: UCI Regression validation dataloader. - """ - return self._data_loader(self.val) - - def test_dataloader(self) -> DataLoader: - """Get the test dataloader for UCI Regression. + # Change by default test_dataloader -> List[DataLoader] + # def test_dataloader(self) -> DataLoader: + # """Get the test dataloader for UCI Regression. - Return: - DataLoader: UCI Regression test dataloader. - """ - return self._data_loader(self.test) - - def _data_loader( - self, dataset: Dataset, shuffle: bool = False - ) -> DataLoader: - """Create a dataloader for a given dataset. - - Args: - dataset (Dataset): Dataset to create a dataloader for. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults - to False. - - Return: - DataLoader: Dataloader for the given dataset. - """ - return DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=shuffle, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - persistent_workers=self.persistent_workers, - ) + # Return: + # DataLoader: UCI Regression test dataloader. + # """ + # return self._data_loader(self.test) @classmethod def add_argparse_args( @@ -141,9 +102,6 @@ def add_argparse_args( parent_parser: ArgumentParser, **kwargs: Any, ) -> ArgumentParser: - p = parent_parser.add_argument_group("datamodule") - p.add_argument("--root", type=str, default="./data/") - p.add_argument("--batch_size", type=int, default=128) - p.add_argument("--val_split", type=float, default=0) - p.add_argument("--num_workers", type=int, default=4) + super().add_argparse_args(parent_parser) + return parent_parser diff --git a/torch_uncertainty/post_processing/calibration/matrix_scaler.py b/torch_uncertainty/post_processing/calibration/matrix_scaler.py index 1472188a..598776ab 100644 --- a/torch_uncertainty/post_processing/calibration/matrix_scaler.py +++ b/torch_uncertainty/post_processing/calibration/matrix_scaler.py @@ -37,7 +37,7 @@ def __init__( init_b: float = 0, lr: float = 0.1, max_iter: int = 200, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Optional[Literal["cpu", "cuda"] | torch.device] = None, ) -> None: super().__init__(lr=lr, max_iter=max_iter, device=device) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index 5f031e3d..2bb6749e 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -31,7 +31,7 @@ def __init__( self, lr: float = 0.1, max_iter: int = 100, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Optional[Literal["cpu", "cuda"] | torch.device] = None, ) -> None: super().__init__() self.device = device @@ -79,20 +79,17 @@ def fit( logits = torch.cat(logits_list).detach().to(self.device) labels = torch.cat(labels_list).detach().to(self.device) - with torch.enable_grad(): - optimizer = optim.LBFGS( - self.temperature, lr=self.lr, max_iter=self.max_iter - ) + optimizer = optim.LBFGS( + self.temperature, lr=self.lr, max_iter=self.max_iter + ) def calib_eval() -> float: - with torch.enable_grad(): - optimizer.zero_grad() - loss = self.criterion(self._scale(logits), labels) - loss.backward() - return loss - - with torch.enable_grad(): - optimizer.step(calib_eval) + optimizer.zero_grad() + loss = self.criterion(self._scale(logits), labels) + loss.backward() + return loss + + optimizer.step(calib_eval) self.trained = True if save_logits: self.logits = logits @@ -108,7 +105,6 @@ def forward(self, logits: torch.Tensor) -> torch.Tensor: ) return self._scale(logits) - @torch.enable_grad() def _scale(self, logits: torch.Tensor) -> torch.Tensor: """ Scale the logits with the optimal temperature. diff --git a/torch_uncertainty/post_processing/calibration/temperature_scaler.py b/torch_uncertainty/post_processing/calibration/temperature_scaler.py index 7f40c208..47a1a656 100644 --- a/torch_uncertainty/post_processing/calibration/temperature_scaler.py +++ b/torch_uncertainty/post_processing/calibration/temperature_scaler.py @@ -31,7 +31,7 @@ def __init__( init_val: float = 1, lr: float = 0.1, max_iter: int = 100, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Optional[Literal["cpu", "cuda"] | torch.device] = None, ) -> None: super().__init__(lr=lr, max_iter=max_iter, device=device) @@ -64,13 +64,12 @@ def _scale(self, logits: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Scaled logits. """ - with torch.enable_grad(): - temperature = ( - self.temperature[0] - .unsqueeze(1) - .expand(logits.size(0), logits.size(1)) - ) - return logits / temperature + temperature = ( + self.temperature[0] + .unsqueeze(1) + .expand(logits.size(0), logits.size(1)) + ) + return logits / temperature @property def temperature(self) -> list: diff --git a/torch_uncertainty/post_processing/calibration/vector_scaler.py b/torch_uncertainty/post_processing/calibration/vector_scaler.py index e2147bc8..8daff2f5 100644 --- a/torch_uncertainty/post_processing/calibration/vector_scaler.py +++ b/torch_uncertainty/post_processing/calibration/vector_scaler.py @@ -37,7 +37,7 @@ def __init__( init_b: float = 0, lr: float = 0.1, max_iter: int = 200, - device: Optional[Literal["cpu", "cuda"]] = None, + device: Optional[Literal["cpu", "cuda"] | torch.device] = None, ) -> None: super().__init__(lr=lr, max_iter=max_iter, device=device) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 96852bdb..beb7cf15 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -1,7 +1,7 @@ # fmt: off from argparse import ArgumentParser, Namespace from functools import partial -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import pytorch_lightning as pl import torch @@ -11,7 +11,6 @@ from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from timm.data import Mixup as timm_Mixup from torch import nn -from torch.utils.data import Dataset from torchmetrics import Accuracy, CalibrationError, MetricCollection from torchmetrics.classification import ( BinaryAccuracy, @@ -72,7 +71,7 @@ def __init__( ood_detection: bool = False, use_entropy: bool = False, use_logits: bool = False, - calibration_set: Dataset | None = None, + calibration_set: Optional[Callable] = None, **kwargs, ) -> None: super().__init__() @@ -368,10 +367,9 @@ def test_epoch_end( def on_test_start(self) -> None: if self.calibration_set is not None: - with torch.enable_grad(): - self.scaler = TemperatureScaler(device=self.device).fit( - model=self.model, calibration_set=self.calibration_set() - ) + self.scaler = TemperatureScaler(device=self.device).fit( + model=self.model, calibration_set=self.calibration_set() + ) self.cal_model = torch.nn.Sequential(self.model, self.scaler) else: self.scaler = None diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index e551674d..6140b40c 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -1,6 +1,6 @@ # fmt: off from argparse import ArgumentParser, Namespace -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Type, Union import pytorch_lightning as pl import torch @@ -19,7 +19,7 @@ class RegressionSingle(pl.LightningModule): def __init__( self, model: nn.Module, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, dist_estimation: int, **kwargs, @@ -211,7 +211,7 @@ class RegressionEnsemble(RegressionSingle): def __init__( self, model: nn.Module, - loss: nn.Module, + loss: Type[nn.Module], optimization_procedure: Any, dist_estimation: int, num_estimators: int, diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index 709c9697..13388a47 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .checkpoints import get_version from .hub import load_hf +from .misc import csv_writter diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py new file mode 100644 index 00000000..ccf15055 --- /dev/null +++ b/torch_uncertainty/utils/misc.py @@ -0,0 +1,19 @@ +import csv + + +def csv_writter(path, dic): + # Check if the file already exists + if path.is_file(): + append_mode = True + rw_mode = "a" + else: + append_mode = False + rw_mode = "w" + + # Write dic + with open(path, rw_mode) as csvfile: + writer = csv.writer(csvfile, delimiter=",") + # Do not write header in append mode + if append_mode is False: + writer.writerow(dic.keys()) + writer.writerow([f"{elem:.4f}" for elem in dic.values()]) From 0103ce963c294fe5bd63ce003bafe1b428ff2d72 Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Fri, 20 Oct 2023 16:50:28 +0200 Subject: [PATCH 05/27] fix calibration in cv --- experiments/classification/cifar10/resnet.py | 7 +- .../classification/tiny-imagenet/resnet.py | 12 ++- torch_uncertainty/__init__.py | 3 +- torch_uncertainty/optimization_procedures.py | 99 +++++++++++++++++++ torch_uncertainty/routines/classification.py | 1 + 5 files changed, 114 insertions(+), 8 deletions(-) diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index a7fe361d..249b8080 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -25,11 +25,11 @@ dm = CIFAR10DataModule(**vars(args)) if args.opt_temp_scaling: - args.calibration_set = dm.get_test_set + calibration_set = dm.get_test_set elif args.val_temp_scaling: - args.calibration_set = dm.get_val_set + calibration_set = dm.get_val_set else: - args.calibration_set = None + calibration_set = None if args.use_cv: list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) @@ -44,6 +44,7 @@ f"resnet{args.arch}", "cifar10", args.version ), style="cifar", + calibration_set=calibration_set, **vars(args), ) ) diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index 39c886ed..632507ba 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -6,6 +6,7 @@ from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import TinyImageNetDataModule +from torch_uncertainty.optimization_procedures import get_procedure from torch_uncertainty.utils import csv_writter @@ -38,11 +39,11 @@ def optim_tiny(model: nn.Module) -> dict: dm = TinyImageNetDataModule(**vars(args)) if args.opt_temp_scaling: - args.calibration_set = dm.get_test_set + calibration_set = dm.get_test_set elif args.val_temp_scaling: - args.calibration_set = dm.get_val_set + calibration_set = dm.get_val_set else: - args.calibration_set = None + calibration_set = None if args.use_cv: list_dm = dm.make_cross_val_splits(args.n_splits, args.train_over) @@ -53,8 +54,11 @@ def optim_tiny(model: nn.Module) -> dict: num_classes=list_dm[i].dm.num_classes, in_channels=list_dm[i].dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_tiny, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "tiny-imagenet", args.version + ), style="cifar", + calibration_set=calibration_set, **vars(args), ) ) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 9e7c6c61..380e663c 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -135,11 +135,12 @@ def cli_main( name=net_name, default_hp_metric=False, log_graph=args.log_graph, - version=args.test, + version=f"fold_{i}", ) # callbacks save_checkpoints = ModelCheckpoint( + dirpath=tb_logger.log_dir, monitor=monitor, mode=mode, save_last=True, diff --git a/torch_uncertainty/optimization_procedures.py b/torch_uncertainty/optimization_procedures.py index 7481518a..0c66193a 100644 --- a/torch_uncertainty/optimization_procedures.py +++ b/torch_uncertainty/optimization_procedures.py @@ -19,6 +19,10 @@ "optim_imagenet_resnet50", "optim_imagenet_resnet50_A3", "optim_regression", + "optim_cifar10_resnet34", + "optim_cifar100_resnet34", + "optim_tinyimagenet_resnet34", + "optim_tinyimagenet_resnet50", ] @@ -229,6 +233,92 @@ def optim_imagenet_resnet50_A3( } +def optim_cifar10_resnet34( + model: nn.Module, +) -> Dict[str, Union[Optimizer, LRScheduler]]: + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[100, 150], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def optim_cifar100_resnet34( + model: nn.Module, +) -> Dict[str, Union[Optimizer, LRScheduler]]: + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[100, 150], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def optim_tinyimagenet_resnet34( + model: nn.Module, +) -> Dict[str, Union[Optimizer, LRScheduler]]: + """Optimization procedure from 'The Devil is in the Margin: Margin-based + Label Smoothing for Network Calibration', + (CVPR 2022, https://arxiv.org/abs/2111.15430): + "We train for 100 epochs with a learning rate of 0.1 for the first + 40 epochs, of 0.01 for the next 20 epochs and of 0.001 for the last + 40 epochs." + """ + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[40, 60], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def optim_tinyimagenet_resnet50( + model: nn.Module, +) -> Dict[str, Union[Optimizer, LRScheduler]]: + """Optimization procedure from 'The Devil is in the Margin: Margin-based + Label Smoothing for Network Calibration', + (CVPR 2022, https://arxiv.org/abs/2111.15430): + "We train for 100 epochs with a learning rate of 0.1 for the first + 40 epochs, of 0.01 for the next 20 epochs and of 0.001 for the last + 40 epochs." + """ + optimizer = optim.SGD( + model.parameters(), + lr=0.1, + momentum=0.9, + weight_decay=1e-4, + nesterov=True, + ) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[40, 60], + gamma=0.1, + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + def optim_regression( model: nn.Module, learning_rate: float = 1e-2, @@ -310,11 +400,20 @@ def get_procedure( procedure = optim_cifar100_resnet18 else: raise NotImplementedError(f"Dataset {ds_name} not implemented.") + elif arch_name == "resnet34": + if ds_name == "cifar10": + procedure = optim_cifar10_resnet34 + elif ds_name == "cifar100": + procedure = optim_cifar100_resnet34 + elif ds_name == "tiny-imagenet": + procedure = optim_tinyimagenet_resnet34 elif arch_name == "resnet50": if ds_name == "cifar10": procedure = optim_cifar10_resnet50 elif ds_name == "cifar100": procedure = optim_cifar100_resnet50 + elif ds_name == "tiny-imagenet": + procedure = optim_tinyimagenet_resnet50 elif ds_name == "imagenet": if imagenet_recipe is not None and imagenet_recipe == "A3": procedure = optim_imagenet_resnet50_A3 diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index beb7cf15..44f61f41 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -82,6 +82,7 @@ def __init__( "loss", "optimization_procedure", "format_batch_fn", + "calibration_set", ] ) From 45eea9d7585b97d2b3a153c7dc1989cc253056b7 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 23 Oct 2023 17:56:19 +0200 Subject: [PATCH 06/27] :heavy_plus_sign: Add sklearn as dependency --- poetry.lock | 97 ++++++++++++++++++++++++++++++++++++-------------- pyproject.toml | 1 + 2 files changed, 71 insertions(+), 27 deletions(-) diff --git a/poetry.lock b/poetry.lock index 6550a42f..83a4cd93 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1092,6 +1092,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "joblib" +version = "1.3.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, + {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, +] + [[package]] name = "kiwisolver" version = "1.4.5" @@ -1280,16 +1291,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -2187,7 +2188,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2195,15 +2195,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2220,7 +2213,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2228,7 +2220,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -2314,36 +2305,30 @@ python-versions = ">=3.6" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d92f81886165cb14d7b067ef37e142256f1c6a90a65cd156b063a43da1708cfd"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b5edda50e5e9e15e54a6a8a0070302b00c518a9d32accc2346ad6c984aacd279"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:7048c338b6c86627afb27faecf418768acb6331fc24cfa56c93e8c9780f815fa"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3fcc54cb0c8b811ff66082de1680b4b14cf8a81dce0d4fbf665c2265a81e07a1"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:665f58bfd29b167039f714c6998178d27ccd83984084c286110ef26b230f259f"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:9eb5dee2772b0f704ca2e45b1713e4e5198c18f515b52743576d196348f374d3"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"}, @@ -2469,6 +2454,53 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scikit-learn" +version = "1.3.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.8" +files = [ + {file = "scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05"}, + {file = "scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1"}, + {file = "scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a"}, + {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c"}, + {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161"}, + {file = "scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c"}, + {file = "scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66"}, + {file = "scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157"}, + {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb"}, + {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433"}, + {file = "scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b"}, + {file = "scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028"}, + {file = "scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5"}, + {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525"}, + {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c"}, + {file = "scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107"}, + {file = "scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93"}, + {file = "scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073"}, + {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d"}, + {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf"}, + {file = "scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0"}, + {file = "scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03"}, + {file = "scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e"}, + {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a"}, + {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9"}, + {file = "scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0"}, +] + +[package.dependencies] +joblib = ">=1.1.1" +numpy = ">=1.17.3,<2.0" +scipy = ">=1.5.0" +threadpoolctl = ">=2.0.0" + +[package.extras] +benchmark = ["matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "pandas (>=1.0.5)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)", "sphinx (>=6.0.0)", "sphinx-copybutton (>=0.5.2)", "sphinx-gallery (>=0.10.1)", "sphinx-prompt (>=1.3.0)", "sphinxext-opengraph (>=0.4.2)"] +examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"] +tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.16.2)"] + [[package]] name = "scipy" version = "1.11.3" @@ -2815,6 +2847,17 @@ files = [ {file = "tensorboard_data_server-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594"}, ] +[[package]] +name = "threadpoolctl" +version = "3.2.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"}, + {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, +] + [[package]] name = "timm" version = "0.9.7" @@ -3228,4 +3271,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "f36e5e3f9f5a237e0ddca967fe121ef0738aa85b4ea51e0379408abfcafabd50" +content-hash = "811c606402079cae4f9c1b178449d2b6be7697d975333a5ca5124f35160202b6" diff --git a/pyproject.toml b/pyproject.toml index c6b16571..18c60476 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ torchinfo = ">=1.7.1" scipy = "^1.10.0" huggingface-hub = "^0.14.1" pandas = "^2.0.3" +scikit-learn = "^1.3.2" [tool.poetry.group.dev] optional = true From 44fdd9eb9997375f58559b95991fba352b0ee9b6 Mon Sep 17 00:00:00 2001 From: Olivier Date: Mon, 23 Oct 2023 20:23:08 +0200 Subject: [PATCH 07/27] :shirt: Fix flake8 import issue on cls.py --- torch_uncertainty/routines/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index a48dac11..e28266a9 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -31,9 +31,9 @@ NegativeLogLikelihood, VariationRatio, ) +from ..plotting_utils import CalibrationPlot, plot_hist from ..post_processing import TemperatureScaler from ..transforms import Mixup, MixupIO, RegMixup, WarpingMixup -from ..plotting_utils import CalibrationPlot, plot_hist # fmt:on From 9ab5487f893cb65adfc1bf86e37021952e61208e Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Fri, 27 Oct 2023 11:31:29 +0200 Subject: [PATCH 08/27] fix transforms in tiny-imagenet + qol changes --- experiments/classification/cifar10/resnet.py | 1 + .../classification/tiny-imagenet/resnet.py | 5 +++- torch_uncertainty/__init__.py | 26 ++++++++++--------- .../datamodules/tiny_imagenet.py | 3 +-- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index 249b8080..b58796d1 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -62,6 +62,7 @@ f"resnet{args.arch}", "cifar10", args.version ), style="cifar", + calibration_set=calibration_set, **vars(args), ) diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index 632507ba..47a54bef 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -72,7 +72,10 @@ def optim_tiny(model: nn.Module) -> dict: num_classes=dm.num_classes, in_channels=dm.num_channels, loss=nn.CrossEntropyLoss, - optimization_procedure=optim_tiny, + optimization_procedure=get_procedure( + f"resnet{args.arch}", "tiny-imagenet", args.version + ), + calibration_set=calibration_set, style="cifar", **vars(args), ) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 1ed191f9..ee845ab1 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -37,6 +37,9 @@ def init_args( default=None, help="Run in test mode. Set to the checkpoint version number to test.", ) + parser.add_argument( + "--ckpt", type=int, default=None, help="The number of the checkpoint" + ) parser.add_argument( "--summary", dest="summary", @@ -67,10 +70,16 @@ def init_args( help="Name of the experiment folder", ) parser.add_argument( - "--opt_temp_scaling", action="store_true", default=False + "--opt_temp_scaling", + action="store_true", + default=False, + help="Compute optimal temperature on the test set", ) parser.add_argument( - "--val_temp_scaling", action="store_true", default=False + "--val_temp_scaling", + action="store_true", + default=False, + help="Compute temperature on the validation set", ) parser = pl.Trainer.add_argparse_args(parser) if network is not None: @@ -147,15 +156,6 @@ def cli_main( save_weights_only=not args.enable_resume, ) - if args.summary: - summary(network, input_size=list(datamodule.input_shape).insert(0, 1)) - test_values = {} - elif args.test is not None: # coverage: ignore - if args.test >= 0: - ckpt_file, _ = get_version( - root=(root / "logs" / net_name), version=args.test - ) - # Select the best model, monitor the lr and stop if NaN callbacks = [ save_checkpoints, @@ -231,7 +231,9 @@ def cli_main( elif args.test is not None: if args.test >= 0: ckpt_file, _ = get_version( - root=(root / net_name), version=args.test + root=(root / net_name), + version=args.test, + checkpoint=args.ckpt, ) test_values = trainer.test( network, datamodule=datamodule, ckpt_path=str(ckpt_file) diff --git a/torch_uncertainty/datamodules/tiny_imagenet.py b/torch_uncertainty/datamodules/tiny_imagenet.py index 1fa44f81..2149d3b3 100644 --- a/torch_uncertainty/datamodules/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/tiny_imagenet.py @@ -74,8 +74,7 @@ def __init__( self.transform_test = T.Compose( [ - T.Resize(72), - T.CenterCrop(64), + T.Resize(64), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] From 9c29f4dfcf38307793873ef2920771698369575b Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 31 Oct 2023 11:28:47 +0100 Subject: [PATCH 09/27] :bug: Fix logs not in 'log' folder & missing argument --- torch_uncertainty/__init__.py | 2 +- torch_uncertainty/datamodules/cifar100.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 4aa794de..336339d7 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -190,7 +190,7 @@ def cli_main( else: # logger tb_logger = TensorBoardLogger( - str(root), + str(root / "logs"), name=net_name, default_hp_metric=False, log_graph=args.log_graph, diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/cifar100.py index dc3729e5..ac13a3e1 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/cifar100.py @@ -71,6 +71,7 @@ def __init__( persistent_workers=persistent_workers, ) + self.evaluate_ood = evaluate_ood self.val_split = val_split self.num_dataloaders = num_dataloaders From de8a8c52a4899c180ce5c71aaaeb89cdf70ec8c7 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 31 Oct 2023 11:33:26 +0100 Subject: [PATCH 10/27] :fire: Remove cutmix linked test & useless light.modules --- tests/_dummies/baseline.py | 4 ++-- tests/routines/test_classification.py | 21 --------------------- torch_uncertainty/__init__.py | 4 ++-- 3 files changed, 4 insertions(+), 25 deletions(-) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index e6552e70..e43bfed1 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -17,7 +17,7 @@ from .model import dummy_model -class DummyClassificationBaseline(LightningModule): +class DummyClassificationBaseline: def __new__( cls, num_classes: int, @@ -62,7 +62,7 @@ def add_model_specific_args( return parser -class DummyRegressionBaseline(LightningModule): +class DummyRegressionBaseline: def __new__( cls, in_features: int, diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index fe5b0d2f..af547807 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -109,27 +109,6 @@ def test_cli_main_dummy_ood(self): ) cli_main(model, dm, root, "dummy", args) - with ArgvContext( - "file.py", "--evaluate_ood", "--entropy", "--cutmix", "0.5" - ): - args = init_args( - DummyClassificationBaseline, DummyClassificationDataModule - ) - - args.root = str(root / "data") - dm = DummyClassificationDataModule(**vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=DECLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - **vars(args), - ) - with pytest.raises(NotImplementedError): - cli_main(model, dm, root, "dummy", args) - def test_classification_failures(self): with pytest.raises(ValueError): ClassificationSingle( diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 336339d7..d91a6e00 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -2,7 +2,7 @@ from argparse import ArgumentParser, Namespace from collections import defaultdict from pathlib import Path -from typing import Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union import pytorch_lightning as pl import torch @@ -19,7 +19,7 @@ def init_args( - network: Optional[Type[pl.LightningModule]] = None, + network: Any = None, datamodule: Optional[Type[pl.LightningDataModule]] = None, ) -> Namespace: parser = ArgumentParser("torch-uncertainty") From ac3d307b93666fb0bc0df3a77d0099f4b2c10c52 Mon Sep 17 00:00:00 2001 From: Olivier Date: Tue, 31 Oct 2023 12:37:38 +0100 Subject: [PATCH 11/27] :bug: Fix test logs, types, lint, & format --- tests/routines/test_classification.py | 20 +++++++++---------- tests/routines/test_regression.py | 8 ++++---- torch_uncertainty/__init__.py | 10 +++++----- .../baselines/classification/resnet.py | 4 +++- .../baselines/classification/vgg.py | 4 +++- torch_uncertainty/baselines/regression/mlp.py | 4 +++- torch_uncertainty/datamodules/cifar10.py | 4 +++- torch_uncertainty/datamodules/cifar100.py | 8 ++++++-- torch_uncertainty/transforms/__init__.py | 3 +-- torch_uncertainty/transforms/mixup.py | 11 +++++----- 10 files changed, 44 insertions(+), 32 deletions(-) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 948cf215..3659463f 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -42,7 +42,7 @@ def test_cli_main_dummy_binary(self): baseline_type="single", **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext("file.py", "--logits"): args = init_args( @@ -60,7 +60,7 @@ def test_cli_main_dummy_binary(self): baseline_type="single", **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_cli_main_dummy_ood(self): root = Path(__file__).parent.absolute().parents[0] @@ -85,7 +85,7 @@ def test_cli_main_dummy_ood(self): baseline_type="single", **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext( "file.py", @@ -107,7 +107,7 @@ def test_cli_main_dummy_ood(self): baseline_type="single", **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext( "file.py", @@ -134,7 +134,7 @@ def test_cli_main_dummy_ood(self): **vars(args), ) with pytest.raises(NotImplementedError): - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_classification_failures(self): with pytest.raises(ValueError): @@ -174,7 +174,7 @@ def test_cli_main_dummy_binary(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext("file.py", "--mutual_information"): args = init_args( @@ -194,7 +194,7 @@ def test_cli_main_dummy_binary(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_cli_main_dummy_ood(self): root = Path(__file__).parent.absolute().parents[0] @@ -216,7 +216,7 @@ def test_cli_main_dummy_ood(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext("file.py", "--evaluate_ood", "--entropy"): args = init_args( @@ -236,7 +236,7 @@ def test_cli_main_dummy_ood(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) with ArgvContext("file.py", "--evaluate_ood", "--variation_ratio"): args = init_args( @@ -256,7 +256,7 @@ def test_cli_main_dummy_ood(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_classification_failures(self): with pytest.raises(ValueError): diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index f67e4507..12f9dbb9 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -36,7 +36,7 @@ def test_cli_main_dummy_dist(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_cli_main_dummy_dist_der(self): root = Path(__file__).parent.absolute().parents[0] @@ -62,7 +62,7 @@ def test_cli_main_dummy_dist_der(self): **vars(args), ) - cli_main(model, dm, root, "dummy_der", args) + cli_main(model, dm, root, "logs/dummy_der", args) def test_cli_main_dummy(self): root = Path(__file__).parent.absolute().parents[0] @@ -82,7 +82,7 @@ def test_cli_main_dummy(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) def test_regression_failures(self): with pytest.raises(ValueError): @@ -134,4 +134,4 @@ def test_cli_main_dummy(self): **vars(args), ) - cli_main(model, dm, root, "dummy", args) + cli_main(model, dm, root, "logs/dummy", args) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index f2d9100d..ad315821 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -2,7 +2,7 @@ from argparse import ArgumentParser, Namespace from collections import defaultdict from pathlib import Path -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import pytorch_lightning as pl import torch @@ -90,12 +90,12 @@ def init_args( def cli_main( - network: pl.LightningModule | list[pl.LightningModule], - datamodule: AbstractDataModule | list[AbstractDataModule], + network: pl.LightningModule | List[pl.LightningModule], + datamodule: AbstractDataModule | List[AbstractDataModule], root: Union[Path, str], net_name: str, args: Namespace, -) -> list[Dict]: +) -> List[Dict]: if isinstance(root, str): root = Path(root) @@ -190,7 +190,7 @@ def cli_main( else: # logger tb_logger = TensorBoardLogger( - str(root / "logs"), + str(root), name=net_name, default_hp_metric=False, log_graph=args.log_graph, diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index e283f983..4921ea91 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -292,7 +292,9 @@ def load_from_checkpoint( elif extension.lower() in ("yml", "yaml"): hparams = load_hparams_from_yaml(hparams_file) else: - raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + raise ValueError( + ".csv, .yml or .yaml is required for `hparams_file`" + ) hparams.update(kwargs) checkpoint = torch.load(checkpoint_path) diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index d6a30099..af49ccb0 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -198,7 +198,9 @@ def load_from_checkpoint( elif extension.lower() in ("yml", "yaml"): hparams = load_hparams_from_yaml(hparams_file) else: - raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + raise ValueError( + ".csv, .yml or .yaml is required for `hparams_file`" + ) hparams.update(kwargs) checkpoint = torch.load(checkpoint_path) diff --git a/torch_uncertainty/baselines/regression/mlp.py b/torch_uncertainty/baselines/regression/mlp.py index 20a9fd33..6bda52d4 100644 --- a/torch_uncertainty/baselines/regression/mlp.py +++ b/torch_uncertainty/baselines/regression/mlp.py @@ -89,7 +89,9 @@ def load_from_checkpoint( elif extension.lower() in ("yml", "yaml"): hparams = load_hparams_from_yaml(hparams_file) else: - raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + raise ValueError( + ".csv, .yml or .yaml is required for `hparams_file`" + ) hparams.update(kwargs) checkpoint = torch.load(checkpoint_path) diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index 9d9a8b43..ab7e2091 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -233,6 +233,8 @@ def add_argparse_args( p.add_argument("--cutout", type=int, default=0) p.add_argument("--auto_augment", type=str) p.add_argument("--test_alt", choices=["c", "h"], default=None) - p.add_argument("--severity", dest="corruption_severity", type=int, default=None) + p.add_argument( + "--severity", dest="corruption_severity", type=int, default=None + ) p.add_argument("--evaluate_ood", action="store_true") return parent_parser diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/cifar100.py index ac13a3e1..c1869836 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/cifar100.py @@ -86,7 +86,9 @@ def __init__( self.corruption_severity = corruption_severity - if (cutout is not None) + randaugment + int(auto_augment is not None) > 1: + if (cutout is not None) + randaugment + int( + auto_augment is not None + ) > 1: raise ValueError( "Only one data augmentation can be chosen at a time. Raise a " "GitHub issue if needed." @@ -232,6 +234,8 @@ def add_argparse_args( p.add_argument("--randaugment", dest="randaugment", action="store_true") p.add_argument("--auto_augment", type=str) p.add_argument("--test_alt", choices=["c"], default=None) - p.add_argument("--severity", dest="corruption_severity", type=int, default=1) + p.add_argument( + "--severity", dest="corruption_severity", type=int, default=1 + ) p.add_argument("--evaluate_ood", action="store_true") return parent_parser diff --git a/torch_uncertainty/transforms/__init__.py b/torch_uncertainty/transforms/__init__.py index 511d3d21..a55bdad6 100644 --- a/torch_uncertainty/transforms/__init__.py +++ b/torch_uncertainty/transforms/__init__.py @@ -1,5 +1,6 @@ # ruff: noqa: F401 from .cutout import Cutout +from .mixup import Mixup, MixupIO, RegMixup, WarpingMixup from .transforms import ( AutoContrast, Brightness, @@ -29,5 +30,3 @@ Color, Sharpness, ] - -from .mixup import Mixup, MixupIO, RegMixup, WarpingMixup diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index d5fc165b..81b52265 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -1,3 +1,4 @@ +from typing import Tuple import scipy import torch import torch.nn.functional as F @@ -131,14 +132,14 @@ def _mix_target( def __call__( self, x: torch.Tensor, y: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: return x, y class Mixup(AbstractMixup): def __call__( self, x: torch.Tensor, y: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: lam, index = self._get_params(x.size()[0], x.device) mixed_x = self._linear_mixing(lam, x, index) @@ -151,7 +152,7 @@ def __call__( class MixupIO(AbstractMixup): def __call__( self, x: torch.Tensor, y: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: lam, index = self._get_params(x.size()[0], x.device) mixed_x = self._linear_mixing(lam, x, index) @@ -164,7 +165,7 @@ def __call__( class RegMixup(AbstractMixup): def __call__( self, x: torch.Tensor, y: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: lam, index = self._get_params(x.size()[0], x.device) part_x = self._linear_mixing(lam, x, index) @@ -208,7 +209,7 @@ def __call__( y: torch.Tensor, feats: torch.Tensor, warp_param=1.0, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: lam, index = self._get_params(x.size()[0], x.device) if self.apply_kernel: From 210105be2438e5faf94df6d048c4a4142d3b5f10 Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Tue, 31 Oct 2023 17:37:12 +0100 Subject: [PATCH 12/27] fix cross-val for c10 / c100 + test on mixup --- experiments/classification/cifar10/resnet.py | 4 +- .../classification/tiny-imagenet/resnet.py | 4 +- tests/transforms/test_mixup.py | 70 ++++++++++ torch_uncertainty/__init__.py | 15 ++- torch_uncertainty/datamodules/cifar10.py | 10 +- torch_uncertainty/datamodules/cifar100.py | 10 +- torch_uncertainty/transforms/mixup.py | 125 +++++++++--------- torch_uncertainty/utils/__init__.py | 2 +- torch_uncertainty/utils/misc.py | 2 +- 9 files changed, 163 insertions(+), 79 deletions(-) create mode 100644 tests/transforms/test_mixup.py diff --git a/experiments/classification/cifar10/resnet.py b/experiments/classification/cifar10/resnet.py index 99d74a00..1540c0bf 100644 --- a/experiments/classification/cifar10/resnet.py +++ b/experiments/classification/cifar10/resnet.py @@ -6,7 +6,7 @@ from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import CIFAR10DataModule from torch_uncertainty.optimization_procedures import get_procedure -from torch_uncertainty.utils import csv_writter +from torch_uncertainty.utils import csv_writer if __name__ == "__main__": args = init_args(ResNet, CIFAR10DataModule) @@ -67,7 +67,7 @@ results = cli_main(model, dm, args.exp_dir, args.exp_name, args) for dict_result in results: - csv_writter( + csv_writer( Path(args.exp_dir) / Path(args.exp_name) / "results.csv", dict_result, ) diff --git a/experiments/classification/tiny-imagenet/resnet.py b/experiments/classification/tiny-imagenet/resnet.py index 8825f26d..2503b3da 100644 --- a/experiments/classification/tiny-imagenet/resnet.py +++ b/experiments/classification/tiny-imagenet/resnet.py @@ -6,7 +6,7 @@ from torch_uncertainty.baselines import ResNet from torch_uncertainty.datamodules import TinyImageNetDataModule from torch_uncertainty.optimization_procedures import get_procedure -from torch_uncertainty.utils import csv_writter +from torch_uncertainty.utils import csv_writer def optim_tiny(model: nn.Module) -> dict: @@ -81,7 +81,7 @@ def optim_tiny(model: nn.Module) -> dict: results = cli_main(model, dm, args.exp_dir, args.exp_name, args) for dict_result in results: - csv_writter( + csv_writer( Path(args.exp_dir) / Path(args.exp_name) / "results.csv", dict_result, ) diff --git a/tests/transforms/test_mixup.py b/tests/transforms/test_mixup.py new file mode 100644 index 00000000..7873f52f --- /dev/null +++ b/tests/transforms/test_mixup.py @@ -0,0 +1,70 @@ +from typing import Tuple + +import pytest +import torch + +from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup + + +@pytest.fixture +def batch_input() -> Tuple[torch.Tensor, torch.Tensor]: + imgs = torch.rand(2, 3, 28, 28) + return imgs, torch.tensor([0, 1]) + + +class TestMixup: + """Testing Mixup augmentation""" + + def test_batch_mixup(self, batch_input): + mixup = Mixup(alpha=1.0, mode="batch", num_classes=2) + _ = mixup(*batch_input) + + def test_elem_mixup(self, batch_input): + mixup = Mixup(alpha=1.0, mode="elem", num_classes=2) + _ = mixup(*batch_input) + + +class TestMixupIO: + """Testing MixupIO augmentation""" + + def test_batch_mixupio(self, batch_input): + mixup = MixupIO(alpha=1.0, mode="batch", num_classes=2) + _ = mixup(*batch_input) + + def test_elem_mixupio(self, batch_input): + mixup = MixupIO(alpha=1.0, mode="elem", num_classes=2) + _ = mixup(*batch_input) + + +class TestRegMixup: + """Testing RegMixup augmentation""" + + def test_batch_regmixup(self, batch_input): + mixup = RegMixup(alpha=1.0, mode="batch", num_classes=2) + _ = mixup(*batch_input) + + def test_elem_regmixup(self, batch_input): + mixup = RegMixup(alpha=1.0, mode="elem", num_classes=2) + _ = mixup(*batch_input) + + +class TestWarpingMixup: + """Testing WarpingMixup augmentation""" + + def test_batch_kernel_warpingmixup(self, batch_input): + mixup = WarpingMixup( + alpha=1.0, mode="batch", num_classes=2, apply_kernel=True + ) + _ = mixup(*batch_input, batch_input[0]) + + def test_elem_kernel_warpingmixup(self, batch_input): + mixup = WarpingMixup( + alpha=1.0, mode="elem", num_classes=2, apply_kernel=True + ) + _ = mixup(*batch_input, batch_input[0]) + + def test_elem_warpingmixup(self, batch_input): + mixup = WarpingMixup( + alpha=1.0, mode="elem", num_classes=2, apply_kernel=False + ) + _ = mixup(*batch_input, batch_input[0]) diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index ad315821..9b7874fb 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -172,10 +172,17 @@ def cli_main( args.opt_temp_scaling or args.val_temp_scaling ), ) - trainer.fit(network[i], datamodule[i]) - test_values.append( - trainer.test(datamodule=datamodule[i], ckpt_path="last")[0] - ) + if args.summary: + summary( + network[i], + input_size=list(datamodule[i].dm.input_shape).insert(0, 1), + ) + test_values.append({}) + else: + trainer.fit(network[i], datamodule[i]) + test_values.append( + trainer.test(datamodule=datamodule[i], ckpt_path="last")[0] + ) all_test_values = defaultdict(list) for test_value in test_values: diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index ab7e2091..09c978b2 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -216,10 +216,16 @@ def test_dataloader(self) -> List[DataLoader]: return dataloader def _get_train_data(self) -> ArrayLike: - return self.train.dataset.data[self.train.indices] + if self.val_split: + return self.train.dataset.data[self.train.indices] + else: + return self.train.data def _get_train_targets(self) -> ArrayLike: - return np.array(self.train.dataset.targets)[self.train.indices] + if self.val_split: + return np.array(self.train.dataset.targets)[self.train.indices] + else: + return np.array(self.train.targets) @classmethod def add_argparse_args( diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/cifar100.py index c1869836..b8314976 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/cifar100.py @@ -216,10 +216,16 @@ def test_dataloader(self) -> List[DataLoader]: return dataloader def _get_train_data(self) -> ArrayLike: - return self.train.dataset.data[self.train.indices] + if self.val_split: + return self.train.dataset.data[self.train.indices] + else: + return self.train.data def _get_train_targets(self) -> ArrayLike: - return np.array(self.train.dataset.targets)[self.train.indices] + if self.val_split: + return np.array(self.train.dataset.targets)[self.train.indices] + else: + return np.array(self.train.targets) @classmethod def add_argparse_args( diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index 81b52265..5cb0e82c 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -2,14 +2,9 @@ import scipy import torch import torch.nn.functional as F -from torch import Tensor import numpy as np -# TODO: torch beta warping (with tensor linspace + approx beta cdf using trapz) -# TODO: Mixup with roll to be more efficient (remove sampling of index) -# TODO: MIT and Rank Mixup - def beta_warping(x, alpha_cdf=1.0, eps=1e-12): return scipy.stats.beta.cdf(x, a=alpha_cdf + eps, b=alpha_cdf + eps) @@ -22,66 +17,66 @@ def sim_gauss_kernel(dist, tau_max=1.0, tau_std=0.5): return 1 / (dist_rate + 1e-12) -def tensor_linspace(start: Tensor, stop: Tensor, num: int): - """ - Creates a tensor of shape [num, *start.shape] whose values are evenly - spaced from start to end, inclusive. - Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. - """ - # create a tensor of 'num' steps from 0 to 1 - steps = torch.arange(num, dtype=torch.float32, device=start.device) / ( - num - 1 - ) - - # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] - # to allow for broadcastings - # using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here - # but torchscript - # "cannot statically infer the expected size of a list in this contex", - # hence the code below - for i in range(start.ndim): - steps = steps.unsqueeze(-1) - - # the output starts at 'start' and increments until 'stop' in each dimension - out = start[None] + steps * (stop - start)[None] - - return out - - -def torch_beta_cdf( - x: Tensor, c1: Tensor | float, c2: Tensor | float, npts=100, eps=1e-12 -): - if isinstance(c1, float): - if c1 == c2: - c1 = torch.tensor([c1], device=x.device) - c2 = c1 - else: - c1 = torch.tensor([c1], device=x.device) - if isinstance(c2, float): - c2 = torch.tensor([c2], device=x.device) - bt = torch.distributions.Beta(c1, c2) - - if isinstance(x, float): - x = torch.tensor(x) - - X = tensor_linspace(torch.zeros_like(x) + eps, x, npts) - return torch.trapezoid(bt.log_prob(X).exp(), X, dim=0) - - -def torch_beta_warping( - x: Tensor, alpha_cdf: float | Tensor = 1.0, eps=1e-12, npts=100 -): - return torch_beta_cdf( - x=x, c1=alpha_cdf + eps, c2=alpha_cdf + eps, npts=npts, eps=eps - ) - - -def torch_sim_gauss_kernel(dist: Tensor, tau_max=1.0, tau_std=0.5): - dist_rate = tau_max * torch.exp( - -(dist - 1) / (torch.mean(dist) * 2 * tau_std * tau_std) - ) - - return 1 / (dist_rate + 1e-12) +# def tensor_linspace(start: Tensor, stop: Tensor, num: int): +# """ +# Creates a tensor of shape [num, *start.shape] whose values are evenly +# spaced from start to end, inclusive. +# Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. +# """ +# # create a tensor of 'num' steps from 0 to 1 +# steps = torch.arange(num, dtype=torch.float32, device=start.device) / ( +# num - 1 +# ) + +# # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] +# # to allow for broadcastings +# # using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here +# # but torchscript +# # "cannot statically infer the expected size of a list in this contex", +# # hence the code below +# for i in range(start.ndim): +# steps = steps.unsqueeze(-1) + +# # the output starts at 'start' and increments until 'stop' in each dimension +# out = start[None] + steps * (stop - start)[None] + +# return out + + +# def torch_beta_cdf( +# x: Tensor, c1: Tensor | float, c2: Tensor | float, npts=100, eps=1e-12 +# ): +# if isinstance(c1, float): +# if c1 == c2: +# c1 = torch.tensor([c1], device=x.device) +# c2 = c1 +# else: +# c1 = torch.tensor([c1], device=x.device) +# if isinstance(c2, float): +# c2 = torch.tensor([c2], device=x.device) +# bt = torch.distributions.Beta(c1, c2) + +# if isinstance(x, float): +# x = torch.tensor(x) + +# X = tensor_linspace(torch.zeros_like(x) + eps, x, npts) +# return torch.trapezoid(bt.log_prob(X).exp(), X, dim=0) + + +# def torch_beta_warping( +# x: Tensor, alpha_cdf: float | Tensor = 1.0, eps=1e-12, npts=100 +# ): +# return torch_beta_cdf( +# x=x, c1=alpha_cdf + eps, c2=alpha_cdf + eps, npts=npts, eps=eps +# ) + + +# def torch_sim_gauss_kernel(dist: Tensor, tau_max=1.0, tau_std=0.5): +# dist_rate = tau_max * torch.exp( +# -(dist - 1) / (torch.mean(dist) * 2 * tau_std * tau_std) +# ) + +# return 1 / (dist_rate + 1e-12) class AbstractMixup: diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index ed7334eb..e6b70312 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -1,4 +1,4 @@ # ruff: noqa: F401 from .checkpoints import get_version from .hub import load_hf -from .misc import csv_writter +from .misc import csv_writer diff --git a/torch_uncertainty/utils/misc.py b/torch_uncertainty/utils/misc.py index ccf15055..19feab91 100644 --- a/torch_uncertainty/utils/misc.py +++ b/torch_uncertainty/utils/misc.py @@ -1,7 +1,7 @@ import csv -def csv_writter(path, dic): +def csv_writer(path, dic): # Check if the file already exists if path.is_file(): append_mode = True From 75c28169daf7a5a85cfa187703d0bd0382b269ca Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Tue, 31 Oct 2023 20:44:48 +0100 Subject: [PATCH 13/27] add tests crossval --- tests/_dummies/dataset.py | 3 ++ tests/datamodules/test_cifar100_datamodule.py | 17 ++++++++ tests/datamodules/test_cifar10_datamodule.py | 17 ++++++++ .../test_tiny_imagenet_datamodule.py | 19 +++++++- tests/test_cli.py | 43 +++++++++++++++++++ 5 files changed, 98 insertions(+), 1 deletion(-) diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 4f5c4286..3ba358b5 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -70,6 +70,9 @@ def __init__( num_images // (num_classes) + 1 )[:num_images] + self.samples = self.data # for compatibility with TinyImagenet + self.label_data = self.targets + def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: diff --git a/tests/datamodules/test_cifar100_datamodule.py b/tests/datamodules/test_cifar100_datamodule.py index 80e44fba..11156ff5 100644 --- a/tests/datamodules/test_cifar100_datamodule.py +++ b/tests/datamodules/test_cifar100_datamodule.py @@ -77,3 +77,20 @@ def test_cifar100(self): args.auto_augment = "rand-m9-n2-mstd0.5" dm = CIFAR100DataModule(**vars(args)) + + def test_cifar100_cv(self): + parser = ArgumentParser() + parser = CIFAR100DataModule.add_argparse_args(parser) + + # Simulate that cutout is set to 8 + args = parser.parse_args("") + + dm = CIFAR100DataModule(**vars(args)) + dm.dataset = lambda root, train, download, transform: DummyClassificationDataset(root, train=train, download=download, transform=transform, num_images=20) + dm.make_cross_val_splits(2,1) + + + args.val_split = 0.1 + dm = CIFAR100DataModule(**vars(args)) + dm.dataset = lambda root, train, download, transform: DummyClassificationDataset(root, train=train, download=download, transform=transform, num_images=20) + dm.make_cross_val_splits(2,1) diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/test_cifar10_datamodule.py index fef639b9..20dc0644 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/test_cifar10_datamodule.py @@ -67,3 +67,20 @@ def test_CIFAR10_cutout(self): args.cutout = None args.auto_augment = "rand-m9-n2-mstd0.5" dm = CIFAR10DataModule(**vars(args)) + + def test_cifar10_cv(self): + parser = ArgumentParser() + parser = CIFAR10DataModule.add_argparse_args(parser) + + # Simulate that cutout is set to 8 + args = parser.parse_args("") + + dm = CIFAR10DataModule(**vars(args)) + dm.dataset = lambda root, train, download, transform: DummyClassificationDataset(root, train=train, download=download, transform=transform, num_images=20) + dm.make_cross_val_splits(2,1) + + + args.val_split = 0.1 + dm = CIFAR10DataModule(**vars(args)) + dm.dataset = lambda root, train, download, transform: DummyClassificationDataset(root, train=train, download=download, transform=transform, num_images=20) + dm.make_cross_val_splits(2,1) diff --git a/tests/datamodules/test_tiny_imagenet_datamodule.py b/tests/datamodules/test_tiny_imagenet_datamodule.py index f7d4a865..bfafea72 100644 --- a/tests/datamodules/test_tiny_imagenet_datamodule.py +++ b/tests/datamodules/test_tiny_imagenet_datamodule.py @@ -13,7 +13,7 @@ class TestTinyImageNetDataModule: """Testing the TinyImageNetDataModule datamodule class.""" - def test_imagenet(self): + def test_tiny_imagenet(self): parser = ArgumentParser() parser = TinyImageNetDataModule.add_argparse_args(parser) @@ -48,3 +48,20 @@ def test_imagenet(self): dm.prepare_data() dm.setup("test") dm.test_dataloader() + + def test_tiny_imagenet_cv(self): + parser = ArgumentParser() + parser = TinyImageNetDataModule.add_argparse_args(parser) + + # Simulate that cutout is set to 8 + args = parser.parse_args("") + + dm = TinyImageNetDataModule(**vars(args)) + dm.dataset = lambda root, split, transform: DummyClassificationDataset(root, split=split, transform=transform, num_images=20) + dm.make_cross_val_splits(2,1) + + + args.val_split = 0.1 + dm = TinyImageNetDataModule(**vars(args)) + dm.dataset = lambda root, split, transform: DummyClassificationDataset(root, split=split, transform=transform, num_images=20) + dm.make_cross_val_splits(2,1) diff --git a/tests/test_cli.py b/tests/test_cli.py index 73d6d74f..18b4f5a1 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -16,6 +16,8 @@ optim_regression, ) +from ._dummies.dataset import DummyClassificationDataset + class TestCLI: """Testing the CLI function.""" @@ -163,6 +165,47 @@ def test_cli_other_training_task(self): with pytest.raises(ValueError): cli_main(model, dm, root, "std", args) + def test_cli_cv_ts(self): + root = Path(__file__).parent.absolute().parents[0] + with ArgvContext("file.py", "--use_cv"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + def test_init_args_void(self): with ArgvContext("file.py"): init_args() From 57c07c27e94138beef53757e573be73011574794 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 1 Nov 2023 12:36:59 +0100 Subject: [PATCH 14/27] :white_check_mark: Add some tests --- tests/datamodules/test_abstract_datamodule.py | 35 ++++ tests/datamodules/test_cifar10_datamodule.py | 7 +- tests/test_cli.py | 158 ++++++++++++++++++ tests/test_optimization_procedures.py | 25 ++- torch_uncertainty/datamodules/abstract.py | 32 ++-- 5 files changed, 231 insertions(+), 26 deletions(-) create mode 100644 tests/datamodules/test_abstract_datamodule.py diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py new file mode 100644 index 00000000..62315bd7 --- /dev/null +++ b/tests/datamodules/test_abstract_datamodule.py @@ -0,0 +1,35 @@ +import pytest + +from torch_uncertainty.datamodules.abstract import ( + AbstractDataModule, + CrossValDataModule, +) + +from .._dummies.dataset import DummyClassificationDataset + + +class TestAbstractDataModule: + """Testing the AbstractDataModule class.""" + + def test_errors(self): + dm = AbstractDataModule("root", 128, 4, True, True) + with pytest.raises(NotImplementedError): + dm.setup() + dm._get_train_data() + dm._get_train_targets() + + +class TestCrossValDataModule: + """Testing the CrossValDataModule class.""" + + def test_errors(self): + dm = AbstractDataModule("root", 128, 4, True, True) + ds = DummyClassificationDataset("root") + dm.train = ds + dm.val = ds + dm.test = ds + cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 4, True, True) + with pytest.raises(NotImplementedError): + cv_dm.setup() + cv_dm._get_train_data() + cv_dm._get_train_targets() diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/test_cifar10_datamodule.py index 20dc0644..a5a10dd8 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/test_cifar10_datamodule.py @@ -14,7 +14,7 @@ class TestCIFAR10DataModule: """Testing the CIFAR10DataModule datamodule class.""" - def test_CIFAR10_cutout(self): + def test_CIFAR10_main(self): parser = ArgumentParser() parser = CIFAR10DataModule.add_argparse_args(parser) @@ -34,6 +34,11 @@ def test_CIFAR10_cutout(self): dm.setup() dm.setup("test") + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() diff --git a/tests/test_cli.py b/tests/test_cli.py index 18b4f5a1..185007b2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -206,6 +206,164 @@ def test_cli_cv_ts(self): cli_main(list_model, list_dm, root, "std", args) + with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + + with ArgvContext("file.py", "--use_cv", "--mixtype", "mixup_io"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + + with ArgvContext("file.py", "--use_cv", "--mixtype", "regmixup"): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + + with ArgvContext( + "file.py", "--use_cv", "--mixtype", "kernel_warping" + ): + args = init_args(ResNet, CIFAR10DataModule) + + # datamodule + args.root = str(root / "data") + dm = CIFAR10DataModule(**vars(args)) + + # Simulate that summary is True & the only argument + args.summary = True + + dm.dataset = ( + lambda root, + train, + download, + transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + ResNet( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + style="cifar", + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "std", args) + def test_init_args_void(self): with ArgvContext("file.py"): init_args() diff --git a/tests/test_optimization_procedures.py b/tests/test_optimization_procedures.py index 8fb63011..3ecd3506 100644 --- a/tests/test_optimization_procedures.py +++ b/tests/test_optimization_procedures.py @@ -1,7 +1,7 @@ # ruff: noqa: F401 import pytest -from torch_uncertainty.models.resnet import resnet18, resnet50 +from torch_uncertainty.models.resnet import resnet18, resnet34, resnet50 from torch_uncertainty.models.vgg import vgg16 from torch_uncertainty.models.wideresnet import wideresnet28x10 from torch_uncertainty.optimization_procedures import ( @@ -11,46 +11,53 @@ class TestOptProcedures: - def test_optim_cifar10_resnet18(self): + def test_optim_cifar10(self): procedure = get_procedure("resnet18", "cifar10", "standard") model = resnet18(in_channels=3, num_classes=10) procedure(model) - def test_optim_cifar10_resnet50(self): procedure = get_procedure("resnet50", "cifar10", "packed") model = resnet50(in_channels=3, num_classes=10) procedure(model) - def test_optim_cifar10_wideresnet(self): procedure = get_procedure("wideresnet28x10", "cifar10", "batched") model = wideresnet28x10(in_channels=3, num_classes=10) procedure(model) - def test_optim_cifar10_vgg16(self): procedure = get_procedure("vgg16", "cifar10", "standard") model = vgg16(in_channels=3, num_classes=10) procedure(model) - def test_optim_cifar100_resnet18(self): + def test_optim_cifar100(self): procedure = get_procedure("resnet18", "cifar100", "masked") model = resnet18(in_channels=3, num_classes=100) procedure(model) - def test_optim_cifar100_resnet50(self): + procedure = get_procedure("resnet34", "cifar100", "masked") + model = resnet34(in_channels=3, num_classes=100) + procedure(model) + procedure = get_procedure("resnet50", "cifar100") model = resnet50(in_channels=3, num_classes=100) procedure(model) - def test_optim_cifar100_wideresnet(self): procedure = get_procedure("wideresnet28x10", "cifar100") model = wideresnet28x10(in_channels=3, num_classes=100) procedure(model) - def test_optim_cifar100_vgg16(self): procedure = get_procedure("vgg16", "cifar100", "standard") model = vgg16(in_channels=3, num_classes=100) procedure(model) + def test_optim_tinyimagenet(self): + procedure = get_procedure("resnet34", "tiny-imagenet", "standard") + model = resnet34(in_channels=3, num_classes=1000) + procedure(model) + + procedure = get_procedure("resnet50", "tiny-imagenet", "standard") + model = resnet50(in_channels=3, num_classes=1000) + procedure(model) + def test_optim_imagenet_resnet50(self): procedure = get_procedure("resnet50", "imagenet", "standard", "A3") model = resnet50(in_channels=3, num_classes=1000) diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 1fa2175a..98356f23 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -11,7 +11,7 @@ class AbstractDataModule(LightningDataModule): - training_task = "" + training_task: str def __init__( self, @@ -34,9 +34,7 @@ def __init__( self.persistent_workers = persistent_workers def setup(self, stage: Optional[str] = None) -> None: - self.train = Dataset() - self.val = Dataset() - self.test = Dataset() + raise NotImplementedError() def get_train_set(self) -> Dataset: return self.train @@ -100,12 +98,14 @@ def _data_loader( # It is generally "Dataset.samples" or "Dataset.data" # They are used for constructing cross validation splits def _get_train_data(self) -> ArrayLike: - pass + raise NotImplementedError() def _get_train_targets(self) -> ArrayLike: - pass + raise NotImplementedError() - def make_cross_val_splits(self, n_splits=10, train_over=4) -> list: + def make_cross_val_splits( + self, n_splits: int = 10, train_over: int = 4 + ) -> List: self.setup("fit") skf = StratifiedKFold(n_splits) cv_dm = [] @@ -192,15 +192,6 @@ def _data_loader(self, dataset: Dataset, idx: ArrayLike) -> DataLoader: persistent_workers=self.persistent_workers, ) - def train_dataloader(self) -> DataLoader: - return self._data_loader(self.dm.get_train_set(), self.train_idx) - - def val_dataloader(self) -> DataLoader: - return self._data_loader(self.dm.get_train_set(), self.val_idx) - - def test_dataloader(self) -> DataLoader: - return self._data_loader(self.dm.get_train_set(), self.val_idx) - def get_train_set(self) -> Dataset: return self.dm.train @@ -209,3 +200,12 @@ def get_test_set(self) -> Dataset: def get_val_set(self) -> Dataset: return self.dm.val + + def train_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.train_idx) + + def val_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.val_idx) + + def test_dataloader(self) -> DataLoader: + return self._data_loader(self.dm.get_train_set(), self.val_idx) From 8321322f7f16327de3a6ccda0b2da6c80a7ebe71 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 1 Nov 2023 12:40:46 +0100 Subject: [PATCH 15/27] :bug: Fix format error --- tests/_dummies/dataset.py | 2 +- tests/datamodules/test_cifar100_datamodule.py | 25 +++++++++++++++---- tests/datamodules/test_cifar10_datamodule.py | 25 +++++++++++++++---- .../test_tiny_imagenet_datamodule.py | 13 ++++++---- 4 files changed, 49 insertions(+), 16 deletions(-) diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index a2029af9..44db5dfc 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -69,7 +69,7 @@ def __init__( num_images // (num_classes) + 1 )[:num_images] - self.samples = self.data # for compatibility with TinyImagenet + self.samples = self.data # for compatibility with TinyImagenet self.label_data = self.targets def __getitem__(self, index: int) -> Tuple[Any, Any]: diff --git a/tests/datamodules/test_cifar100_datamodule.py b/tests/datamodules/test_cifar100_datamodule.py index 4b208d55..0fa48f0c 100644 --- a/tests/datamodules/test_cifar100_datamodule.py +++ b/tests/datamodules/test_cifar100_datamodule.py @@ -84,11 +84,26 @@ def test_cifar100_cv(self): args = parser.parse_args("") dm = CIFAR100DataModule(**vars(args)) - dm.dataset = lambda root, train, download, transform: DummyClassificationDataset(root, train=train, download=download, transform=transform, num_images=20) - dm.make_cross_val_splits(2,1) - + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) args.val_split = 0.1 dm = CIFAR100DataModule(**vars(args)) - dm.dataset = lambda root, train, download, transform: DummyClassificationDataset(root, train=train, download=download, transform=transform, num_images=20) - dm.make_cross_val_splits(2,1) + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/test_cifar10_datamodule.py index a5f4d2a9..62a5c336 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/test_cifar10_datamodule.py @@ -79,11 +79,26 @@ def test_cifar10_cv(self): args = parser.parse_args("") dm = CIFAR10DataModule(**vars(args)) - dm.dataset = lambda root, train, download, transform: DummyClassificationDataset(root, train=train, download=download, transform=transform, num_images=20) - dm.make_cross_val_splits(2,1) - + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) args.val_split = 0.1 dm = CIFAR10DataModule(**vars(args)) - dm.dataset = lambda root, train, download, transform: DummyClassificationDataset(root, train=train, download=download, transform=transform, num_images=20) - dm.make_cross_val_splits(2,1) + dm.dataset = ( + lambda root, train, download, transform: DummyClassificationDataset( + root, + train=train, + download=download, + transform=transform, + num_images=20, + ) + ) + dm.make_cross_val_splits(2, 1) diff --git a/tests/datamodules/test_tiny_imagenet_datamodule.py b/tests/datamodules/test_tiny_imagenet_datamodule.py index 2306b6e7..aa36fdbf 100644 --- a/tests/datamodules/test_tiny_imagenet_datamodule.py +++ b/tests/datamodules/test_tiny_imagenet_datamodule.py @@ -55,11 +55,14 @@ def test_tiny_imagenet_cv(self): args = parser.parse_args("") dm = TinyImageNetDataModule(**vars(args)) - dm.dataset = lambda root, split, transform: DummyClassificationDataset(root, split=split, transform=transform, num_images=20) - dm.make_cross_val_splits(2,1) - + dm.dataset = lambda root, split, transform: DummyClassificationDataset( + root, split=split, transform=transform, num_images=20 + ) + dm.make_cross_val_splits(2, 1) args.val_split = 0.1 dm = TinyImageNetDataModule(**vars(args)) - dm.dataset = lambda root, split, transform: DummyClassificationDataset(root, split=split, transform=transform, num_images=20) - dm.make_cross_val_splits(2,1) + dm.dataset = lambda root, split, transform: DummyClassificationDataset( + root, split=split, transform=transform, num_images=20 + ) + dm.make_cross_val_splits(2, 1) From 2de5c243e6ba2b6e87f7b79051e3d2c2c7c9dbc4 Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 1 Nov 2023 13:10:58 +0100 Subject: [PATCH 16/27] :white_check_mark: Improve cvg., fix bug, and sort imports --- tests/_dummies/dataset.py | 3 +- tests/routines/test_regression.py | 4 +- tests/test_cli.py | 8 +- tests/test_losses.py | 2 +- tests/test_optimization_procedures.py | 4 + tests/transforms/test_transforms.py | 2 +- torch_uncertainty/__init__.py | 3 +- torch_uncertainty/datamodules/abstract.py | 3 +- torch_uncertainty/datamodules/cifar10.py | 5 +- torch_uncertainty/datamodules/cifar100.py | 5 +- .../datamodules/tiny_imagenet.py | 3 +- .../datasets/classification/cifar/cifar_c.py | 3 +- .../datasets/classification/cifar/cifar_h.py | 3 +- .../classification/imagenet/tiny_imagenet.py | 3 +- .../datasets/classification/mnist_c.py | 3 +- torch_uncertainty/layers/bayesian/sampler.py | 3 +- torch_uncertainty/layers/masksembles.py | 3 +- torch_uncertainty/metrics/fpr95.py | 5 +- torch_uncertainty/models/wideresnet/packed.py | 1 - torch_uncertainty/transforms/cutout.py | 3 +- torch_uncertainty/transforms/mixup.py | 96 ++++++++----------- torch_uncertainty/transforms/pixmix.py | 2 +- torch_uncertainty/transforms/transforms.py | 3 +- 23 files changed, 73 insertions(+), 97 deletions(-) diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 44db5dfc..fa26669a 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -1,12 +1,11 @@ from pathlib import Path from typing import Any, Callable, Tuple +import numpy as np import torch import torch.utils.data as data from PIL import Image -import numpy as np - class DummyClassificationDataset(data.Dataset): def __init__( diff --git a/tests/routines/test_regression.py b/tests/routines/test_regression.py index 49d9120e..64c6e354 100644 --- a/tests/routines/test_regression.py +++ b/tests/routines/test_regression.py @@ -6,7 +6,7 @@ from torch import nn from torch_uncertainty import cli_main, init_args -from torch_uncertainty.losses import NIGLoss, BetaNLL +from torch_uncertainty.losses import BetaNLL, NIGLoss from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18 from .._dummies import DummyRegressionBaseline, DummyRegressionDataModule @@ -86,7 +86,7 @@ def test_cli_main_dummy_dist_betanll(self): **vars(args), ) - cli_main(model, dm, root, "dummy_betanll", args) + cli_main(model, dm, root, "logs/dummy_betanll", args) def test_cli_main_dummy(self): root = Path(__file__).parent.absolute().parents[0] diff --git a/tests/test_cli.py b/tests/test_cli.py index 185007b2..dcdf071b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -15,6 +15,7 @@ optim_cifar10_wideresnet, optim_regression, ) +from torch_uncertainty.utils.misc import csv_writer from ._dummies.dataset import DummyClassificationDataset @@ -43,7 +44,12 @@ def test_cli_main_resnet(self): **vars(args), ) - cli_main(model, dm, root, "std", args) + results = cli_main(model, dm, root, "std", args) + for dict_result in results: + csv_writer( + Path("tests/logs/results.csv"), + dict_result, + ) def test_cli_main_other_arguments(self): root = Path(__file__).parent.absolute().parents[0] diff --git a/tests/test_losses.py b/tests/test_losses.py index 96412e30..89a110ad 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -5,7 +5,7 @@ from torch import nn from torch_uncertainty.layers.bayesian import BayesLinear -from torch_uncertainty.losses import DECLoss, ELBOLoss, NIGLoss, BetaNLL +from torch_uncertainty.losses import BetaNLL, DECLoss, ELBOLoss, NIGLoss class TestELBOLoss: diff --git a/tests/test_optimization_procedures.py b/tests/test_optimization_procedures.py index 3ecd3506..e250b547 100644 --- a/tests/test_optimization_procedures.py +++ b/tests/test_optimization_procedures.py @@ -16,6 +16,10 @@ def test_optim_cifar10(self): model = resnet18(in_channels=3, num_classes=10) procedure(model) + procedure = get_procedure("resnet34", "cifar10", "masked") + model = resnet34(in_channels=3, num_classes=100) + procedure(model) + procedure = get_procedure("resnet50", "cifar10", "packed") model = resnet50(in_channels=3, num_classes=10) procedure(model) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index c6e30d34..be18de67 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -1,10 +1,10 @@ from typing import Tuple +import numpy import pytest import torch from PIL import Image -import numpy from torch_uncertainty.transforms import ( AutoContrast, Brightness, diff --git a/torch_uncertainty/__init__.py b/torch_uncertainty/__init__.py index 9b7874fb..e3612ae6 100644 --- a/torch_uncertainty/__init__.py +++ b/torch_uncertainty/__init__.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Type, Union +import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import LearningRateMonitor @@ -12,8 +13,6 @@ from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from torchinfo import summary -import numpy as np - from .datamodules.abstract import AbstractDataModule from .utils import get_version diff --git a/torch_uncertainty/datamodules/abstract.py b/torch_uncertainty/datamodules/abstract.py index 98356f23..ad414ef5 100644 --- a/torch_uncertainty/datamodules/abstract.py +++ b/torch_uncertainty/datamodules/abstract.py @@ -2,13 +2,12 @@ from pathlib import Path from typing import Any, List, Optional, Union +from numpy.typing import ArrayLike from pytorch_lightning import LightningDataModule from sklearn.model_selection import StratifiedKFold from torch.utils.data import DataLoader, Dataset from torch.utils.data.sampler import SubsetRandomSampler -from numpy.typing import ArrayLike - class AbstractDataModule(LightningDataModule): training_task: str diff --git a/torch_uncertainty/datamodules/cifar10.py b/torch_uncertainty/datamodules/cifar10.py index 09c978b2..a4b5990e 100644 --- a/torch_uncertainty/datamodules/cifar10.py +++ b/torch_uncertainty/datamodules/cifar10.py @@ -2,15 +2,14 @@ from pathlib import Path from typing import Any, List, Literal, Optional, Union +import numpy as np import torchvision.transforms as T +from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR10, SVHN -import numpy as np -from numpy.typing import ArrayLike - from ..datasets import AggregatedDataset from ..datasets.classification import CIFAR10C, CIFAR10H from ..transforms import Cutout diff --git a/torch_uncertainty/datamodules/cifar100.py b/torch_uncertainty/datamodules/cifar100.py index b8314976..4a596c41 100644 --- a/torch_uncertainty/datamodules/cifar100.py +++ b/torch_uncertainty/datamodules/cifar100.py @@ -2,16 +2,15 @@ from pathlib import Path from typing import Any, List, Literal, Optional, Union +import numpy as np import torch import torchvision.transforms as T +from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR100, SVHN -import numpy as np -from numpy.typing import ArrayLike - from ..datasets import AggregatedDataset from ..datasets.classification import CIFAR100C from ..transforms import Cutout diff --git a/torch_uncertainty/datamodules/tiny_imagenet.py b/torch_uncertainty/datamodules/tiny_imagenet.py index 7f86a6c5..e10b8d5d 100644 --- a/torch_uncertainty/datamodules/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/tiny_imagenet.py @@ -3,13 +3,12 @@ from typing import Any, List, Optional, Union import torchvision.transforms as T +from numpy.typing import ArrayLike from timm.data.auto_augment import rand_augment_transform from torch import nn from torch.utils.data import ConcatDataset, DataLoader from torchvision.datasets import DTD, SVHN -from numpy.typing import ArrayLike - from ..datasets.classification import ImageNetO, TinyImageNet from .abstract import AbstractDataModule diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_c.py b/torch_uncertainty/datasets/classification/cifar/cifar_c.py index 284435fb..f670da50 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_c.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_c.py @@ -2,14 +2,13 @@ from pathlib import Path from typing import Any, Callable, Optional, Tuple +import numpy as np from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( check_integrity, download_and_extract_archive, ) -import numpy as np - class CIFAR10C(VisionDataset): """The corrupted CIFAR-10-C Dataset. diff --git a/torch_uncertainty/datasets/classification/cifar/cifar_h.py b/torch_uncertainty/datasets/classification/cifar/cifar_h.py index a3f17435..f4127055 100644 --- a/torch_uncertainty/datasets/classification/cifar/cifar_h.py +++ b/torch_uncertainty/datasets/classification/cifar/cifar_h.py @@ -1,12 +1,11 @@ import os from typing import Any, Callable, Optional +import numpy as np import torch from torchvision.datasets import CIFAR10 from torchvision.datasets.utils import check_integrity, download_url -import numpy as np - class CIFAR10H(CIFAR10): """`CIFAR-10H `_ Dataset. diff --git a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py index 6e67f174..a37f43f2 100644 --- a/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py +++ b/torch_uncertainty/datasets/classification/imagenet/tiny_imagenet.py @@ -3,12 +3,11 @@ from pathlib import Path from typing import Callable, Literal, Optional +import numpy as np import torch from PIL import Image from torch.utils.data import Dataset -import numpy as np - class TinyImageNet(Dataset): """Inspired by diff --git a/torch_uncertainty/datasets/classification/mnist_c.py b/torch_uncertainty/datasets/classification/mnist_c.py index d07373af..b6ce0d04 100644 --- a/torch_uncertainty/datasets/classification/mnist_c.py +++ b/torch_uncertainty/datasets/classification/mnist_c.py @@ -2,14 +2,13 @@ from pathlib import Path from typing import Any, Callable, Literal, Optional, Tuple +import numpy as np from torchvision.datasets import VisionDataset from torchvision.datasets.utils import ( check_integrity, download_and_extract_archive, ) -import numpy as np - class MNISTC(VisionDataset): """The corrupted MNIST-C Dataset. diff --git a/torch_uncertainty/layers/bayesian/sampler.py b/torch_uncertainty/layers/bayesian/sampler.py index d15f2f4f..01ae5791 100644 --- a/torch_uncertainty/layers/bayesian/sampler.py +++ b/torch_uncertainty/layers/bayesian/sampler.py @@ -1,10 +1,9 @@ from typing import Optional +import numpy as np import torch from torch import Tensor, distributions, nn -import numpy as np - class TrainableDistribution(nn.Module): lsqrt2pi = torch.tensor(np.log(np.sqrt(2 * np.pi))) diff --git a/torch_uncertainty/layers/masksembles.py b/torch_uncertainty/layers/masksembles.py index 834d1d4f..ade2737b 100644 --- a/torch_uncertainty/layers/masksembles.py +++ b/torch_uncertainty/layers/masksembles.py @@ -2,12 +2,11 @@ from typing import Any, Union +import numpy as np import torch from torch import Tensor, nn from torch.nn.common_types import _size_2_t -import numpy as np - def _generate_masks(m: int, n: int, s: float) -> np.ndarray: """Generates set of binary masks with properties defined by n, m, s params. diff --git a/torch_uncertainty/metrics/fpr95.py b/torch_uncertainty/metrics/fpr95.py index 5b870b3b..7064c050 100644 --- a/torch_uncertainty/metrics/fpr95.py +++ b/torch_uncertainty/metrics/fpr95.py @@ -1,14 +1,13 @@ from typing import List +import numpy as np import torch +from numpy.typing import ArrayLike from torch import Tensor from torchmetrics import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -import numpy as np -from numpy.typing import ArrayLike - def stable_cumsum(arr: ArrayLike, rtol: float = 1e-05, atol: float = 1e-08): """ diff --git a/torch_uncertainty/models/wideresnet/packed.py b/torch_uncertainty/models/wideresnet/packed.py index bd4372ab..242e8116 100644 --- a/torch_uncertainty/models/wideresnet/packed.py +++ b/torch_uncertainty/models/wideresnet/packed.py @@ -6,7 +6,6 @@ from ...layers import PackedConv2d, PackedLinear - __all__ = [ "packed_wideresnet28x10", ] diff --git a/torch_uncertainty/transforms/cutout.py b/torch_uncertainty/transforms/cutout.py index f84243b6..98865547 100644 --- a/torch_uncertainty/transforms/cutout.py +++ b/torch_uncertainty/transforms/cutout.py @@ -1,8 +1,7 @@ +import numpy as np import torch from torch import nn -import numpy as np - class Cutout(nn.Module): """Cutout augmentation class. diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index 5cb0e82c..83cf2b3a 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -1,16 +1,17 @@ from typing import Tuple + +import numpy as np import scipy import torch import torch.nn.functional as F - -import numpy as np +from torch import Tensor -def beta_warping(x, alpha_cdf=1.0, eps=1e-12): +def beta_warping(x, alpha_cdf: float = 1.0, eps: float = 1e-12) -> float: return scipy.stats.beta.cdf(x, a=alpha_cdf + eps, b=alpha_cdf + eps) -def sim_gauss_kernel(dist, tau_max=1.0, tau_std=0.5): +def sim_gauss_kernel(dist, tau_max: float = 1.0, tau_std: float = 0.5) -> float: dist_rate = tau_max * np.exp( -(dist - 1) / (np.mean(dist) * 2 * tau_std * tau_std) ) @@ -48,16 +49,16 @@ def sim_gauss_kernel(dist, tau_max=1.0, tau_std=0.5): # ): # if isinstance(c1, float): # if c1 == c2: -# c1 = torch.tensor([c1], device=x.device) +# c1 = Tensor([c1], device=x.device) # c2 = c1 # else: -# c1 = torch.tensor([c1], device=x.device) +# c1 = Tensor([c1], device=x.device) # if isinstance(c2, float): -# c2 = torch.tensor([c2], device=x.device) +# c2 = Tensor([c2], device=x.device) # bt = torch.distributions.Beta(c1, c2) # if isinstance(x, float): -# x = torch.tensor(x) +# x = Tensor(x) # X = tensor_linspace(torch.zeros_like(x) + eps, x, npts) # return torch.trapezoid(bt.log_prob(X).exp(), X, dim=0) @@ -80,7 +81,9 @@ def sim_gauss_kernel(dist, tau_max=1.0, tau_std=0.5): class AbstractMixup: - def __init__(self, alpha=1.0, mode="batch", num_classes=1000) -> None: + def __init__( + self, alpha: float = 1.0, mode: str = "batch", num_classes: int = 1000 + ) -> None: self.alpha = alpha self.num_classes = num_classes self.mode = mode @@ -89,65 +92,54 @@ def _get_params(self, batch_size: int, device: torch.device): if self.mode == "batch": lam = np.random.beta(self.alpha, self.alpha) else: - lam = torch.tensor( + lam = Tensor( np.random.beta(self.alpha, self.alpha, batch_size), device=device, ) - index = torch.randperm(batch_size, device=device) - return lam, index def _linear_mixing( self, - lam: torch.Tensor | float, - inp: torch.Tensor, - index: torch.Tensor, - ) -> torch.Tensor: - if isinstance(lam, torch.Tensor): + lam: Tensor | float, + inp: Tensor, + index: Tensor, + ) -> Tensor: + if isinstance(lam, Tensor): lam = lam.view(-1, *[1 for _ in range(inp.ndim - 1)]).float() return lam * inp + (1 - lam) * inp[index, :] def _mix_target( self, - lam: torch.Tensor | float, - target: torch.Tensor, - index: torch.Tensor, - ) -> torch.Tensor: + lam: Tensor | float, + target: Tensor, + index: Tensor, + ) -> Tensor: y1 = F.one_hot(target, self.num_classes) y2 = F.one_hot(target[index], self.num_classes) - if isinstance(lam, torch.Tensor): + if isinstance(lam, Tensor): lam = lam.view(-1, *[1 for _ in range(y1.ndim - 1)]).float() - if isinstance(lam, torch.Tensor) and lam.dtype == torch.bool: + if isinstance(lam, Tensor) and lam.dtype == torch.bool: return lam * y1 + (~lam) * y2 else: return lam * y1 + (1 - lam) * y2 - def __call__( - self, x: torch.Tensor, y: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: return x, y class Mixup(AbstractMixup): - def __call__( - self, x: torch.Tensor, y: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: lam, index = self._get_params(x.size()[0], x.device) - mixed_x = self._linear_mixing(lam, x, index) - mixed_y = self._mix_target(lam, y, index) - return mixed_x, mixed_y class MixupIO(AbstractMixup): - def __call__( - self, x: torch.Tensor, y: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: lam, index = self._get_params(x.size()[0], x.device) mixed_x = self._linear_mixing(lam, x, index) @@ -158,30 +150,24 @@ def __call__( class RegMixup(AbstractMixup): - def __call__( - self, x: torch.Tensor, y: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: lam, index = self._get_params(x.size()[0], x.device) - part_x = self._linear_mixing(lam, x, index) - part_y = self._mix_target(lam, y, index) - mixed_x = torch.cat([x, part_x], dim=0) mixed_y = torch.cat([F.one_hot(y, self.num_classes), part_y], dim=0) - return mixed_x, mixed_y class WarpingMixup(AbstractMixup): def __init__( self, - alpha=1.0, - mode="batch", - num_classes=1000, - apply_kernel=True, - tau_max=1.0, - tau_std=0.5, + alpha: float = 1.0, + mode: str = "batch", + num_classes: int = 1000, + apply_kernel: bool = True, + tau_max: float = 1.0, + tau_std: float = 0.5, ) -> None: super().__init__(alpha, mode, num_classes) self.apply_kernel = apply_kernel @@ -195,16 +181,15 @@ def _get_params(self, batch_size: int, device: torch.device): lam = np.random.beta(self.alpha, self.alpha, batch_size) index = torch.randperm(batch_size, device=device) - return lam, index def __call__( self, - x: torch.Tensor, - y: torch.Tensor, - feats: torch.Tensor, + x: Tensor, + y: Tensor, + feats: Tensor, warp_param=1.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[Tensor, Tensor]: lam, index = self._get_params(x.size()[0], x.device) if self.apply_kernel: @@ -217,10 +202,7 @@ def __call__( ) warp_param = sim_gauss_kernel(l2_dist, self.tau_max, self.tau_std) - k_lam = torch.tensor(beta_warping(lam, warp_param), device=x.device) - + k_lam = Tensor(beta_warping(lam, warp_param), device=x.device) mixed_x = self._linear_mixing(k_lam, x, index) - mixed_y = self._mix_target(k_lam, y, index) - return mixed_x, mixed_y diff --git a/torch_uncertainty/transforms/pixmix.py b/torch_uncertainty/transforms/pixmix.py index d9d96941..56a0e7b3 100644 --- a/torch_uncertainty/transforms/pixmix.py +++ b/torch_uncertainty/transforms/pixmix.py @@ -1,9 +1,9 @@ """ Code adapted from PixMix' paper. """ +import numpy as np from PIL import Image from torch import nn -import numpy as np from torch_uncertainty.transforms import Shear, Translate, augmentations diff --git a/torch_uncertainty/transforms/transforms.py b/torch_uncertainty/transforms/transforms.py index 950ff94a..03550328 100644 --- a/torch_uncertainty/transforms/transforms.py +++ b/torch_uncertainty/transforms/transforms.py @@ -1,13 +1,12 @@ from typing import List, Optional, Tuple, Union +import numpy as np import torch import torchvision.transforms.functional as F from einops import rearrange from PIL import Image, ImageEnhance from torch import Tensor, nn -import numpy as np - class AutoContrast(nn.Module): pixmix_max_level = None From aaf57abe67fdd6adb995dd21d92a3c59594fa0eb Mon Sep 17 00:00:00 2001 From: Olivier Date: Wed, 1 Nov 2023 13:40:34 +0100 Subject: [PATCH 17/27] :bug: Resolve test issue --- tests/test_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index dcdf071b..fe25090d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -47,7 +47,7 @@ def test_cli_main_resnet(self): results = cli_main(model, dm, root, "std", args) for dict_result in results: csv_writer( - Path("tests/logs/results.csv"), + Path("results.csv"), dict_result, ) From 2fb672176b422ffe0091edf12728d21e443ad579 Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Thu, 2 Nov 2023 10:43:05 +0100 Subject: [PATCH 18/27] add tests on cv datamodule + resnet feats forward + ts in classification routines + more on cli --- tests/datamodules/test_abstract_datamodule.py | 23 +++++++++++++- tests/models/test_resnets.py | 1 + tests/routines/test_classification.py | 30 +++++++++++++++++++ tests/test_cli.py | 10 +++++-- 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 62315bd7..7bd16b20 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -1,3 +1,4 @@ +from pathlib import Path import pytest from torch_uncertainty.datamodules.abstract import ( @@ -22,9 +23,29 @@ def test_errors(self): class TestCrossValDataModule: """Testing the CrossValDataModule class.""" + def test_cv_main(self): + dm = AbstractDataModule("root", 128, 4, True, True) + ds = DummyClassificationDataset(Path("root")) + dm.train = ds + dm.val = ds + dm.test = ds + cv_dm = CrossValDataModule("root", [0], [1], dm, 128, 4, True, True) + + cv_dm.setup() + cv_dm.setup("test") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + def test_errors(self): dm = AbstractDataModule("root", 128, 4, True, True) - ds = DummyClassificationDataset("root") + ds = DummyClassificationDataset(Path("root")) dm.train = ds dm.val = ds dm.test = ds diff --git a/tests/models/test_resnets.py b/tests/models/test_resnets.py index 92f8390a..49a8c462 100644 --- a/tests/models/test_resnets.py +++ b/tests/models/test_resnets.py @@ -38,6 +38,7 @@ def test_main(self): model = resnet50(1, 10, 1) with torch.no_grad(): model(torch.randn(2, 1, 32, 32)) + model.feats_forward(torch.randn(2, 1, 32, 32)) def test_mc_dropout(self): resnet34(1, 10, 1, num_estimators=5) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 7b34b21f..60315d15 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -134,6 +134,36 @@ def test_cli_main_dummy_ood(self): with pytest.raises(NotImplementedError): cli_main(model, dm, root, "logs/dummy", args) + def test_cli_main_dummy_mixup_ts(self): + root = Path(__file__).parent.absolute().parents[0] + with ArgvContext( + "file.py", + "--mixtype", + "kernel_warping", + "--mixup_alpha", + "1.", + "--dist_sim", + "inp", + "--opt_temp_scaling", + ): + args = init_args( + DummyClassificationBaseline, DummyClassificationDataModule + ) + + args.root = str(root / "data") + dm = DummyClassificationDataModule(num_classes=10, **vars(args)) + + model = DummyClassificationBaseline( + num_classes=dm.num_classes, + in_channels=dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + calibration_set=dm.get_test_set, + **vars(args), + ) + cli_main(model, dm, root, "logs/dummy", args) + def test_classification_failures(self): with pytest.raises(ValueError): ClassificationSingle( diff --git a/tests/test_cli.py b/tests/test_cli.py index fe25090d..5020b434 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -47,7 +47,13 @@ def test_cli_main_resnet(self): results = cli_main(model, dm, root, "std", args) for dict_result in results: csv_writer( - Path("results.csv"), + root / Path("tests/logs/results.csv"), + dict_result, + ) + # Test if file already exists + for dict_result in results: + csv_writer( + root / Path("tests/logs/results.csv"), dict_result, ) @@ -173,7 +179,7 @@ def test_cli_other_training_task(self): def test_cli_cv_ts(self): root = Path(__file__).parent.absolute().parents[0] - with ArgvContext("file.py", "--use_cv"): + with ArgvContext("file.py", "--use_cv", "--channels_last"): args = init_args(ResNet, CIFAR10DataModule) # datamodule From 6b2b4b94457a42b15a679bb93e1064907fc34972 Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Thu, 2 Nov 2023 11:11:23 +0100 Subject: [PATCH 19/27] fix test results writer --- tests/test_cli.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5020b434..12e8575e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ import sys +import os from pathlib import Path import pytest @@ -45,15 +46,18 @@ def test_cli_main_resnet(self): ) results = cli_main(model, dm, root, "std", args) + results_path = root / "tests" / "logs" + if not os.path.exists(results_path): + os.makedirs(results_path) for dict_result in results: csv_writer( - root / Path("tests/logs/results.csv"), + results_path / "results.csv", dict_result, ) # Test if file already exists for dict_result in results: csv_writer( - root / Path("tests/logs/results.csv"), + results_path / "results.csv", dict_result, ) From 7ba4078e6df946f04f6f937d612f9778fcf80b7f Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Thu, 2 Nov 2023 12:36:23 +0100 Subject: [PATCH 20/27] fix error in test cv datamodule --- tests/datamodules/test_abstract_datamodule.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/datamodules/test_abstract_datamodule.py b/tests/datamodules/test_abstract_datamodule.py index 7bd16b20..ea6f3807 100644 --- a/tests/datamodules/test_abstract_datamodule.py +++ b/tests/datamodules/test_abstract_datamodule.py @@ -35,13 +35,13 @@ def test_cv_main(self): cv_dm.setup("test") # test abstract methods - dm.get_train_set() - dm.get_val_set() - dm.get_test_set() + cv_dm.get_train_set() + cv_dm.get_val_set() + cv_dm.get_test_set() - dm.train_dataloader() - dm.val_dataloader() - dm.test_dataloader() + cv_dm.train_dataloader() + cv_dm.val_dataloader() + cv_dm.test_dataloader() def test_errors(self): dm = AbstractDataModule("root", 128, 4, True, True) From 64d591db34acb23ee85076168a943674d3a3472c Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Thu, 2 Nov 2023 13:21:13 +0100 Subject: [PATCH 21/27] more tests, cifar10 + fit with cv --- tests/_dummies/datamodule.py | 8 ++++ tests/datamodules/test_cifar10_datamodule.py | 8 ++++ tests/routines/test_classification.py | 46 +++++++++++++++----- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index a8e3afc8..99a5313c 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -1,6 +1,8 @@ from argparse import ArgumentParser from pathlib import Path from typing import Any, List, Optional, Union +from numpy.typing import ArrayLike +import numpy as np import torchvision.transforms as T from torch.utils.data import DataLoader @@ -84,6 +86,12 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: dataloader.append(self._data_loader(self.ood)) return dataloader + def _get_train_data(self) -> ArrayLike: + return self.train.data + + def _get_train_targets(self) -> ArrayLike: + return np.array(self.train.targets) + @classmethod def add_argparse_args( cls, diff --git a/tests/datamodules/test_cifar10_datamodule.py b/tests/datamodules/test_cifar10_datamodule.py index 62a5c336..53376dea 100644 --- a/tests/datamodules/test_cifar10_datamodule.py +++ b/tests/datamodules/test_cifar10_datamodule.py @@ -32,6 +32,9 @@ def test_CIFAR10_main(self): dm.setup() dm.setup("test") + with pytest.raises(ValueError): + dm.setup("xxx") + # test abstract methods dm.get_train_set() dm.get_val_set() @@ -52,6 +55,11 @@ def test_CIFAR10_main(self): with pytest.raises(ValueError): dm.setup() + args.test_alt = "h" + dm = CIFAR10DataModule(**vars(args)) + dm.dataset = DummyClassificationDataset + dm.setup("test") + args.test_alt = None args.num_dataloaders = 2 args.val_split = 0.1 diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 60315d15..39c59f85 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -16,6 +16,7 @@ from .._dummies import ( DummyClassificationBaseline, DummyClassificationDataModule, + DummyClassificationDataset, ) @@ -134,7 +135,7 @@ def test_cli_main_dummy_ood(self): with pytest.raises(NotImplementedError): cli_main(model, dm, root, "logs/dummy", args) - def test_cli_main_dummy_mixup_ts(self): + def test_cli_main_dummy_mixup_ts_cv(self): root = Path(__file__).parent.absolute().parents[0] with ArgvContext( "file.py", @@ -144,7 +145,8 @@ def test_cli_main_dummy_mixup_ts(self): "1.", "--dist_sim", "inp", - "--opt_temp_scaling", + "--val_temp_scaling", + "--use_cv", ): args = init_args( DummyClassificationBaseline, DummyClassificationDataModule @@ -152,17 +154,37 @@ def test_cli_main_dummy_mixup_ts(self): args.root = str(root / "data") dm = DummyClassificationDataModule(num_classes=10, **vars(args)) - - model = DummyClassificationBaseline( - num_classes=dm.num_classes, - in_channels=dm.num_channels, - loss=nn.CrossEntropyLoss, - optimization_procedure=optim_cifar10_resnet18, - baseline_type="single", - calibration_set=dm.get_test_set, - **vars(args), + dm.dataset = ( + lambda root, + num_channels, + num_classes, + image_size, + transform: DummyClassificationDataset( + root, + num_channels=num_channels, + num_classes=num_classes, + image_size=image_size, + transform=transform, + num_images=20, + ) ) - cli_main(model, dm, root, "logs/dummy", args) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + DummyClassificationBaseline( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + calibration_set=dm.get_val_set, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "logs/dummy", args) def test_classification_failures(self): with pytest.raises(ValueError): From 8b44a86cb5b1d91e93f44482ce3d9b0751b90403 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 2 Nov 2023 17:11:33 +0100 Subject: [PATCH 22/27] :hammer: Refactor mixup init --- torch_uncertainty/routines/classification.py | 82 +++++++++++--------- 1 file changed, 46 insertions(+), 36 deletions(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 8419d816..f1f6cf6b 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -159,42 +159,9 @@ def __init__( self.mixmode = mixmode self.dist_sim = dist_sim - if self.mixtype == "timm": - self.mixup = timm_Mixup( - mixup_alpha=mixup_alpha, - cutmix_alpha=cutmix_alpha, - mode=self.mixmode, - num_classes=self.num_classes, - ) - elif self.mixtype == "mixup": - self.mixup = Mixup( - alpha=mixup_alpha, - mode=self.mixmode, - num_classes=self.num_classes, - ) - elif self.mixtype == "mixup_io": - self.mixup = MixupIO( - alpha=mixup_alpha, - mode=self.mixmode, - num_classes=self.num_classes, - ) - elif self.mixtype == "regmixup": - self.mixup = RegMixup( - alpha=mixup_alpha, - mode=self.mixmode, - num_classes=self.num_classes, - ) - elif self.mixtype == "kernel_warping": - self.mixup = WarpingMixup( - alpha=mixup_alpha, - mode=self.mixmode, - num_classes=self.num_classes, - apply_kernel=True, - tau_max=kernel_tau_max, - tau_std=kernel_tau_std, - ) - else: - self.mixup = lambda x, y: (x, y) + self.mixup = self.init_mixup( + mixup_alpha, cutmix_alpha, kernel_tau_max, kernel_tau_std + ) self.cal_plot = CalibrationPlot() @@ -412,6 +379,49 @@ def test_epoch_end( "Likelihood Histogram", probs_fig ) + def init_mixup( + self, + mixup_alpha: float, + cutmix_alpha: float, + kernel_tau_max: float, + kernel_tau_std: float, + ) -> nn.Module: + if self.mixtype == "timm": + return timm_Mixup( + mixup_alpha=mixup_alpha, + cutmix_alpha=cutmix_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "mixup": + return Mixup( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "mixup_io": + return MixupIO( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "regmixup": + return RegMixup( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + ) + elif self.mixtype == "kernel_warping": + return WarpingMixup( + alpha=mixup_alpha, + mode=self.mixmode, + num_classes=self.num_classes, + apply_kernel=True, + tau_max=kernel_tau_max, + tau_std=kernel_tau_std, + ) + return lambda x, y: (x, y) + @staticmethod def add_model_specific_args( parent_parser: ArgumentParser, From c288d22bb57afd940eb575c021710b1d672a59e0 Mon Sep 17 00:00:00 2001 From: Olivier Date: Thu, 2 Nov 2023 17:14:21 +0100 Subject: [PATCH 23/27] :white_check_mark: Continue improving cvg. --- tests/_dummies/model.py | 3 ++ tests/routines/test_classification.py | 49 +++++++++++++++++++++++++++ tests/test_cli.py | 5 ++- tests/transforms/test_mixup.py | 9 +++++ torch_uncertainty/transforms/mixup.py | 2 +- 5 files changed, 66 insertions(+), 2 deletions(-) diff --git a/tests/_dummies/model.py b/tests/_dummies/model.py index e048438c..07900527 100644 --- a/tests/_dummies/model.py +++ b/tests/_dummies/model.py @@ -22,6 +22,9 @@ def __init__( self.num_estimators = num_estimators + def feats_forward(self, x: Tensor) -> Tensor: + return self.forward(x) + def forward(self, x: Tensor) -> Tensor: out = self.linear( torch.ones( diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 39c59f85..62ed2968 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -186,6 +186,55 @@ def test_cli_main_dummy_mixup_ts_cv(self): cli_main(list_model, list_dm, root, "logs/dummy", args) + with ArgvContext( + "file.py", + "--mixtype", + "kernel_warping", + "--mixup_alpha", + "1.", + "--dist_sim", + "emb", + "--val_temp_scaling", + "--use_cv", + ): + args = init_args( + DummyClassificationBaseline, DummyClassificationDataModule + ) + + args.root = str(root / "data") + dm = DummyClassificationDataModule(num_classes=10, **vars(args)) + dm.dataset = ( + lambda root, + num_channels, + num_classes, + image_size, + transform: DummyClassificationDataset( + root, + num_channels=num_channels, + num_classes=num_classes, + image_size=image_size, + transform=transform, + num_images=20, + ) + ) + + list_dm = dm.make_cross_val_splits(2, 1) + list_model = [] + for i in range(len(list_dm)): + list_model.append( + DummyClassificationBaseline( + num_classes=list_dm[i].dm.num_classes, + in_channels=list_dm[i].dm.num_channels, + loss=nn.CrossEntropyLoss, + optimization_procedure=optim_cifar10_resnet18, + baseline_type="single", + calibration_set=dm.get_val_set, + **vars(args), + ) + ) + + cli_main(list_model, list_dm, root, "logs/dummy", args) + def test_classification_failures(self): with pytest.raises(ValueError): ClassificationSingle( diff --git a/tests/test_cli.py b/tests/test_cli.py index 12e8575e..9fd24dea 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -154,13 +154,16 @@ def test_cli_main_mlp(self): cli_main(model, dm, root, "std", args) + args.test = True + cli_main(model, dm, root, "std", args) + def test_cli_other_training_task(self): root = Path(__file__).parent.absolute().parents[0] with ArgvContext("file.py"): args = init_args(MLP, UCIDataModule) # datamodule - args.root = root / "/data" + args.root = root / "data" dm = UCIDataModule( dataset_name="kin8nm", input_shape=(1, 5), **vars(args) ) diff --git a/tests/transforms/test_mixup.py b/tests/transforms/test_mixup.py index 7873f52f..d71dbc99 100644 --- a/tests/transforms/test_mixup.py +++ b/tests/transforms/test_mixup.py @@ -4,6 +4,7 @@ import torch from torch_uncertainty.transforms import Mixup, MixupIO, RegMixup, WarpingMixup +from torch_uncertainty.transforms.mixup import AbstractMixup @pytest.fixture @@ -12,6 +13,14 @@ def batch_input() -> Tuple[torch.Tensor, torch.Tensor]: return imgs, torch.tensor([0, 1]) +class TestAbstractMixup: + """Testing AbstractMixup augmentation""" + + def test_abstract_mixup(self, batch_input): + with pytest.raises(NotImplementedError): + AbstractMixup()(*batch_input) + + class TestMixup: """Testing Mixup augmentation""" diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index 83cf2b3a..04c6b4d3 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -127,7 +127,7 @@ def _mix_target( return lam * y1 + (1 - lam) * y2 def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: - return x, y + raise NotImplementedError class Mixup(AbstractMixup): From 611b4e10040a1a07b772f1c091ff3bd96a41affb Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Thu, 2 Nov 2023 19:56:38 +0100 Subject: [PATCH 24/27] fix mixupio + tensor constructor for gpu --- torch_uncertainty/routines/classification.py | 2 +- torch_uncertainty/transforms/mixup.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index f1f6cf6b..4b86dc8e 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -385,7 +385,7 @@ def init_mixup( cutmix_alpha: float, kernel_tau_max: float, kernel_tau_std: float, - ) -> nn.Module: + ) -> Callable: if self.mixtype == "timm": return timm_Mixup( mixup_alpha=mixup_alpha, diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index 04c6b4d3..f8c0b4d2 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -92,7 +92,7 @@ def _get_params(self, batch_size: int, device: torch.device): if self.mode == "batch": lam = np.random.beta(self.alpha, self.alpha) else: - lam = Tensor( + lam = torch.as_tensor( np.random.beta(self.alpha, self.alpha, batch_size), device=device, ) @@ -121,10 +121,7 @@ def _mix_target( if isinstance(lam, Tensor): lam = lam.view(-1, *[1 for _ in range(y1.ndim - 1)]).float() - if isinstance(lam, Tensor) and lam.dtype == torch.bool: - return lam * y1 + (~lam) * y2 - else: - return lam * y1 + (1 - lam) * y2 + return lam * y1 + (1 - lam) * y2 def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: raise NotImplementedError @@ -144,7 +141,10 @@ def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: mixed_x = self._linear_mixing(lam, x, index) - mixed_y = self._mix_target((lam > 0.5), y, index) + if self.mode == "batch": + mixed_y = self._mix_target(float(lam > 0.5), y, index) + else: + mixed_y = self._mix_target((lam > 0.5).float(), y, index) return mixed_x, mixed_y @@ -202,7 +202,7 @@ def __call__( ) warp_param = sim_gauss_kernel(l2_dist, self.tau_max, self.tau_std) - k_lam = Tensor(beta_warping(lam, warp_param), device=x.device) + k_lam = torch.as_tensor(beta_warping(lam, warp_param), device=x.device) mixed_x = self._linear_mixing(k_lam, x, index) mixed_y = self._mix_target(k_lam, y, index) return mixed_x, mixed_y From e9fc53cd295278f35c0e9b9c7b5de76aa3965cef Mon Sep 17 00:00:00 2001 From: Quentin Bouniot Date: Fri, 3 Nov 2023 10:11:30 +0100 Subject: [PATCH 25/27] add refs for mixup methods --- torch_uncertainty/transforms/mixup.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torch_uncertainty/transforms/mixup.py b/torch_uncertainty/transforms/mixup.py index f8c0b4d2..e39d6190 100644 --- a/torch_uncertainty/transforms/mixup.py +++ b/torch_uncertainty/transforms/mixup.py @@ -128,6 +128,11 @@ def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: class Mixup(AbstractMixup): + """Original Mixup method from Zhang et al., + "mixup: Beyond Empirical Risk Minimization" (ICLR 2021) + http://arxiv.org/abs/1710.09412 + """ + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: lam, index = self._get_params(x.size()[0], x.device) mixed_x = self._linear_mixing(lam, x, index) @@ -136,6 +141,11 @@ def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: class MixupIO(AbstractMixup): + """Mixup on inputs only with targets unchanged, from Wang et al., + "On the Pitfall of Mixup for Uncertainty Calibration" (CVPR 2023) + https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_On_the_Pitfall_of_Mixup_for_Uncertainty_Calibration_CVPR_2023_paper.pdf + """ + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: lam, index = self._get_params(x.size()[0], x.device) @@ -150,6 +160,11 @@ def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: class RegMixup(AbstractMixup): + """RegMixup method from Pinto et al., + "RegMixup: Mixup as a Regularizer Can Surprisingly Improve Accuracy and Out Distribution Robustness" (NeurIPS 2022) + https://arxiv.org/abs/2206.14502 + """ + def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: lam, index = self._get_params(x.size()[0], x.device) part_x = self._linear_mixing(lam, x, index) @@ -160,6 +175,11 @@ def __call__(self, x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: class WarpingMixup(AbstractMixup): + """Kernel Warping Mixup method from Bouniot et al., + "Tailoring Mixup to Data using Kernel Warping functions" (2023) + https://arxiv.org/abs/2311.01434 + """ + def __init__( self, alpha: float = 1.0, From c1d5711f90de94fa8bb73ab79ea505cfb34cefcb Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 3 Nov 2023 11:26:06 +0100 Subject: [PATCH 26/27] :books: Add references to docs & improve rdme. --- README.md | 16 +++++++++++----- docs/source/references.rst | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 11c5738d..4ffc9274 100644 --- a/README.md +++ b/README.md @@ -51,20 +51,26 @@ A quickstart is available at [torch-uncertainty.github.io/quickstart](https://to To date, the following deep learning baselines have been implemented: - Deep Ensembles -- MC-Dropout +- MC-Dropout - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) - BatchEnsemble - Masksembles - MIMO -- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) -- Bayesian Neural Networks :construction: Work in progress :construction: +- Packed-Ensembles (see [blog post](https://medium.com/@adrien.lafage/make-your-neural-networks-more-reliable-with-packed-ensembles-7ad0b737a873)) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) +- Bayesian Neural Networks :construction: Work in progress :construction: - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_bayesian.html) - Regression with Beta Gaussian NLL Loss -- Deep Evidential Classification & Regression +- Deep Evidential Classification & Regression - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) + +### Augmentation methods + +The following data augmentation methods have been implemented: + +- Mixup, MixupIO, RegMixup, WarpingMixup ### Post-processing methods To date, the following post-processing methods have been implemented: -- Temperature, Vector, & Matrix scaling +- Temperature, Vector, & Matrix scaling - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_scaler.html) ## Tutorials diff --git a/docs/source/references.rst b/docs/source/references.rst index 80432303..456a9627 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -114,6 +114,39 @@ For Monte-Carlo Dropout, consider citing: * Authors: *Yarin Gal and Zoubin Ghahramani* * Paper: `ICML 2016 `__. +Data Augmentation Methods +------------------------- + +Mixup +^^^^^ + +For Mixup, consider citing: + +**mixup: Beyond Empirical Risk Minimization** + +* Authors: *Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, and David Lopez-Paz* +* Paper: `ICLR 2018 `__. + +MixupIO +^^^^^^^ + +For MixupIO, consider citing: + +**On the Pitfall of Mixup for Uncertainty Calibration** + +* Authors: *Deng-Bao Wang, Lanqing Li, Peilin Zhao, Pheng-Ann Heng, and Min-Ling Zhang* +* Paper: `CVPR 2023 ` + +Warping Mixup +^^^^^^^^^^^^^ + +For Warping Mixup, consider citing: + +**Tailoring Mixup to Data using Kernel Warping functions** + +* Authors: *Quentin Bouniot, Pavlo Mozharovskyi, and Florence d'Alché-Buc* +* Paper: `ArXiv 2023 `__. + Post-Processing Methods ----------------------- From ffa2a3793ba49e8badec8768f98e8a43aa2cc68e Mon Sep 17 00:00:00 2001 From: Olivier Date: Fri, 3 Nov 2023 11:31:19 +0100 Subject: [PATCH 27/27] :fire: Remove pytest checks from pre-commits --- .pre-commit-config.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e9caef5a..414f1910 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,8 +38,8 @@ repos: language: python types_or: [python, pyi] exclude: ^auto_tutorials_source/ - - id: pytest-check - name: pytest-check - entry: pytest - language: system - pass_filenames: false + # - id: pytest-check + # name: pytest-check + # entry: pytest + # language: system + # pass_filenames: false