Skip to content

Commit

Permalink
Add a workaround for invalid outputs of nn.Linear on MPS (#124)
Browse files Browse the repository at this point in the history
`nn.Linear` produces incorrect outputs with certain matrix sizes when
using the MPS backend:

pytorch/pytorch#97239

The actual issue is in the underlying `torch.nn.functional.linear`
function. Work around this by using an explicit matrix multiplication
when the MPS backend is used.
  • Loading branch information
danieldk authored Mar 22, 2023
1 parent 5fd5f0e commit d78f768
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 8 deletions.
1 change: 1 addition & 0 deletions curated_transformers/models/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .albert.encoder import AlbertEncoder
from .bert.encoder import BertEncoder
from .roberta.encoder import RobertaEncoder
from .linear import Linear
11 changes: 5 additions & 6 deletions curated_transformers/models/pytorch/bert/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .. import GeluNew
from ..attention import AttentionMask, ScaledDotProductAttention
from .config import BertAttentionConfig, BertLayerConfig
from ..linear import Linear
from ....errors import Errors


Expand All @@ -24,8 +25,8 @@ def __init__(self, config: BertAttentionConfig):

self.dims_per_head = self.model_dim // self.num_heads
self.attention = ScaledDotProductAttention(dropout_prob=config.dropout_prob)
self.input = torch.nn.Linear(self.model_dim, self.model_dim * 3)
self.output = torch.nn.Linear(self.model_dim, self.model_dim)
self.input = Linear(self.model_dim, self.model_dim * 3)
self.output = Linear(self.model_dim, self.model_dim)

def _split_heads(self, x: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -75,10 +76,8 @@ class BertFeedForward(Module):
def __init__(self, config: BertLayerConfig):
super().__init__()

self.intermediate = torch.nn.Linear(
config.hidden_width, config.intermediate_width
)
self.output = torch.nn.Linear(config.intermediate_width, config.hidden_width)
self.intermediate = Linear(config.hidden_width, config.intermediate_width)
self.output = Linear(config.intermediate_width, config.hidden_width)
if config.hidden_act == "relu":
self.activation = torch.nn.ReLU() # type: ignore
elif config.hidden_act == "gelu":
Expand Down
14 changes: 14 additions & 0 deletions curated_transformers/models/pytorch/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F


class Linear(nn.Linear):
def forward(self, input: Tensor) -> Tensor:
# Work around issue with linear with the MPS backend. See:
# https://github.com/pytorch/pytorch/issues/97239
if hasattr(input, "is_mps") and input.is_mps:
return torch.matmul(input, self.weight.t()) + self.bias
else:
return F.linear(input, self.weight, self.bias)
4 changes: 2 additions & 2 deletions curated_transformers/tests/models/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def test_model_against_hf_transformers(model_config):
Y_hf_encoder = hf_encoder(X, attention_mask=attention_mask)

assert torch.allclose(
Y_encoder.last_hidden_layer_states, Y_hf_encoder.last_hidden_state
Y_encoder.last_hidden_layer_states, Y_hf_encoder.last_hidden_state, atol=1e-6
)

# Try to infer the attention mask from padding.
Y_encoder = encoder(X)
assert torch.allclose(
Y_encoder.last_hidden_layer_states, Y_hf_encoder.last_hidden_state
Y_encoder.last_hidden_layer_states, Y_hf_encoder.last_hidden_state, atol=1e-6
)

0 comments on commit d78f768

Please sign in to comment.