diff --git a/.gitignore b/.gitignore index 4aa0c186..e8f10e61 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,9 @@ docs/*/auto_tutorials/ *.ckpt *.out docs/source/sg_execution_times.rst -test**/*.csv +test +**/*.csv +pyrightconfig.json # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/docs/source/conf.py b/docs/source/conf.py index c5c3b996..8832766d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ f"{datetime.now().year!s}, Adrien Lafage and Olivier Laurent" ) author = "Adrien Lafage and Olivier Laurent" -release = "0.2.2.post0" +release = "0.2.2.post1" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index c82abe13..ab2fa77f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "torch_uncertainty" -version = "0.2.2.post0" +version = "0.2.2.post1" authors = [ { name = "ENSTA U2IS", email = "olivier.laurent@ensta-paris.fr" }, { name = "Adrien Lafage", email = "adrienlafage@outlook.com" }, diff --git a/torch_uncertainty/metrics/classification/risk_coverage.py b/torch_uncertainty/metrics/classification/risk_coverage.py index 8ab62d00..33298abe 100644 --- a/torch_uncertainty/metrics/classification/risk_coverage.py +++ b/torch_uncertainty/metrics/classification/risk_coverage.py @@ -147,9 +147,7 @@ def plot( ax.set_ylabel("Risk - Error Rate (%)", fontsize=16) ax.set_xlim(0, 100) ax.set_ylim(0, min(100, np.ceil(error_rates.max() * 100))) - ax.set_aspect("equal", "box") ax.legend(loc="upper right") - fig.tight_layout() return fig, ax @@ -270,9 +268,7 @@ def plot( ax.set_ylabel("Generalized Risk (%)", fontsize=16) ax.set_xlim(0, 100) ax.set_ylim(0, min(100, np.ceil(error_rates.max() * 100))) - ax.set_aspect("equal", "box") ax.legend(loc="upper right") - fig.tight_layout() return fig, ax diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index cd89e63c..933b0749 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -248,7 +248,7 @@ def __init__( self.layer4 = nn.Identity() linear_multiplier = 4 - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) @@ -297,7 +297,7 @@ def forward(self, x: Tensor) -> Tensor: out = self.layer3(out) out = self.layer4(out) out = self.pool(out) - out = self.dropout(self.flatten(out)) + out = self.final_dropout(self.flatten(out)) return self.linear(out) diff --git a/torch_uncertainty/models/resnet/lpbnn.py b/torch_uncertainty/models/resnet/lpbnn.py index b79de57c..6d22720f 100644 --- a/torch_uncertainty/models/resnet/lpbnn.py +++ b/torch_uncertainty/models/resnet/lpbnn.py @@ -258,7 +258,7 @@ def __init__( self.layer4 = nn.Identity() linear_multiplier = 4 - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) @@ -309,7 +309,7 @@ def feats_forward(self, x: Tensor) -> Tensor: out = self.layer3(out) out = self.layer4(out) out = self.pool(out) - return self.dropout(self.flatten(out)) + return self.final_dropout(self.flatten(out)) def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index 04117e67..4af8ab2a 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -262,7 +262,7 @@ def __init__( self.layer4 = nn.Identity() linear_multiplier = 4 - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) @@ -315,7 +315,7 @@ def forward(self, x: Tensor) -> Tensor: out = self.layer4(out) out = self.pool(out) - out = self.dropout(self.flatten(out)) + out = self.final_dropout(self.flatten(out)) return self.linear(out) diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/resnet/packed.py index 4bf170d8..217a07a3 100644 --- a/torch_uncertainty/models/resnet/packed.py +++ b/torch_uncertainty/models/resnet/packed.py @@ -315,7 +315,7 @@ def __init__( self.layer4 = nn.Identity() linear_multiplier = 4 - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) @@ -374,7 +374,7 @@ def forward(self, x: Tensor) -> Tensor: ) out = self.pool(out) - out = self.dropout(self.flatten(out)) + out = self.final_dropout(self.flatten(out)) return self.linear(out) def check_config(self, config: dict[str, Any]) -> bool: diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index b07e7fc6..1b9ddccd 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -293,7 +293,7 @@ def __init__( self.layer4 = nn.Identity() linear_multiplier = 4 - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) @@ -340,7 +340,7 @@ def feats_forward(self, x: Tensor) -> Tensor: out = self.layer3(out) out = self.layer4(out) out = self.pool(out) - return self.dropout(self.flatten(out)) + return self.final_dropout(self.flatten(out)) def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) @@ -374,6 +374,7 @@ def resnet( activation_fn (Callable, optional): Activation function. Defaults to ``torch.nn.functional.relu``. normalization_layer (nn.Module, optional): Normalization layer. + Defaults to ``torch.nn.BatchNorm2d``. Returns: _ResNet: The ResNet model. diff --git a/torch_uncertainty/models/wideresnet/batched.py b/torch_uncertainty/models/wideresnet/batched.py index 792c0e46..120dc267 100644 --- a/torch_uncertainty/models/wideresnet/batched.py +++ b/torch_uncertainty/models/wideresnet/batched.py @@ -22,6 +22,7 @@ def __init__( groups: int, conv_bias: bool, activation_fn: Callable, + normalization_layer: type[nn.Module], ) -> None: super().__init__() self.activation_fn = activation_fn @@ -35,7 +36,7 @@ def __init__( bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) - self.bn1 = nn.BatchNorm2d(planes) + self.bn1 = normalization_layer(planes) self.conv2 = BatchConv2d( planes, planes, @@ -46,7 +47,7 @@ def __init__( groups=groups, bias=conv_bias, ) - self.bn2 = nn.BatchNorm2d(planes) + self.bn2 = normalization_layer(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: @@ -82,6 +83,7 @@ def __init__( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__() self.num_estimators = num_estimators @@ -123,7 +125,7 @@ def __init__( else: raise ValueError(f"Unknown WideResNet style: {style}. ") - self.bn1 = nn.BatchNorm2d(num_stages[0]) + self.bn1 = normalization_layer(num_stages[0]) if style == "imagenet": self.optional_pool = nn.MaxPool2d( @@ -142,6 +144,7 @@ def __init__( groups=groups, conv_bias=conv_bias, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) self.layer2 = self._wide_layer( _WideBasicBlock, @@ -153,6 +156,7 @@ def __init__( groups=groups, conv_bias=conv_bias, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) self.layer3 = self._wide_layer( _WideBasicBlock, @@ -164,9 +168,10 @@ def __init__( groups=groups, conv_bias=conv_bias, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) self.linear = BatchLinear( @@ -186,6 +191,7 @@ def _wide_layer( groups: int, conv_bias: bool, activation_fn: Callable, + normalization_layer: type[nn.Module], ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -201,6 +207,7 @@ def _wide_layer( num_estimators=num_estimators, groups=groups, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) ) self.in_planes = planes @@ -214,7 +221,7 @@ def feats_forward(self, x: Tensor) -> Tensor: out = self.layer2(out) out = self.layer3(out) out = self.pool(out) - return self.dropout(self.flatten(out)) + return self.final_dropout(self.flatten(out)) def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) @@ -228,6 +235,8 @@ def batched_wideresnet28x10( dropout_rate: float = 0.3, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _BatchWideResNet: """BatchEnsemble of Wide-ResNet-28x10. @@ -241,6 +250,10 @@ def batched_wideresnet28x10( groups (int): Number of groups in the convolutions. Defaults to ``1``. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. + activation_fn (Callable, optional): Activation function. Defaults to + ``torch.nn.functional.relu``. + normalization_layer (nn.Module, optional): Normalization layer. + Defaults to ``torch.nn.BatchNorm2d``. Returns: _BatchWideResNet: A BatchEnsemble-style Wide-ResNet-28x10. @@ -255,4 +268,6 @@ def batched_wideresnet28x10( num_estimators=num_estimators, groups=groups, style=style, + activation_fn=activation_fn, + normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/wideresnet/masked.py b/torch_uncertainty/models/wideresnet/masked.py index 3a90be81..230441dd 100644 --- a/torch_uncertainty/models/wideresnet/masked.py +++ b/torch_uncertainty/models/wideresnet/masked.py @@ -23,6 +23,7 @@ def __init__( scale: float, groups: int, activation_fn: Callable, + normalization_layer: type[nn.Module], ) -> None: super().__init__() self.activation_fn = activation_fn @@ -37,7 +38,7 @@ def __init__( bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) - self.bn1 = nn.BatchNorm2d(planes) + self.bn1 = normalization_layer(planes) self.conv2 = MaskedConv2d( planes, planes, @@ -49,7 +50,7 @@ def __init__( groups=groups, bias=conv_bias, ) - self.bn2 = nn.BatchNorm2d(planes) + self.bn2 = normalization_layer(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: @@ -87,6 +88,7 @@ def __init__( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__() self.num_estimators = num_estimators @@ -126,7 +128,7 @@ def __init__( else: raise ValueError(f"Unknown WideResNet style: {style}. ") - self.bn1 = nn.BatchNorm2d(num_stages[0]) + self.bn1 = normalization_layer(num_stages[0]) if style == "imagenet": self.optional_pool = nn.MaxPool2d( @@ -146,6 +148,7 @@ def __init__( scale=scale, groups=groups, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) self.layer2 = self._wide_layer( _WideBasicBlock, @@ -158,6 +161,7 @@ def __init__( scale=scale, groups=groups, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) self.layer3 = self._wide_layer( _WideBasicBlock, @@ -170,9 +174,10 @@ def __init__( scale=scale, groups=groups, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) @@ -189,9 +194,10 @@ def _wide_layer( dropout_rate: float, stride: int, num_estimators: int, - scale: float = 2.0, - groups: int = 1, - activation_fn: Callable = relu, + scale: float, + groups: int, + activation_fn: Callable, + normalization_layer: type[nn.Module], ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -208,6 +214,7 @@ def _wide_layer( scale=scale, groups=groups, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) ) self.in_planes = planes @@ -221,7 +228,7 @@ def feats_forward(self, x: Tensor) -> Tensor: out = self.layer2(out) out = self.layer3(out) out = self.pool(out) - return self.dropout(self.flatten(out)) + return self.final_dropout(self.flatten(out)) def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) @@ -236,6 +243,8 @@ def masked_wideresnet28x10( dropout_rate: float = 0.3, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _MaskedWideResNet: """Masksembles of Wide-ResNet-28x10. @@ -251,6 +260,10 @@ def masked_wideresnet28x10( ``1``. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. + activation_fn (Callable, optional): Activation function. Defaults to + ``torch.nn.functional.relu``. + normalization_layer (nn.Module, optional): Normalization layer. + Defaults to ``torch.nn.BatchNorm2d``. Returns: _MaskedWideResNet: A Masksembles-style Wide-ResNet-28x10. @@ -266,4 +279,6 @@ def masked_wideresnet28x10( scale=scale, groups=groups, style=style, + activation_fn=activation_fn, + normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/wideresnet/mimo.py b/torch_uncertainty/models/wideresnet/mimo.py index edb9a588..3e6d9991 100644 --- a/torch_uncertainty/models/wideresnet/mimo.py +++ b/torch_uncertainty/models/wideresnet/mimo.py @@ -3,6 +3,7 @@ import torch from einops import rearrange +from torch import nn from torch.nn.functional import relu from .std import _WideResNet @@ -25,6 +26,7 @@ def __init__( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__( depth, @@ -36,6 +38,7 @@ def __init__( groups=groups, style=style, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) self.num_estimators = num_estimators @@ -56,6 +59,8 @@ def mimo_wideresnet28x10( dropout_rate: float = 0.3, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _MIMOWideResNet: return _MIMOWideResNet( depth=28, @@ -67,4 +72,6 @@ def mimo_wideresnet28x10( dropout_rate=dropout_rate, groups=groups, style=style, + activation_fn=activation_fn, + normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/wideresnet/packed.py b/torch_uncertainty/models/wideresnet/packed.py index 60fcc7cf..ae9970d5 100644 --- a/torch_uncertainty/models/wideresnet/packed.py +++ b/torch_uncertainty/models/wideresnet/packed.py @@ -1,7 +1,6 @@ from collections.abc import Callable from typing import Literal -import torch.nn.functional as F from einops import rearrange from torch import Tensor, nn from torch.nn.functional import relu @@ -25,8 +24,11 @@ def __init__( num_estimators: int, gamma: int, groups: int, + activation_fn: Callable, + normalization_layer: type[nn.Module], ) -> None: super().__init__() + self.activation_fn = activation_fn self.conv1 = PackedConv2d( in_planes, planes, @@ -39,7 +41,7 @@ def __init__( bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) - self.bn1 = nn.BatchNorm2d(alpha * planes) + self.bn1 = normalization_layer(alpha * planes) self.conv2 = PackedConv2d( planes, planes, @@ -67,13 +69,13 @@ def __init__( bias=conv_bias, ), ) - self.bn2 = nn.BatchNorm2d(alpha * planes) + self.bn2 = normalization_layer(alpha * planes) def forward(self, x: Tensor) -> Tensor: - out = F.relu(self.bn1(self.dropout(self.conv1(x)))) + out = self.activation_fn(self.bn1(self.dropout(self.conv1(x)))) out = self.conv2(out) out += self.shortcut(x) - return F.relu(self.bn2(out)) + return self.activation_fn(self.bn2(out)) class _PackedWideResNet(nn.Module): @@ -91,6 +93,7 @@ def __init__( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__() self.num_estimators = num_estimators @@ -138,7 +141,7 @@ def __init__( else: raise ValueError(f"Unknown WideResNet style: {style}. ") - self.bn1 = nn.BatchNorm2d(num_stages[0] * alpha) + self.bn1 = normalization_layer(num_stages[0] * alpha) if style == "imagenet": self.optional_pool = nn.MaxPool2d( @@ -158,6 +161,8 @@ def __init__( num_estimators=self.num_estimators, gamma=gamma, groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, ) self.layer2 = self._wide_layer( _WideBasicBlock, @@ -170,6 +175,8 @@ def __init__( num_estimators=self.num_estimators, gamma=gamma, groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, ) self.layer3 = self._wide_layer( _WideBasicBlock, @@ -182,9 +189,11 @@ def __init__( num_estimators=self.num_estimators, gamma=gamma, groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, ) - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) @@ -208,6 +217,8 @@ def _wide_layer( num_estimators: int, gamma: int, groups: int, + activation_fn: Callable, + normalization_layer: type[nn.Module], ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -224,6 +235,8 @@ def _wide_layer( num_estimators=num_estimators, gamma=gamma, groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, ) ) self.in_planes = planes @@ -239,7 +252,7 @@ def feats_forward(self, x: Tensor) -> Tensor: out, "e (m c) h w -> (m e) c h w", m=self.num_estimators ) out = self.pool(out) - return self.dropout(self.flatten(out)) + return self.final_dropout(self.flatten(out)) def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) @@ -255,6 +268,8 @@ def packed_wideresnet28x10( dropout_rate: float = 0.3, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", + activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _PackedWideResNet: """Packed-Ensembles of Wide-ResNet-28x10. @@ -270,6 +285,10 @@ def packed_wideresnet28x10( dropout_rate (float, optional): Dropout rate. Defaults to ``0.3``. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. + activation_fn (Callable, optional): Activation function. Defaults to + ``torch.nn.functional.relu``. + normalization_layer (nn.Module, optional): Normalization layer. + Defaults to ``torch.nn.BatchNorm2d``. Returns: _PackedWideResNet: A Packed-Ensembles Wide-ResNet-28x10. @@ -286,4 +305,6 @@ def packed_wideresnet28x10( gamma=gamma, groups=groups, style=style, + activation_fn=activation_fn, + normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/models/wideresnet/std.py b/torch_uncertainty/models/wideresnet/std.py index 963b4d60..be4ca9b6 100644 --- a/torch_uncertainty/models/wideresnet/std.py +++ b/torch_uncertainty/models/wideresnet/std.py @@ -19,6 +19,7 @@ def __init__( groups: int, conv_bias: bool, activation_fn: Callable, + normalization_layer: type[nn.Module], ) -> None: super().__init__() self.activation_fn = activation_fn @@ -31,7 +32,7 @@ def __init__( bias=conv_bias, ) self.dropout = nn.Dropout2d(p=dropout_rate) - self.bn1 = nn.BatchNorm2d(planes) + self.bn1 = normalization_layer(planes) self.conv2 = nn.Conv2d( planes, planes, @@ -41,7 +42,7 @@ def __init__( groups=groups, bias=conv_bias, ) - self.bn2 = nn.BatchNorm2d(planes) + self.bn2 = normalization_layer(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: @@ -75,6 +76,7 @@ def __init__( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__() self.dropout_rate = dropout_rate @@ -114,7 +116,7 @@ def __init__( else: raise ValueError(f"Unknown WideResNet style: {style}. ") - self.bn1 = nn.BatchNorm2d(num_stages[0]) + self.bn1 = normalization_layer(num_stages[0]) if style == "imagenet": self.optional_pool = nn.MaxPool2d( @@ -132,6 +134,7 @@ def __init__( groups=groups, activation_fn=activation_fn, conv_bias=conv_bias, + normalization_layer=normalization_layer, ) self.layer2 = self._wide_layer( WideBasicBlock, @@ -142,6 +145,7 @@ def __init__( groups=groups, activation_fn=activation_fn, conv_bias=conv_bias, + normalization_layer=normalization_layer, ) self.layer3 = self._wide_layer( WideBasicBlock, @@ -152,8 +156,9 @@ def __init__( groups=groups, activation_fn=activation_fn, conv_bias=conv_bias, + normalization_layer=normalization_layer, ) - self.dropout = nn.Dropout(p=dropout_rate) + self.final_dropout = nn.Dropout(p=dropout_rate) self.pool = nn.AdaptiveAvgPool2d(output_size=1) self.flatten = nn.Flatten(1) self.linear = nn.Linear( @@ -171,6 +176,7 @@ def _wide_layer( groups: int, conv_bias: bool, activation_fn: Callable, + normalization_layer: type[nn.Module], ) -> nn.Module: strides = [stride] + [1] * (int(num_blocks) - 1) layers = [] @@ -185,6 +191,7 @@ def _wide_layer( groups=groups, conv_bias=conv_bias, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) ) self.in_planes = planes @@ -197,7 +204,7 @@ def feats_forward(self, x: Tensor) -> Tensor: out = self.layer2(out) out = self.layer3(out) out = self.pool(out) - return self.dropout(self.flatten(out)) + return self.final_dropout(self.flatten(out)) def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) @@ -211,6 +218,7 @@ def wideresnet28x10( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _WideResNet: """Wide-ResNet-28x10 from `Wide Residual Networks `_. @@ -227,6 +235,8 @@ def wideresnet28x10( structure. Defaults to ``True``. activation_fn (Callable, optional): Activation function. Defaults to ``torch.nn.functional.relu``. + normalization_layer (nn.Module, optional): Normalization layer. + Defaults to ``torch.nn.BatchNorm2d``. Returns: _Wide: A Wide-ResNet-28x10. @@ -241,4 +251,5 @@ def wideresnet28x10( groups=groups, style=style, activation_fn=activation_fn, + normalization_layer=normalization_layer, ) diff --git a/torch_uncertainty/post_processing/laplace.py b/torch_uncertainty/post_processing/laplace.py index dac1b0ba..89180b4e 100644 --- a/torch_uncertainty/post_processing/laplace.py +++ b/torch_uncertainty/post_processing/laplace.py @@ -26,6 +26,7 @@ def __init__( "mc", "probit", "bridge", "bridge_norm" ] = "probit", batch_size: int = 256, + optimize_prior_precision: bool = True, ) -> None: """Laplace approximation for uncertainty estimation. @@ -45,6 +46,8 @@ def __init__( See the Laplace library for more details. Defaults to "probit". batch_size (int, optional): batch size for the Laplace approximation. Defaults to 256. + optimize_prior_precision (bool, optional): whether to optimize the prior + precision. Defaults to True. Reference: Daxberger et al. Laplace Redux - Effortless Bayesian Deep Learning. In NeurIPS 2021. @@ -77,6 +80,7 @@ def set_model(self, model: nn.Module) -> None: def fit(self, dataset: Dataset) -> None: dl = DataLoader(dataset, batch_size=self.batch_size) self.la.fit(train_loader=dl) + self.la.optimize_prior_precision(method="marglik") def forward( self, diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 23da6188..c6e92755 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -38,7 +38,7 @@ EPOCH_UPDATE_MODEL, STEP_UPDATE_MODEL, ) -from torch_uncertainty.post_processing import PostProcessing +from torch_uncertainty.post_processing import LaplaceApprox, PostProcessing from torch_uncertainty.transforms import ( Mixup, MixupIO, @@ -449,11 +449,6 @@ def test_step( else: ood_scores = -confs - if self.post_processing is not None: - pp_logits = self.post_processing(inputs) - pp_probs = F.softmax(pp_logits, dim=-1) - self.ts_cls_metrics.update(pp_probs, targets) - if dataloader_idx == 0: # squeeze if binary classification only for binary metrics self.test_cls_metrics.update( @@ -485,6 +480,14 @@ def test_step( if self.id_logit_storage is not None: self.id_logit_storage.append(logits.detach().cpu()) + if self.post_processing is not None: + pp_logits = self.post_processing(inputs) + if not isinstance(self.post_processing, LaplaceApprox): + pp_probs = F.softmax(pp_logits, dim=-1) + else: + pp_probs = pp_logits + self.ts_cls_metrics.update(pp_probs, targets) + elif self.eval_ood and dataloader_idx == 1: self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) self.test_ood_entropy(probs)