Skip to content

Commit d78f768

Browse files
authored
Add a workaround for invalid outputs of nn.Linear on MPS (#124)
`nn.Linear` produces incorrect outputs with certain matrix sizes when using the MPS backend: pytorch/pytorch#97239 The actual issue is in the underlying `torch.nn.functional.linear` function. Work around this by using an explicit matrix multiplication when the MPS backend is used.
1 parent 5fd5f0e commit d78f768

File tree

4 files changed

+22
-8
lines changed

4 files changed

+22
-8
lines changed

curated_transformers/models/pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .albert.encoder import AlbertEncoder
44
from .bert.encoder import BertEncoder
55
from .roberta.encoder import RobertaEncoder
6+
from .linear import Linear

curated_transformers/models/pytorch/bert/layer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .. import GeluNew
66
from ..attention import AttentionMask, ScaledDotProductAttention
77
from .config import BertAttentionConfig, BertLayerConfig
8+
from ..linear import Linear
89
from ....errors import Errors
910

1011

@@ -24,8 +25,8 @@ def __init__(self, config: BertAttentionConfig):
2425

2526
self.dims_per_head = self.model_dim // self.num_heads
2627
self.attention = ScaledDotProductAttention(dropout_prob=config.dropout_prob)
27-
self.input = torch.nn.Linear(self.model_dim, self.model_dim * 3)
28-
self.output = torch.nn.Linear(self.model_dim, self.model_dim)
28+
self.input = Linear(self.model_dim, self.model_dim * 3)
29+
self.output = Linear(self.model_dim, self.model_dim)
2930

3031
def _split_heads(self, x: Tensor) -> Tensor:
3132
"""
@@ -75,10 +76,8 @@ class BertFeedForward(Module):
7576
def __init__(self, config: BertLayerConfig):
7677
super().__init__()
7778

78-
self.intermediate = torch.nn.Linear(
79-
config.hidden_width, config.intermediate_width
80-
)
81-
self.output = torch.nn.Linear(config.intermediate_width, config.hidden_width)
79+
self.intermediate = Linear(config.hidden_width, config.intermediate_width)
80+
self.output = Linear(config.intermediate_width, config.hidden_width)
8281
if config.hidden_act == "relu":
8382
self.activation = torch.nn.ReLU() # type: ignore
8483
elif config.hidden_act == "gelu":
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
from torch import Tensor
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
7+
class Linear(nn.Linear):
8+
def forward(self, input: Tensor) -> Tensor:
9+
# Work around issue with linear with the MPS backend. See:
10+
# https://github.com/pytorch/pytorch/issues/97239
11+
if hasattr(input, "is_mps") and input.is_mps:
12+
return torch.matmul(input, self.weight.t()) + self.bias
13+
else:
14+
return F.linear(input, self.weight, self.bias)

curated_transformers/tests/models/test_hf_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def test_model_against_hf_transformers(model_config):
7272
Y_hf_encoder = hf_encoder(X, attention_mask=attention_mask)
7373

7474
assert torch.allclose(
75-
Y_encoder.last_hidden_layer_states, Y_hf_encoder.last_hidden_state
75+
Y_encoder.last_hidden_layer_states, Y_hf_encoder.last_hidden_state, atol=1e-6
7676
)
7777

7878
# Try to infer the attention mask from padding.
7979
Y_encoder = encoder(X)
8080
assert torch.allclose(
81-
Y_encoder.last_hidden_layer_states, Y_hf_encoder.last_hidden_state
81+
Y_encoder.last_hidden_layer_states, Y_hf_encoder.last_hidden_state, atol=1e-6
8282
)

0 commit comments

Comments
 (0)