Skip to content

Commit

Permalink
Restore TorchScript functionality (necessary for quantization) (#129)
Browse files Browse the repository at this point in the history
* Add TorchScript and deepcopy tests

* TorchScript fix: ModuleList can only be indexed with literals

* TorchScript fix: **kwargs is not allowed

* TorchScript fix: we cannot condition on global state

So remove our custom workaround for macOS 13.2, this is fixed
in 13.3.

* TorchScript fix: TorchScript does not allow Module type annotation

* TorchScript fixes: many fixes for the Attention class

- Ensure TorchScript type inference works.
- We can't reference global variables, including errors.
- Dataclasses do not work well.
- We need __init__ that can be found in source (not synthesized).
- The tuple type only works fully specified (not Tuple or Tuple[int, ...])

* Revert "Add support for Torch `scaled_dot_product_attention` (#128)"

This reverts commit 68a355a.

The functionality introduced in this PR uses global state to detect
whether the `scaled_dot_product_attention` is available and check
whether the user want to use it. However, we cannot rely on global
state in TorchScript.

* Attempt to fix CI pip issues

* Describe some TorchScript rules of thumb in DEVELOP.md

* Simplify TorchScript type inference

* Remove unused imports
  • Loading branch information
danieldk authored Apr 18, 2023
1 parent 68a355a commit 9808fa0
Show file tree
Hide file tree
Showing 12 changed files with 192 additions and 76 deletions.
22 changes: 11 additions & 11 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ jobs:

- name: Install requirements
run: |
pip install --upgrade pip setuptools wheel build
pip install -r requirements.txt
python -m pip install --upgrade pip setuptools wheel build
python -m pip install -r requirements.txt
- name: Install optional dependencies for Japanese models
run: pip install fugashi[unidic-lite] spacy[ja]
run: python -m pip install fugashi[unidic-lite] spacy[ja]

- name: Build sdist
run: python -m build --sdist
Expand All @@ -52,29 +52,29 @@ jobs:

- name: Uninstall all packages
run: |
pip freeze
pip freeze --exclude pywin32 > installed.txt
pip uninstall -y -r installed.txt
python -m pip freeze
python -m pip freeze --exclude pywin32 > installed.txt
python -m pip uninstall -y -r installed.txt
- name: Install from sdist
run: |
SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1)
pip install dist/$SDIST
python -m pip install dist/$SDIST
shell: bash

- name: Install test dependencies
run: |
pip install -r requirements.txt
pip install fugashi[unidic-lite] spacy[ja]
python -m pip install -r requirements.txt
python -m pip install fugashi[unidic-lite] spacy[ja]
- name: Run pytest
run: python -m pytest --pyargs curated_transformers --slow

- name: Install HF transformers
id: init-hf-transformers
run: |
pip install '${{ env.hf-transformers-pip }}'
echo "version=$(pip show transformers | grep -i version:)" >> $GITHUB_OUTPUT
python -m pip install '${{ env.hf-transformers-pip }}'
echo "version=$(python -m pip show transformers | grep -i version:)" >> $GITHUB_OUTPUT
shell: bash

# Disabled for Windows as the models appear to
Expand Down
92 changes: 92 additions & 0 deletions DEVELOP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Development

## TorchScript

Every `torch.nn.Module` in this project must have a TorchScript conversion test.
TorchScript only supports a subset of Python and we want to make sure that all
models are convertable to TorchScript.

In this section we will give some rules of thumb to avoid conversion errors.

### Do not use global state

TorchScript cannot use global state. One form of global state that we have
in this project is the `Errors` class. Consequently, we cannot use `Errors`
in `Module`s. The following is therefore invalid:

```python
class Foo(nn.Module):
def forward(X: Tensor) -> Tensor:
# Problem: Errors fields are global state.
raise ValueError(Errors.E042)
```

In these cases we have to use an inline string instead:

```python
class Foo(nn.Module):
def forward(X: Tensor) -> Tensor:
raise ValueError("This module does not do anything yet.")
```

For the same reason we can also not rely on `has_*` bools in a module:

```python
class Foo(nn.Module):
def forward(X: Tensor) -> Tensor:
# Problem: conditional on global state.
if has_torch_feature:
...
```

