diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index 46d2698c..bf1fd7ac 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -6,6 +6,7 @@ PackedConv1d, PackedConv2d, PackedConv3d, + PackedLayerNorm, PackedLinear, PackedMultiheadAttention, ) @@ -46,30 +47,6 @@ def voxels_input() -> torch.Tensor: return torch.rand((5, 6, 3, 3, 3)) -@pytest.fixture() -def unbatched_sequence() -> torch.Tensor: - return torch.rand((3, 6)) # (L, Hin) - - -@pytest.fixture() -def batched_sequence() -> torch.Tensor: - return torch.rand((2, 3, 6)) # (B, L, Hin) - - -@pytest.fixture() -def unbatched_sequences() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return torch.rand((3, 6)), torch.rand((4, 2)), torch.rand((4, 4)) # (L, Eq), (S, Ek), (S, Ev) - - -@pytest.fixture() -def batched_sequences() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return ( - torch.rand((2, 3, 6)), - torch.rand((2, 4, 2)), - torch.rand((2, 4, 4)), - ) # (B, L, Eq), (B, S, Ek), (B, S, Ev) - - @pytest.fixture() def unbatched_qkv() -> torch.Tensor: return torch.rand((3, 6)) @@ -336,9 +313,18 @@ def test_conv3_failures(self): _ = PackedConv3d(5, 2, kernel_size=1, alpha=1, num_estimators=1, gamma=-1) -class TestPackedGroupNorm: +class TestPackedLayerNorm: """Testing the PackedGroupNorm layer class.""" + def test_one_estimator_forward(self, batched_qkv: torch.Tensor): + packed_layer_norm = PackedLayerNorm( + embed_dim=6, + num_estimators=1, + alpha=1, + ) + out = packed_layer_norm(batched_qkv) + assert out.shape == torch.Size([2, 3, 6]) + class TestPackedMultiheadAttention: """Testing the PackedMultiheadAttention layer class.""" @@ -371,6 +357,7 @@ def test_one_estimator_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch alpha=1, num_estimators=1, batch_first=True, + bias=False, ) out, _ = layer( query=batched_qkv, @@ -379,7 +366,11 @@ def test_one_estimator_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch ) assert out.shape == torch.Size([2, 3, 6]) - def test_one_estimator_q_kv(self, unbatched_q_kv: torch.Tensor, batched_q_kv: torch.Tensor): + def test_one_estimator_q_kv( + self, + unbatched_q_kv: tuple[torch.Tensor, torch.Tensor], + batched_q_kv: tuple[torch.Tensor, torch.Tensor], + ): layer = PackedMultiheadAttention( embed_dim=6, num_heads=2, @@ -387,6 +378,7 @@ def test_one_estimator_q_kv(self, unbatched_q_kv: torch.Tensor, batched_q_kv: to num_estimators=1, kdim=2, vdim=2, + add_zero_attn=True, ) out, _ = layer( query=unbatched_q_kv[0], @@ -418,7 +410,11 @@ def test_one_estimator_q_kv(self, unbatched_q_kv: torch.Tensor, batched_q_kv: to ) assert out.shape == torch.Size([2, 3, 6]) - def test_one_estimator_q_k_v(self, unbatched_q_k_v: torch.Tensor, batched_q_k_v: torch.Tensor): + def test_one_estimator_q_k_v( + self, + unbatched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + batched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ): layer = PackedMultiheadAttention( embed_dim=6, num_heads=2, @@ -426,6 +422,7 @@ def test_one_estimator_q_k_v(self, unbatched_q_k_v: torch.Tensor, batched_q_k_v: num_estimators=1, kdim=2, vdim=4, + add_bias_kv=True, ) out, _ = layer( query=unbatched_q_k_v[0], @@ -452,9 +449,26 @@ def test_one_estimator_q_k_v(self, unbatched_q_k_v: torch.Tensor, batched_q_k_v: vdim=4, batch_first=True, ) + + layer.eval() + + attn_mask = torch.zeros(3, 4, dtype=torch.bool) + key_padding_mask = torch.zeros(2, 4, dtype=torch.bool) + out, _ = layer( query=batched_q_k_v[0], key=batched_q_k_v[1], value=batched_q_k_v[2], + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, ) assert out.shape == torch.Size([2, 3, 6]) + assert out.isfinite().all() + + +class TestPackedTransformerEncoderLayer: + """Testing the PackedTransformerEncoderLayer class.""" + + +class TestPackedTransformerDecoderLayer: + """Testing the PackedTransformerDecoderLayer class.""" diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py index 626eb8f9..38fe0b3b 100644 --- a/torch_uncertainty/layers/functional/packed.py +++ b/torch_uncertainty/layers/functional/packed.py @@ -34,16 +34,19 @@ def packed_linear( block_diag = torch.block_diag(*weight) return F.linear(inputs, block_diag, bias) if implementation == "sparse": - return (inputs @ weight.transpose(0, 1)) + bias + out = inputs @ weight.transpose(0, 1) + if bias is not None: + out += bias + return out if implementation == "einsum": - return ( - torch.einsum( - "...ki,kij->...kj", - rearrange(inputs, "... (m d) -> ... m d", m=num_groups), - weight.transpose(1, 2), - ).flatten(start_dim=-2) - + bias - ) + out = torch.einsum( + "...ki,kij->...kj", + rearrange(inputs, "... (m d) -> ... m d", m=num_groups), + weight.transpose(1, 2), + ).flatten(start_dim=-2) + if bias is not None: + out += bias + return out raise ValueError(f"Unknown implementation: {implementation}") diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 0c4793c4..ae835f83 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -566,9 +566,40 @@ def bias(self) -> Tensor | None: class PackedLayerNorm(nn.GroupNorm): + """Packed-Ensembles-style LayerNorm layer. + + Args: + embed_dim (int): the number of features in the input tensor. + num_estimators (int): the number of estimators in the ensemble. + alpha (float): the width multiplier of the layer. + eps (float, optional): a value added to the denominator for numerical stability. Defaults + to 1e-5. + affine (bool, optional): a boolean value that when set to ``True``, this module has + learnable per_channel affine parameters initialized to ones (for weights) and zeros + (for biases). Defaults to ``True``. + + Shape: + - Input: :math:`(N, *)` where :math:`*` means any number of additional dimensions. + - Output: :math:`(N, *)` (same shape as input) + """ + + def __init__( + self, + embed_dim: int, + num_estimators: int, + alpha: float, + eps: float = 1e-5, + affine: bool = True, + ) -> None: + super().__init__( + num_groups=num_estimators, + num_channels=int(embed_dim * alpha), + eps=eps, + affine=affine, + ) + def forward(self, inputs: Tensor) -> Tensor: - b, _, _ = inputs.size() - x = rearrange(inputs, "b s h -> (b s) h") + x = rearrange(inputs, "b ... h -> b h ...") x = F.group_norm( x, self.num_groups, @@ -576,7 +607,7 @@ def forward(self, inputs: Tensor) -> Tensor: self.bias, self.eps, ) - return rearrange(x, "(b s) h -> b s h", b=b) + return rearrange(x, "b h ... -> b ... h") class PackedMultiheadAttention(nn.Module): @@ -683,6 +714,12 @@ def __init__( else: self.register_parameter("in_proj_bias", None) + if add_bias_kv: + self.bias_k = nn.Parameter(torch.empty((1, 1, self.embed_dim), **factory_kwargs)) + self.bias_v = nn.Parameter(torch.empty((1, 1, self.embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + self.out_proj = PackedLinear( in_features=embed_dim, out_features=embed_dim, @@ -696,6 +733,8 @@ def __init__( **factory_kwargs, ) + self.add_zero_attn = add_zero_attn + self._reset_parameters() def _reset_parameters(self): @@ -712,19 +751,6 @@ def _reset_parameters(self): nn.init.constant_(self.in_proj_bias, 0.0) nn.init.constant_(self.out_proj.bias, 0.0) - def __setstate__(self, state): - """Support loading old MultiheadAttention checkpoints generated by - v1.1.0. - - Args: - state (_type_): _description_ - """ - # - if "_qkv_same_embed_dim" not in state: - state["_qkv_same_embed_dim"] = True - - super().__setstate__(state) - def forward( self, query: Tensor, @@ -779,9 +805,9 @@ def forward( self.num_groups, self.in_proj_weight, self.in_proj_bias, - None, - None, - False, + self.bias_k, + self.bias_v, + self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, @@ -809,9 +835,9 @@ def forward( self.num_groups, self.in_proj_weight, self.in_proj_bias, - None, - None, - False, + self.bias_k, + self.bias_v, + self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias,