From a09fdd9cb0c00fd6d02bffd87f12ebee9bf44ad8 Mon Sep 17 00:00:00 2001 From: alafage Date: Mon, 6 Jan 2025 15:41:41 +0100 Subject: [PATCH] :books: Add documentation for Packed Transformer Layers --- docs/source/api.rst | 4 + torch_uncertainty/layers/__init__.py | 11 +- torch_uncertainty/layers/packed.py | 248 ++++++++++++++++-- .../metrics/classification/mean_iou.py | 37 ++- 4 files changed, 275 insertions(+), 25 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index bbc6b8df..16f1f1c2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -125,6 +125,10 @@ Ensemble layers PackedLinear PackedConv2d + PackedMultiheadAttention + PackedLayerNorm + PackedTransformerEncoderLayer + PackedTransformerDecoderLayer BatchLinear BatchConv2d MaskedLinear diff --git a/torch_uncertainty/layers/__init__.py b/torch_uncertainty/layers/__init__.py index 210e0bea..689943ff 100644 --- a/torch_uncertainty/layers/__init__.py +++ b/torch_uncertainty/layers/__init__.py @@ -4,4 +4,13 @@ from .channel_layer_norm import ChannelLayerNorm from .masksembles import MaskedConv2d, MaskedLinear from .modules import Identity -from .packed import PackedConv1d, PackedConv2d, PackedConv3d, PackedLinear +from .packed import ( + PackedConv1d, + PackedConv2d, + PackedConv3d, + PackedLayerNorm, + PackedLinear, + PackedMultiheadAttention, + PackedTransformerDecoderLayer, + PackedTransformerEncoderLayer, +) diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index a4eacab4..71537b82 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -566,23 +566,6 @@ def bias(self) -> Tensor | None: class PackedLayerNorm(nn.GroupNorm): - """Packed-Ensembles-style LayerNorm layer. - - Args: - embed_dim (int): the number of features in the input tensor. - num_estimators (int): the number of estimators in the ensemble. - alpha (float): the width multiplier of the layer. - eps (float, optional): a value added to the denominator for numerical stability. Defaults - to 1e-5. - affine (bool, optional): a boolean value that when set to ``True``, this module has - learnable per_channel affine parameters initialized to ones (for weights) and zeros - (for biases). Defaults to ``True``. - - Shape: - - Input: :math:`(N, *)` where :math:`*` means any number of additional dimensions. - - Output: :math:`(N, *)` (same shape as input) - """ - def __init__( self, embed_dim: int, @@ -593,6 +576,26 @@ def __init__( device=None, dtype=None, ) -> None: + r"""Packed-Ensembles-style LayerNorm layer. + + Args: + embed_dim (int): the number of features in the input tensor. + num_estimators (int): the number of estimators in the ensemble. + alpha (float): the width multiplier of the layer. + eps (float, optional): a value added to the denominator for numerical stability. Defaults + to 1e-5. + affine (bool, optional): a boolean value that when set to ``True``, this module has + learnable per_channel affine parameters initialized to ones (for weights) and zeros + (for biases). Defaults to ``True``. + device (torch.device, optional): the device to use for the layer's parameters. Defaults + to ``None``. + dtype (torch.dtype, optional): the dtype to use for the layer's parameters. Defaults to + ``None``. + + Shape: + - Input: :math:`(B, *)` where :math:`*` means any number of additional dimensions. + - Output: :math:`(B, *)` (same shape as input) + """ super().__init__( num_groups=num_estimators, num_channels=int(embed_dim * alpha), @@ -638,6 +641,42 @@ def __init__( device=None, dtype=None, ) -> None: + r"""Packed-Ensembles-style MultiheadAttention layer. + + Args: + embed_dim (int): Size of the embedding dimension. + num_heads (int): Number of parallel attention heads. + alpha (float): The width multiplier of the embedding dimension. + num_estimators (int): The number of estimators packed in the layer. + gamma (int, optional): Defaults to ``1``. + dropout (float, optional): Dropout probability on ``attn_output_weights``. Defaults to ``0.0`` + (no dropout). + bias (bool, optional): Ì specified, adds bias to input / output projection layers. + Defaults to ``True``. + add_bias_kv (bool, optional): If specified, adds bias to the key and value sequences at + ``dim=0``. Defaults to ``False``. + add_zero_attn (bool, optional): If specified, adds a new batch of zeros to the key and + value sequences at ``dim=1``. Defaults to ``False``. + kdim (int | None, optional): Total number of features for keys. Defaults to ``None`` + (uses ``kdim=embed_dim``). + vdim (int | None, optional): Total number of features for values. Defaults to ``None`` + (uses ``vdim=embed_dim``). + batch_first (bool, optional): If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Defaults to ``False`` (seq, batch, feature). + first (bool, optional): Whether this is the first layer of the network. Defaults to + ``False``. + last (bool, optional): Whether this is the last layer of the network. Defaults to + ``False``. + device (torch.device, optional): The device to use for the layer's parameters. Defaults + to ``None``. + dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to + ``None``. + + Reference: + - `Attention Is All You Need `_: Original Multihead Attention formulation. + - `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting `_ + : Packed-Ensembles-style Multihead Attention formulation. + """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -765,7 +804,61 @@ def forward( attn_mask: Tensor | None = None, average_attn_weights: bool = True, is_causal: bool = False, - ) -> tuple[Tensor, Tensor | None]: + ) -> tuple[Tensor, None]: + r"""Computes attention outputs given query, key, and value tensors. + + Args: + query (Tensor): Query embeddings of shape :math:`(L, E_q)` for unbatched input, + :math:`(L, B, E_q)` when ``batch_first=False`` or :math:`(B, L, E_q)` when + ``batch_first=True``, where :math:`L` is the target sequence length, :math:`B` is + the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + key (Tensor): Key embeddingd of shape :math:`(S, E_k)` for unbatched input, + :math:`(S, B, E_k)` when ``batch_first=False`` or :math:`(B, S, E_k)` when + ``batch_first=True``, where :math:`S` is the source sequence length, :math:`B` is + the batch size and :math:`E_k` is the key embedding dimension ``kdim``. + value (Tensor): Value embeddings of shape :math:`(S, E_v)` for unbatched input, + :math:`(S, B, E_v)` when ``batch_first=False`` or :math:`(B, S, E_v)` when + ``batch_first=True``, where :math:`S` is the source sequence length, :math:`B` is + the batch size and :math:`E_v` is the value embedding dimension ``vdim``. + key_padding_mask (Tensor | None, optional): If specified, a mask of shape + :math:`(B, S)` indicating which elements within ``key`` to ignore for the purpose + of attention (i.e. treat as "padding"). For unbatched `query`, shape should be + :math:`(S)`. Binary and float masks are supported. For a binary mask, a ``True`` + value indicates that the corresponding ``key`` value will be ignored for the + purpose of attention. For a float mask, it will be directly added to the + corresponding ``key`` value. Defaults to ``None``. + need_weights (bool, optional): If specified, returns ``attn_output_weights`` in + addition to ``attn_outputs``. Set ``need_weights=False`` to use the optimized + ``scale_dot_product_attention`` and achieve the best performance for MHA. + Defaults to ``False``. + attn_mask (Tensor | None, optional): If specified, a 2D or 3D mask preventing attention + to certain positions. Must be of shape :math:`(L,S)` or + :math:`(B \times \text{num_heads}, L, S)`, where :math:`B` is the batch size, :math:`L` + is the target sequence length, and :math:`S` is the source sequence length. A 2D mask + will be broadcasted across the batch while a 3D mask allows for a different mask for + each entry in the batch. Binary and float masks are supported. For a binary mask, a + ``True`` value indicates that the corresponding position is not allowed to attend to. + For a float mask, the mask values will be added to the attention weight. If both + ``attn_mask`` and ``key_padding_mask`` are provided, their types should match. + Defaults to ``None``. + average_attn_weights (bool, optional): If ``True``, indicates that the returned + ``attn_weights`` should be averaged across heads. Otherwise, ``attn_weights`` are + provided separately per head. Note that this flag only has an effect when + ``need_weights=True``. Defaults to ``True``. + is_causal (bool, optional): _description_. Defaults to ``False``. + + Warning: + ``need_weights=True`` and therefore ``average_attn_weights`` are not supported yet thus + have no effect. + + Returns: + tuple[Tensor, None]: + - *attn_output* (Tensor): The output tensor of shape :math:`(L, E_q)`, :math:`(L, B, E_q)` + or :math:`(B, L, E_q)` where :math:`L` is the target sequence length, :math:`B` is + the batch size, and :math:`E_q` is the embedding dimension ``embed_dim``. + - *attn_output_weights* (None): Always ``None`` has we do not support + ``need_weights=True`` yet. + """ is_batched = query.dim() == 3 key_padding_mask = F._canonical_mask( @@ -879,6 +972,44 @@ def __init__( device=None, dtype=None, ) -> None: + r"""Packed-Ensembles-style TransformerEncoderLayer (made up of self-attention followed by a + feedforward network). + + Args: + d_model (int): the number of expected features in the input. + nhead (int): the number of heads in the multiheadattention models. + alpha (float): the width multiplier of the layer. + num_estimators (int): the number of estimators packed in the layer. + gamma (int, optional): Defaults to ``1``. + dim_feedforward (int, optional): the dimension of the feedforward network model. Defaults + to ``2048``. + dropout (float, optional): the dropout value. Defaults to ``0.1``. + activation (Callable[[Tensor], Tensor], optional): the activation function of the + intermediate layer, that is a unary callable. Defaults to ``F.relu``. + layer_norm_eps (float, optional): the eps value in layer normalization components. Defaults + to ``1e-5``. + bias (bool, optional): If ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an + additive bias. Defaults to ``True``. + batch_first (bool, optional): If ``True``, then the input and output tensors are provided + as :math:`(\text{batch}, \text{seq}, \text{d_model})`. Defaults to ``False`` + :math:`(\text{seq}, \text{batch}, \text{d_model})`. + norm_first (bool, optional): If ``True``, the layer norm is done prior to attention and + feedforward operations, respectively. Otherwise, it is done after. Defaults to + ``False``. + first (bool, optional): Whether this is the first layer of the network. Defaults to + ``False``. + last (bool, optional): Whether this is the last layer of the network. Defaults to + ``False``. + device (torch.device, optional): The device to use for the layer's parameters. Defaults + to ``None``. + dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to + ``None``. + + Reference: + - `Attention Is All You Need `_: Original Multihead Attention formulation. + - `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting `_ + : Packed-Ensembles-style Multihead Attention formulation. + """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -965,6 +1096,22 @@ def forward( src_key_padding_mask: Tensor | None = None, is_causal: bool = False, ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src (Tensor): The sequence to the encoder layer. Shape: :math:`(B, L, E)` or + :math:`(L, B, E)`. + src_mask (Tensor | None, optional): The mask for the ``src`` sequence. Defaults to ``None``. + src_key_padding_mask (Tensor | None, optional): The mask for the ``src`` keys per + batch. Defaults to ``None``. + is_causal (bool, optional): If specified, applies a causal mask as ``src_mask``. + Defaults to ``False``. Warning: ``is_causal`` provides a hint the ``src_mask`` is + a causal mask. Providing incorrect hints can result in incorrect execution, + including forward and backward compatibility. + + Returns: + Tensor: The output of the encoder layer. Shape: :math:`(B, L, E)` or :math:`(L, B, E)`. + """ src_key_padding_mask = F._canonical_mask( mask=src_key_padding_mask, mask_name="src_key_padding_mask", @@ -1045,6 +1192,44 @@ def __init__( device=None, dtype=None, ) -> None: + r"""Packed-Ensembles-style TransformerDecoderLayer (made up of self-attention, multi-head + attention, and feedforward network). + + Args: + d_model (int): the number of expected features in the input. + nhead (int): the number of heads in the multiheadattention models. + alpha (float): the width multiplier of the layer. + num_estimators (int): the number of estimators packed in the layer. + gamma (int, optional): Defaults to ``1``. + dim_feedforward (int, optional): the dimension of the feedforward network model. Defaults + to ``2048``. + dropout (float, optional): the dropout value. Defaults to ``0.1``. + activation (Callable[[Tensor], Tensor], optional): the activation function of the + intermediate layer, that is a unary callable. Defaults to ``F.relu``. + layer_norm_eps (float, optional): the eps value in layer normalization components. Defaults + to ``1e-5``. + bias (bool, optional): If ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an + additive bias. Defaults to ``True``. + batch_first (bool, optional): If ``True``, then the input and output tensors are provided + as :math:`(\text{batch}, \text{seq}, \text{d_model})`. Defaults to ``False`` + :math:`(\text{seq}, \text{batch}, \text{d_model})`. + norm_first (bool, optional): If ``True``, the layer norm is done prior to attention and + feedforward operations, respectively. Otherwise, it is done after. Defaults to + ``False``. + first (bool, optional): Whether this is the first layer of the network. Defaults to + ``False``. + last (bool, optional): Whether this is the last layer of the network. Defaults to + ``False``. + device (torch.device, optional): The device to use for the layer's parameters. Defaults + to ``None``. + dtype (torch.dtype, optional): The dtype to use for the layer's parameters. Defaults to + ``None``. + + Reference: + - `Attention Is All You Need `_: Original Multihead Attention formulation. + - `Hierarchical Light Tranformer Ensembles for Multimodal Trajectory Forecasting `_ + : Packed-Ensembles-style Multihead Attention formulation. + """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -1156,6 +1341,33 @@ def forward( tgt_is_causal: bool = False, memory_is_causal: bool = False, ) -> Tensor: + r"""Pass the input (and mask) through the decoder layer. + + Args: + tgt (Tensor): The sequence to the decoder layer. Shape: :math:`(B, L, E)` or + :math:`(L, B, E)`. + memory (Tensor): The sequence from the last layer of the encoder. Shape: + :math:`(B, S, E)` or :math:`(S, B, E)`. + tgt_mask (Tensor | None, optional): The mask for the ``tgt`` sequence. Defaults to + ``None``. + memory_mask (Tensor | None, optional): The mask for the ``memory`` sequence. Defaults + to ``None``. + tgt_key_padding_mask (Tensor | None, optional): The mask for the ``tgt`` keys per + batch. Defaults to ``None``. + memory_key_padding_mask (Tensor | None, optional): The mask for the ``memory`` keys per + batch. Defaults to ``None``. + tgt_is_causal (bool, optional): If specified, applies a causal mask as ``tgt_mask``. + Defaults to ``False``. Warning: ``tgt_is_causal`` provides a hint the ``tgt_mask`` + is a causal mask. Providing incorrect hints can result in incorrect execution, + including forward and backward compatibility. + memory_is_causal (bool, optional): If specified, applies a causal mask as ``memory_mask``. + Defaults to ``False``. Warning: ``memory_is_causal`` provides a hint the ``memory_mask`` + is a causal mask. Providing incorrect hints can result in incorrect execution, + including forward and backward compatibility. + + Returns: + Tensor: The output of the encoder layer. Shape: :math:`(B, L, E)` or :math:`(L, B, E)`. + """ x = tgt if self.norm_first: x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal) diff --git a/torch_uncertainty/metrics/classification/mean_iou.py b/torch_uncertainty/metrics/classification/mean_iou.py index 69362f36..5d9ef56d 100644 --- a/torch_uncertainty/metrics/classification/mean_iou.py +++ b/torch_uncertainty/metrics/classification/mean_iou.py @@ -1,13 +1,9 @@ -from typing import Literal - from torch import Tensor from torchmetrics.classification.stat_scores import MulticlassStatScores from torchmetrics.utilities.compute import _safe_divide class MeanIntersectionOverUnion(MulticlassStatScores): - """Compute the MeanIntersection over Union (IoU) score.""" - is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False @@ -16,16 +12,44 @@ def __init__( self, num_classes: int, top_k: int = 1, - multidim_average: Literal["global", "samplewise"] = "global", ignore_index: int | None = None, validate_args: bool = True, **kwargs, ) -> None: + r"""Computes Mean Intersection over Union (IoU) score. + + Args: + num_classes (int): Integer specifying the number of classes. + top_k (int, optional): Number of highest probability or logit score predictions + considered to find the correct label. Only works when ``preds`` contain + probabilities/logits. Defaults to ``1``. + ignore_index (int | None, optional): Specifies a target value that is ignored and does + not contribute to the metric calculation. Defaults to ``None``. + validate_args (bool, optional): Bool indicating if input arguments and tensors should + be validated for correctness. Set to ``False`` for faster computations. Defaults to + ``True``. + **kwargs: kwargs: Additional keyword arguments, see + `Advanced metric settings `_ + for more info. + + Shape: + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, ...)`` or float tensor of shape ``(B, C, ..)``. + If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert + probabilities/logits into an int tensor. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(B, ...)``. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``mean_iou`` (:class:`~torch.Tensor`): The computed Mean Intersection over Union (IoU) score. + A tensor containing a single float value. + """ super().__init__( num_classes, top_k, "macro", - multidim_average, + "global", ignore_index, validate_args, **kwargs, @@ -34,4 +58,5 @@ def __init__( def compute(self) -> Tensor: """Compute the Means Intersection over Union (MIoU) based on saved inputs.""" tp, fp, _, fn = self._final_state() + return _safe_divide(tp, tp + fp + fn, zero_division=float("nan")).nanmean()