Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AttentionScorer abstraction #349

Merged
merged 2 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions curated_transformers/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AttentionHeads,
AttentionLinearBiases,
AttentionMask,
AttentionScorer,
QkvMode,
QkvSplit,
ScaledDotProductAttention,
Expand Down Expand Up @@ -32,6 +33,7 @@
"AttentionHeads",
"AttentionLinearBiases",
"AttentionMask",
"AttentionScorer",
"CacheProtocol",
"DecoderLayer",
"EmbeddingDropouts",
Expand Down
161 changes: 85 additions & 76 deletions curated_transformers/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import Dropout, Linear, Module

from ..semver import Default, FutureMandatory
from ..util.dataclass import DataclassAsDict
Expand Down Expand Up @@ -633,34 +633,12 @@ def forward(self, *, attention_scores: Tensor, inplace: bool = True) -> Tensor:
return attention_scores + biases


class ScaledDotProductAttention(Module):
class AttentionScorer(Module, ABC):
"""
Scaled dot-product attention (`Vaswani et al., 2017`_).

.. _Vaswani et al., 2017: https://arxiv.org/abs/1706.03762
Base class of attention scoring implementations.
"""

linear_biases: Optional[AttentionLinearBiases]

def __init__(
self, *, dropout_prob: float, linear_biases: Optional[AttentionLinearBiases]
):
"""
Construct a scaled dot-product attention module.

:param dropout_prob:
Dropout to apply to the final hidden representation.
:param linear_biases:
ALiBi (`Press et al., 2022`_) for attention scores.
Not applied if ``None``.

.. _Press et al., 2022: https://arxiv.org/abs/2108.12409
"""
super().__init__()

self.dropout = torch.nn.Dropout(p=dropout_prob)
self.linear_biases = linear_biases

@abstractmethod
def forward(
self,
*,
Expand All @@ -670,9 +648,9 @@ def forward(
attention_mask: AttentionMask,
) -> Tensor:
"""
Apply attention layer to the given key, query and value.
Apply attention scores to the given key, query and value.

Sequence elements that are marked with `False` in the attention mask
Sequence elements that are marked with ``False`` in the attention mask
are ignored by the attention mechanism (if a mask is provided).

:param query:
Expand All @@ -696,18 +674,78 @@ def forward(

*Shape:* ``(batch_size, heads, seq_len, width)``
"""
model_width = key.shape[-1]
attn_scores = query @ key.transpose(-2, -1)
attn_scores /= math.sqrt(model_width)
...


class ScaledDotProductAttention(AttentionScorer):
"""
Scaled dot-product attention (`Vaswani et al., 2017`_).

.. _Vaswani et al., 2017: https://arxiv.org/abs/1706.03762
"""

linear_biases: Optional[AttentionLinearBiases]

def __init__(
self, *, dropout_prob: float, linear_biases: Optional[AttentionLinearBiases]
):
"""
Construct a scaled dot-product attention module.

:param dropout_prob:
Dropout to apply to the final hidden representation.
:param linear_biases:
ALiBi (`Press et al., 2022`_) for attention scores.
Not applied if ``None``.

.. _Press et al., 2022: https://arxiv.org/abs/2108.12409
"""
super().__init__()

if self.linear_biases is not None:
attn_scores = self.linear_biases(attention_scores=attn_scores)
self.dropout = Dropout(p=dropout_prob)
self.linear_biases = linear_biases

def forward(
self,
*,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: AttentionMask,
) -> Tensor:
if _TORCH_SDP.get():
attn_mask = attention_mask.logit_mask(query.dtype)

# Add AliBi to the logit mask
if self.linear_biases is not None:
biases = self.linear_biases.calculate_biases(key.size(-2)).to(
dtype=query.dtype, device=query.device
)
bool_mask = attention_mask.bool_mask
attn_mask = torch.where(bool_mask, biases, attn_mask)

# We can't pass a bool mask, because it is currently broken:
# https://github.com/pytorch/pytorch/issues/103749
return F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
)
else:
width = key.shape[-1]
attn_scores = query @ key.transpose(-2, -1)
attn_scores /= math.sqrt(width)

if self.linear_biases is not None:
attn_scores = self.linear_biases(attention_scores=attn_scores)

attn_scores = attention_mask.apply_logit_mask(attn_scores)
attn_weights = attn_scores.softmax(dim=-1)
attn_values = self.dropout(attn_weights @ value)
attn_scores = attention_mask.apply_logit_mask(attn_scores)
attn_weights = attn_scores.softmax(dim=-1)
attn_values = self.dropout(attn_weights @ value)

return attn_values
return attn_values