## Typing limitations

TorchScript only supports a small [subset of Python types](https://pytorch.org/docs/stable/jit_language_reference.html#supported-type).
This also applies to type annotations. For instance, the following will not work, because
TorchScript only supports fully-specified tuple types:

```python
class Foo(nn.Module):
# Problem: underspecified tuple
def shape(self) -> Tuple:
...

# Problem: underspecified tuple
def shape(self) -> Tuple[int, ...]:
...
```

The following is ok, because it is a valid TorchScript type:

```python
class Foo(nn.Module):
def shape(self) -> Tuple[int, int]:
...
```

## Do not use `**kwargs` arguments

TorchScript does not support `**kwargs` wildcards. So the following is
invalid:

```python
class Foo(nn.Module):
...

def forward(X: Tensor, **kwargs) -> Tensor:
hidden = self.inner1(X)
return self.inner2(hidden, **kwargs)

```

Instead we have to spell out all arguments, eg.:

```python
class Foo(nn.Module):
...

def forward(X: Tensor, attn_mask: AttentionMask) -> Tensor:
hidden = self.inner1(X)
return self.inner2(hidden, attn_mask=attn_mask)

```
6 changes: 0 additions & 6 deletions curated_transformers/_compat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from os import environ
from spacy.lang.ja import try_sudachi_import
import torch

has_torch_sdp_attention = hasattr(
torch.nn.functional, "scaled_dot_product_attention"
) and environ.get("CURATED_TORCH_SDPA")

try:
import transformers
Expand Down
3 changes: 0 additions & 3 deletions curated_transformers/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ class Errors(metaclass=ErrorsWithCodes):
"divisible by the number of self-attention heads ({num_heads})")
E004 = ("The point-wise feed-forward network in the transformer only "
"supports the following activation functions: {activation_funcs}")
E005 = ("Expected the attention mask to be of dtype 'torch.bool' but "
"found it be '{dtype}' instead")
E006 = ("The attention mask must be a 2D-tensor of shape [batch, seq_len]")
E007 = ("Attempting to load the weights of an unsupported Hugging "
"Face `transformers` model ({unsupported_model}). Currently "
"supported models: {supported_models}")
Expand Down
1 change: 0 additions & 1 deletion curated_transformers/models/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@
from .albert.encoder import AlbertEncoder
from .bert.encoder import BertEncoder
from .roberta.encoder import RobertaEncoder
from .linear import Linear
9 changes: 4 additions & 5 deletions curated_transformers/models/pytorch/albert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@ def forward(
layers_per_group = self.num_hidden_layers // len(self.groups)

layer_outputs = []
for i in range(self.num_hidden_layers):
layer_output = self.groups[i // layers_per_group](
layer_output, attn_mask=attention_mask
)
layer_outputs.append(layer_output)
for group in self.groups:
for _ in range(layers_per_group):
layer_output = group(layer_output, attn_mask=attention_mask)
layer_outputs.append(layer_output)

return PyTorchTransformerOutput(
embedding_output=embeddings, layer_hidden_states=layer_outputs
Expand Down
5 changes: 3 additions & 2 deletions curated_transformers/models/pytorch/albert/layer_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch import Tensor
from torch.nn import Module, ModuleList

from ..attention import AttentionMask
from ..bert.config import BertAttentionConfig
from ..bert.layer import BertAttentionConfig, BertEncoderLayer
from .config import AlbertLayerConfig
Expand All @@ -19,12 +20,12 @@ def __init__(
]
)

def forward(self, input: Tensor, **kwargs) -> Tensor:
def forward(self, input: Tensor, attn_mask: AttentionMask) -> Tensor:
"""
Shapes:
input - (batch, seq_len, width)
"""
layer_output = input
for layer in self.group_layers:
layer_output = layer(layer_output, **kwargs)
layer_output = layer(layer_output, attn_mask)
return layer_output
26 changes: 16 additions & 10 deletions curated_transformers/models/pytorch/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Optional
import math
import torch
from torch import Tensor
Expand All @@ -8,28 +7,33 @@
from ...errors import Errors


@dataclass
class AttentionMask:
bool_mask: Tensor
_logit_mask: Optional[Tensor] = None
_logit_mask: Optional[Tensor]

def __post_init__(self):
if self.bool_mask.dtype != torch.bool:
raise ValueError(Errors.E005.format(dtype=self.bool_mask.dtype))
def __init__(self, bool_mask: Tensor):
if bool_mask.dtype != torch.bool:
raise ValueError("Expected the attention mask to be of dtype 'torch.bool'")
self.bool_mask = bool_mask
self._logit_mask = torch.jit.annotate(Optional[Tensor], None)

@property
def logit_mask(self) -> Tensor:
if self._logit_mask is None:
# The value is `torch.finfo(attn_scores.dype).min`. Unfortunately,
# we cannot use `torch.finfo` in TorchScript.
self._logit_mask = (1.0 - self.bool_mask.int()) * -3.4028234663852886e38
return self._logit_mask

# Narrow type for TorchScript.
logit_mask = self._logit_mask
assert logit_mask is not None
return logit_mask

def dim(self) -> int:
return self.bool_mask.dim()

@property
def shape(self) -> Tuple:
def shape(self):
return self.bool_mask.shape


Expand All @@ -49,7 +53,9 @@ def forward(
"""

if attn_mask.dim() != 2:
raise ValueError(Errors.E006)
raise ValueError(
"The attention mask must be a 2D-tensor of shape [batch, seq_len]"
)

model_dim = k.shape[-1]
attn_scores = q @ k.transpose(-2, -1)
Expand Down
23 changes: 4 additions & 19 deletions curated_transformers/models/pytorch/bert/layer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import torch
from torch.nn import Module
from torch.nn import Linear, Module
from torch import Tensor

from .. import GeluNew
from ..attention import AttentionMask, ScaledDotProductAttention
from .config import BertAttentionConfig, BertLayerConfig
from ..linear import Linear
from ...._compat import has_torch_sdp_attention
from ....errors import Errors


Expand All @@ -15,7 +13,6 @@ class BertSelfAttention(Module):
def __init__(self, config: BertAttentionConfig):
super().__init__()

self.dropout_prob = config.dropout_prob
self.model_dim = config.hidden_width
self.num_heads = config.num_attention_heads
if self.model_dim % self.num_heads != 0:
Expand Down Expand Up @@ -58,6 +55,7 @@ def forward(self, x: Tensor, attn_mask: AttentionMask) -> Tensor:
x - (batch, seq_len, width)
attn_mask - (batch, seq_len)
"""

proj = self.input(x)
q, k, v = proj.chunk(3, dim=-1)

Expand All @@ -66,21 +64,8 @@ def forward(self, x: Tensor, attn_mask: AttentionMask) -> Tensor:
q = self._split_heads(q)
v = self._split_heads(v)

# attn: (batch, head, seq_len, with_per_head)
if has_torch_sdp_attention:
batch, seq_len = attn_mask.shape
mask = attn_mask.bool_mask.view(batch, 1, 1, seq_len)
# Ignore because we still support torch<2.0.0 and older versions
# do not have this attribute.
attn = torch.nn.functional.scaled_dot_product_attention( # type: ignore[attr-defined]
q, k, v, mask, self.dropout_prob if self.training else 0.0
)
else:
attn = self.attention(k, q, v, attn_mask)

# attn: (batch, seq_len, width)
attn = self._combine_heads(attn)

# (batch, seq_len, width)
attn = self._combine_heads(self.attention(k, q, v, attn_mask))
out = self.output(attn)

return out
Expand Down
9 changes: 4 additions & 5 deletions curated_transformers/models/pytorch/curated_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ class CuratedTransformer(Generic[CuratedEncoderT], Module):
"""Simple wrapper for encoders. Currently only used to add a predictable
prefix (curated_encoder) to encoders."""

curated_encoder: CuratedEncoderT

__slots__ = ["curated_encoder"]

def __init__(self, encoder: CuratedEncoderT) -> None:
super().__init__()
self.curated_encoder = encoder

# Type ignore, because TorchScript does not allow Module
# as a class variable type.
self.curated_encoder = encoder # type: ignore[var-annotated]

def forward(
self,
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/models/pytorch/linear.py

This file was deleted.

Loading

0 comments on commit 9808fa0

Please sign in to comment.