diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 78834193..7a87e39f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,11 +32,11 @@ jobs: - name: Install requirements run: | - pip install --upgrade pip setuptools wheel build - pip install -r requirements.txt + python -m pip install --upgrade pip setuptools wheel build + python -m pip install -r requirements.txt - name: Install optional dependencies for Japanese models - run: pip install fugashi[unidic-lite] spacy[ja] + run: python -m pip install fugashi[unidic-lite] spacy[ja] - name: Build sdist run: python -m build --sdist @@ -52,20 +52,20 @@ jobs: - name: Uninstall all packages run: | - pip freeze - pip freeze --exclude pywin32 > installed.txt - pip uninstall -y -r installed.txt + python -m pip freeze + python -m pip freeze --exclude pywin32 > installed.txt + python -m pip uninstall -y -r installed.txt - name: Install from sdist run: | SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1) - pip install dist/$SDIST + python -m pip install dist/$SDIST shell: bash - name: Install test dependencies run: | - pip install -r requirements.txt - pip install fugashi[unidic-lite] spacy[ja] + python -m pip install -r requirements.txt + python -m pip install fugashi[unidic-lite] spacy[ja] - name: Run pytest run: python -m pytest --pyargs curated_transformers --slow @@ -73,8 +73,8 @@ jobs: - name: Install HF transformers id: init-hf-transformers run: | - pip install '${{ env.hf-transformers-pip }}' - echo "version=$(pip show transformers | grep -i version:)" >> $GITHUB_OUTPUT + python -m pip install '${{ env.hf-transformers-pip }}' + echo "version=$(python -m pip show transformers | grep -i version:)" >> $GITHUB_OUTPUT shell: bash # Disabled for Windows as the models appear to diff --git a/DEVELOP.md b/DEVELOP.md new file mode 100644 index 00000000..dc72c3e5 --- /dev/null +++ b/DEVELOP.md @@ -0,0 +1,92 @@ +# Development + +## TorchScript + +Every `torch.nn.Module` in this project must have a TorchScript conversion test. +TorchScript only supports a subset of Python and we want to make sure that all +models are convertable to TorchScript. + +In this section we will give some rules of thumb to avoid conversion errors. + +### Do not use global state + +TorchScript cannot use global state. One form of global state that we have +in this project is the `Errors` class. Consequently, we cannot use `Errors` +in `Module`s. The following is therefore invalid: + +```python +class Foo(nn.Module): + def forward(X: Tensor) -> Tensor: + # Problem: Errors fields are global state. + raise ValueError(Errors.E042) +``` + +In these cases we have to use an inline string instead: + +```python +class Foo(nn.Module): + def forward(X: Tensor) -> Tensor: + raise ValueError("This module does not do anything yet.") +``` + +For the same reason we can also not rely on `has_*` bools in a module: + +```python +class Foo(nn.Module): + def forward(X: Tensor) -> Tensor: + # Problem: conditional on global state. + if has_torch_feature: + ... +``` + +## Typing limitations + +TorchScript only supports a small [subset of Python types](https://pytorch.org/docs/stable/jit_language_reference.html#supported-type). +This also applies to type annotations. For instance, the following will not work, because +TorchScript only supports fully-specified tuple types: + +```python +class Foo(nn.Module): + # Problem: underspecified tuple + def shape(self) -> Tuple: + ... + + # Problem: underspecified tuple + def shape(self) -> Tuple[int, ...]: + ... +``` + +The following is ok, because it is a valid TorchScript type: + +```python +class Foo(nn.Module): + def shape(self) -> Tuple[int, int]: + ... +``` + +## Do not use `**kwargs` arguments + +TorchScript does not support `**kwargs` wildcards. So the following is +invalid: + +```python +class Foo(nn.Module): + ... + + def forward(X: Tensor, **kwargs) -> Tensor: + hidden = self.inner1(X) + return self.inner2(hidden, **kwargs) + +``` + +Instead we have to spell out all arguments, eg.: + +```python +class Foo(nn.Module): + ... + + def forward(X: Tensor, attn_mask: AttentionMask) -> Tensor: + hidden = self.inner1(X) + return self.inner2(hidden, attn_mask=attn_mask) + +``` diff --git a/curated_transformers/_compat.py b/curated_transformers/_compat.py index d8bb3493..5a172f98 100644 --- a/curated_transformers/_compat.py +++ b/curated_transformers/_compat.py @@ -1,10 +1,4 @@ -from os import environ from spacy.lang.ja import try_sudachi_import -import torch - -has_torch_sdp_attention = hasattr( - torch.nn.functional, "scaled_dot_product_attention" -) and environ.get("CURATED_TORCH_SDPA") try: import transformers diff --git a/curated_transformers/errors.py b/curated_transformers/errors.py index e71e8a3f..f0a2ea67 100644 --- a/curated_transformers/errors.py +++ b/curated_transformers/errors.py @@ -26,9 +26,6 @@ class Errors(metaclass=ErrorsWithCodes): "divisible by the number of self-attention heads ({num_heads})") E004 = ("The point-wise feed-forward network in the transformer only " "supports the following activation functions: {activation_funcs}") - E005 = ("Expected the attention mask to be of dtype 'torch.bool' but " - "found it be '{dtype}' instead") - E006 = ("The attention mask must be a 2D-tensor of shape [batch, seq_len]") E007 = ("Attempting to load the weights of an unsupported Hugging " "Face `transformers` model ({unsupported_model}). Currently " "supported models: {supported_models}") diff --git a/curated_transformers/models/pytorch/__init__.py b/curated_transformers/models/pytorch/__init__.py index 868c94a8..92c14829 100644 --- a/curated_transformers/models/pytorch/__init__.py +++ b/curated_transformers/models/pytorch/__init__.py @@ -3,4 +3,3 @@ from .albert.encoder import AlbertEncoder from .bert.encoder import BertEncoder from .roberta.encoder import RobertaEncoder -from .linear import Linear diff --git a/curated_transformers/models/pytorch/albert/encoder.py b/curated_transformers/models/pytorch/albert/encoder.py index da040f9d..649163be 100644 --- a/curated_transformers/models/pytorch/albert/encoder.py +++ b/curated_transformers/models/pytorch/albert/encoder.py @@ -63,11 +63,10 @@ def forward( layers_per_group = self.num_hidden_layers // len(self.groups) layer_outputs = [] - for i in range(self.num_hidden_layers): - layer_output = self.groups[i // layers_per_group]( - layer_output, attn_mask=attention_mask - ) - layer_outputs.append(layer_output) + for group in self.groups: + for _ in range(layers_per_group): + layer_output = group(layer_output, attn_mask=attention_mask) + layer_outputs.append(layer_output) return PyTorchTransformerOutput( embedding_output=embeddings, layer_hidden_states=layer_outputs diff --git a/curated_transformers/models/pytorch/albert/layer_group.py b/curated_transformers/models/pytorch/albert/layer_group.py index e25bf355..396fd505 100644 --- a/curated_transformers/models/pytorch/albert/layer_group.py +++ b/curated_transformers/models/pytorch/albert/layer_group.py @@ -1,6 +1,7 @@ from torch import Tensor from torch.nn import Module, ModuleList +from ..attention import AttentionMask from ..bert.config import BertAttentionConfig from ..bert.layer import BertAttentionConfig, BertEncoderLayer from .config import AlbertLayerConfig @@ -19,12 +20,12 @@ def __init__( ] ) - def forward(self, input: Tensor, **kwargs) -> Tensor: + def forward(self, input: Tensor, attn_mask: AttentionMask) -> Tensor: """ Shapes: input - (batch, seq_len, width) """ layer_output = input for layer in self.group_layers: - layer_output = layer(layer_output, **kwargs) + layer_output = layer(layer_output, attn_mask) return layer_output diff --git a/curated_transformers/models/pytorch/attention.py b/curated_transformers/models/pytorch/attention.py index 83ac3ac1..fad5e02b 100644 --- a/curated_transformers/models/pytorch/attention.py +++ b/curated_transformers/models/pytorch/attention.py @@ -1,5 +1,4 @@ -from typing import Optional, Tuple -from dataclasses import dataclass +from typing import Optional import math import torch from torch import Tensor @@ -8,14 +7,15 @@ from ...errors import Errors -@dataclass class AttentionMask: bool_mask: Tensor - _logit_mask: Optional[Tensor] = None + _logit_mask: Optional[Tensor] - def __post_init__(self): - if self.bool_mask.dtype != torch.bool: - raise ValueError(Errors.E005.format(dtype=self.bool_mask.dtype)) + def __init__(self, bool_mask: Tensor): + if bool_mask.dtype != torch.bool: + raise ValueError("Expected the attention mask to be of dtype 'torch.bool'") + self.bool_mask = bool_mask + self._logit_mask = torch.jit.annotate(Optional[Tensor], None) @property def logit_mask(self) -> Tensor: @@ -23,13 +23,17 @@ def logit_mask(self) -> Tensor: # The value is `torch.finfo(attn_scores.dype).min`. Unfortunately, # we cannot use `torch.finfo` in TorchScript. self._logit_mask = (1.0 - self.bool_mask.int()) * -3.4028234663852886e38 - return self._logit_mask + + # Narrow type for TorchScript. + logit_mask = self._logit_mask + assert logit_mask is not None + return logit_mask def dim(self) -> int: return self.bool_mask.dim() @property - def shape(self) -> Tuple: + def shape(self): return self.bool_mask.shape @@ -49,7 +53,9 @@ def forward( """ if attn_mask.dim() != 2: - raise ValueError(Errors.E006) + raise ValueError( + "The attention mask must be a 2D-tensor of shape [batch, seq_len]" + ) model_dim = k.shape[-1] attn_scores = q @ k.transpose(-2, -1) diff --git a/curated_transformers/models/pytorch/bert/layer.py b/curated_transformers/models/pytorch/bert/layer.py index f807623a..edadb98d 100644 --- a/curated_transformers/models/pytorch/bert/layer.py +++ b/curated_transformers/models/pytorch/bert/layer.py @@ -1,12 +1,10 @@ import torch -from torch.nn import Module +from torch.nn import Linear, Module from torch import Tensor from .. import GeluNew from ..attention import AttentionMask, ScaledDotProductAttention from .config import BertAttentionConfig, BertLayerConfig -from ..linear import Linear -from ...._compat import has_torch_sdp_attention from ....errors import Errors @@ -15,7 +13,6 @@ class BertSelfAttention(Module): def __init__(self, config: BertAttentionConfig): super().__init__() - self.dropout_prob = config.dropout_prob self.model_dim = config.hidden_width self.num_heads = config.num_attention_heads if self.model_dim % self.num_heads != 0: @@ -58,6 +55,7 @@ def forward(self, x: Tensor, attn_mask: AttentionMask) -> Tensor: x - (batch, seq_len, width) attn_mask - (batch, seq_len) """ + proj = self.input(x) q, k, v = proj.chunk(3, dim=-1) @@ -66,21 +64,8 @@ def forward(self, x: Tensor, attn_mask: AttentionMask) -> Tensor: q = self._split_heads(q) v = self._split_heads(v) - # attn: (batch, head, seq_len, with_per_head) - if has_torch_sdp_attention: - batch, seq_len = attn_mask.shape - mask = attn_mask.bool_mask.view(batch, 1, 1, seq_len) - # Ignore because we still support torch<2.0.0 and older versions - # do not have this attribute. - attn = torch.nn.functional.scaled_dot_product_attention( # type: ignore[attr-defined] - q, k, v, mask, self.dropout_prob if self.training else 0.0 - ) - else: - attn = self.attention(k, q, v, attn_mask) - - # attn: (batch, seq_len, width) - attn = self._combine_heads(attn) - + # (batch, seq_len, width) + attn = self._combine_heads(self.attention(k, q, v, attn_mask)) out = self.output(attn) return out diff --git a/curated_transformers/models/pytorch/curated_transformer.py b/curated_transformers/models/pytorch/curated_transformer.py index 2c1d1201..c440634e 100644 --- a/curated_transformers/models/pytorch/curated_transformer.py +++ b/curated_transformers/models/pytorch/curated_transformer.py @@ -17,13 +17,12 @@ class CuratedTransformer(Generic[CuratedEncoderT], Module): """Simple wrapper for encoders. Currently only used to add a predictable prefix (curated_encoder) to encoders.""" - curated_encoder: CuratedEncoderT - - __slots__ = ["curated_encoder"] - def __init__(self, encoder: CuratedEncoderT) -> None: super().__init__() - self.curated_encoder = encoder + + # Type ignore, because TorchScript does not allow Module + # as a class variable type. + self.curated_encoder = encoder # type: ignore[var-annotated] def forward( self, diff --git a/curated_transformers/models/pytorch/linear.py b/curated_transformers/models/pytorch/linear.py deleted file mode 100644 index a50c906d..00000000 --- a/curated_transformers/models/pytorch/linear.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -from torch import Tensor -import torch.nn as nn -import torch.nn.functional as F - - -class Linear(nn.Linear): - def forward(self, input: Tensor) -> Tensor: - # Work around issue with linear with the MPS backend. See: - # https://github.com/pytorch/pytorch/issues/97239 - if hasattr(input, "is_mps") and input.is_mps: - return torch.matmul(input, self.weight.t()) + self.bias - else: - return F.linear(input, self.weight, self.bias) diff --git a/curated_transformers/tests/models/pytorch/test_torchscript.py b/curated_transformers/tests/models/pytorch/test_torchscript.py new file mode 100644 index 00000000..b0c9a81d --- /dev/null +++ b/curated_transformers/tests/models/pytorch/test_torchscript.py @@ -0,0 +1,58 @@ +import copy +import pytest +import torch.jit + +from curated_transformers.models.with_strided_spans import ( + build_with_strided_spans_v1, + with_strided_spans, +) +from curated_transformers.tokenization.wordpiece_encoder import ( + build_wordpiece_encoder_v1, +) +from curated_transformers.models.architectures import ( + build_albert_transformer_model_v1, + build_bert_transformer_model_v1, + build_camembert_transformer_model_v1, + build_roberta_transformer_model_v1, + build_xlmr_transformer_model_v1, +) + +MODEL_CONSTRUCTORS = [ + build_albert_transformer_model_v1, + build_bert_transformer_model_v1, + build_camembert_transformer_model_v1, + build_roberta_transformer_model_v1, + build_xlmr_transformer_model_v1, +] + + +@pytest.mark.slow +@pytest.mark.parametrize("model_factory", MODEL_CONSTRUCTORS) +def test_encoder_deepcopy(model_factory): + # Not necessarily a TorchScript test, but we often want to + # copy a module before TorchScript conversion (see e.g. + # quantization). + + # Use a small vocab to limit memory use. + model = model_factory( + piece_encoder=build_wordpiece_encoder_v1(), + with_spans=build_with_strided_spans_v1(), + vocab_size=128, + ) + model.initialize() + encoder = model.get_ref("transformer").shims[0]._model + copy.deepcopy(encoder) + + +@pytest.mark.slow +@pytest.mark.parametrize("model_factory", MODEL_CONSTRUCTORS) +def test_encoder_torchscript(model_factory): + # Use a small vocab to limit memory use. + model = model_factory( + piece_encoder=build_wordpiece_encoder_v1(), + with_spans=build_with_strided_spans_v1(), + vocab_size=128, + ) + model.initialize() + encoder = model.get_ref("transformer").shims[0]._model + torch.jit.script(encoder)