class SelfAttention(Module):
Expand All @@ -723,11 +761,10 @@ def __init__(
self,
*,
attention_heads: AttentionHeads,
dropout_prob: float,
attention_scorer: AttentionScorer,
hidden_width: int,
qkv_mode: QkvMode,
rotary_embeds: Optional[QueryKeyRotaryEmbeddings] = None,
attention_biases: Optional[AttentionLinearBiases] = None,
use_bias: bool,
device: Optional[torch.device] = None,
):
Expand All @@ -737,12 +774,10 @@ def __init__(

:param attention_heads:
Attention head configuration.
:param dropout_prob:
Dropout to apply between the self-attention and output layers.
:param attention_scorer:
Attention scorer used to calculate the attention values.
:param hidden_width:
Hidden width of the layer.
:param attention_biases:
ALiBi biases. ALiBi will not be used when set to ``None``.
:param qkv_mode:
Handling mode for query, key and value.
:param rotary_embeds:
Expand All @@ -756,7 +791,6 @@ def __init__(

super().__init__()

self.dropout_prob = dropout_prob
self.attention_heads = attention_heads
if hidden_width % attention_heads._n_query_heads != 0:
raise ValueError(
Expand All @@ -766,13 +800,10 @@ def __init__(

self.head_width = hidden_width // attention_heads._n_query_heads
self.qkv_mode = qkv_mode
self.use_alibi = attention_biases is not None

self.rotary_embeds = rotary_embeds

self.attention = ScaledDotProductAttention(
dropout_prob=dropout_prob, linear_biases=attention_biases
)
self.attention_scorer = attention_scorer

if (
qkv_mode == QkvMode.MERGED_SPLIT_BEFORE
Expand Down Expand Up @@ -877,34 +908,12 @@ def forward(
causal_mask = create_causal_mask(query, key)
combined_mask = combined_mask.merge_mask(causal_mask)

if _TORCH_SDP.get():
attn_mask = combined_mask.logit_mask(query.dtype)

# Add AliBi to the logit mask
if self.use_alibi:
assert self.attention.linear_biases is not None
biases = self.attention.linear_biases.calculate_biases(key.size(-2)).to(
dtype=query.dtype, device=query.device
)
bool_mask = combined_mask.bool_mask
attn_mask = torch.where(bool_mask, biases, attn_mask)

# We can't pass a bool mask, because it is currently broken:
# https://github.com/pytorch/pytorch/issues/103749
attn = F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
)
else:
attn = self.attention(
query=query,
key=key,
value=value,
attention_mask=combined_mask,
)
attn = self.attention_scorer(
query=query,
key=key,
value=value,
attention_mask=combined_mask,
)

attn = combine_heads(attn)

Expand Down
13 changes: 11 additions & 2 deletions curated_transformers/models/albert/layer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from torch import Tensor
from torch.nn import LayerNorm, Module, ModuleList

from ...layers.attention import AttentionHeads, AttentionMask, QkvMode, SelfAttention
from ...layers.attention import (
AttentionHeads,
AttentionMask,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.feedforward import PointwiseFeedForward
from ...layers.transformer import (
EncoderLayer,
Expand Down Expand Up @@ -41,7 +47,10 @@ def __init__(
attention_heads=AttentionHeads.uniform(
attention_config.n_query_heads
),
dropout_prob=attention_config.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=attention_config.dropout_prob,
linear_biases=None,
),
hidden_width=layer_config.feedforward.hidden_width,
qkv_mode=QkvMode.SEPARATE,
rotary_embeds=None,
Expand Down
12 changes: 10 additions & 2 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from torch import Tensor
from torch.nn import Dropout, LayerNorm

from ...layers.attention import AttentionHeads, QkvMode, SelfAttention
from ...layers.attention import (
AttentionHeads,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.feedforward import PointwiseFeedForward
from ...layers.transformer import (
EmbeddingDropouts,
Expand Down Expand Up @@ -77,7 +82,10 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None)
attention_heads=AttentionHeads.uniform(
config.layer.attention.n_query_heads
),
dropout_prob=config.layer.attention.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
linear_biases=None,
),
hidden_width=config.layer.feedforward.hidden_width,
qkv_mode=QkvMode.SEPARATE,
rotary_embeds=None,
Expand Down
9 changes: 6 additions & 3 deletions curated_transformers/models/falcon/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import torch
from torch import Tensor
from torch.nn import Dropout, Embedding, LayerNorm, ModuleList
from torch.nn import Dropout, LayerNorm, ModuleList

from ...layers.attention import (
AttentionHeads,
AttentionLinearBiases,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.embeddings import QueryKeyRotaryEmbeddings
Expand Down Expand Up @@ -150,12 +151,14 @@ def _create_new_decoder_architecture_layer(
)
return DecoderLayer(
attention_layer=SelfAttention(
attention_biases=attention_biases,
attention_heads=AttentionHeads.key_value_broadcast(
n_query_heads=n_attention_heads,
n_key_value_heads=config.layer.attention.n_key_value_heads,
),
dropout_prob=config.layer.attention.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
linear_biases=attention_biases,
),
hidden_width=hidden_width,
qkv_mode=QkvMode.MERGED_SPLIT_AFTER,
rotary_embeds=rotary_embeds,
Expand Down
7 changes: 5 additions & 2 deletions curated_transformers/models/falcon/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AttentionMask,
KeyValueCache,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.embeddings import QueryKeyRotaryEmbeddings
Expand Down Expand Up @@ -62,8 +63,10 @@ def __init__(
else None
)
self.mha = SelfAttention(
attention_biases=attention_biases,
dropout_prob=attention_config.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=attention_config.dropout_prob,
linear_biases=attention_biases,
),
hidden_width=hidden_width,
attention_heads=AttentionHeads.key_value_broadcast(
n_query_heads=attention_config.n_query_heads,
Expand Down
12 changes: 10 additions & 2 deletions curated_transformers/models/gpt_neox/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from torch import Tensor
from torch.nn import Dropout, LayerNorm, ModuleList

from ...layers.attention import AttentionHeads, QkvMode, SelfAttention
from ...layers.attention import (
AttentionHeads,
QkvMode,
ScaledDotProductAttention,
SelfAttention,
)
from ...layers.embeddings import QueryKeyRotaryEmbeddings
from ...layers.feedforward import PointwiseFeedForward
from ...layers.transformer import (
Expand Down Expand Up @@ -78,7 +83,10 @@ def __init__(
DecoderLayer(
attention_layer=SelfAttention(
attention_heads=AttentionHeads.uniform(n_attention_heads),
dropout_prob=config.layer.attention.dropout_prob,
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
linear_biases=None,
),
hidden_width=hidden_width,
qkv_mode=QkvMode.MERGED_SPLIT_BEFORE,
rotary_embeds=QueryKeyRotaryEmbeddings(
Expand Down
Loading
Loading