Skip to content

Commit

Permalink
Remove support for TorchScript tracing
Browse files Browse the repository at this point in the history
We added support for TorchScript tracing a while back, so that models
can be exported to ONNX. However, the support relies on metaclasses,
which breaks with torch.compile in the latest PyTorch versions. However,
PyTorch now provides a TorchDynamo-based ONNX exporter:

https://pytorch.org/docs/stable/onnx_dynamo.html

So it's time to yank TorchScript tracing support and remove all the
fragile dataclass/tuple/dict polymorphism.
  • Loading branch information
danieldk committed Feb 8, 2024
1 parent 130df32 commit fd49ae7
Show file tree
Hide file tree
Showing 21 changed files with 24 additions and 700 deletions.
39 changes: 2 additions & 37 deletions curated_transformers/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch.nn import Dropout, Linear, Module

from ..semver import Default, FutureMandatory
from ..util.dataclass import DataclassAsDict
from .cache import KeyValueCache
from .embeddings import QueryKeyRotaryEmbeddings

Expand Down Expand Up @@ -47,7 +46,7 @@ def enable_torch_sdp(use_torch_sdp: bool = True):


@dataclass
class AttentionMask(DataclassAsDict):
class AttentionMask:
"""
Mask for attention calculation. Sequence elements for which the
corresponding mask element is set to ``False`` are ignored during
Expand All @@ -74,33 +73,6 @@ def __init__(self, bool_mask: Tensor):
"[batch_len, heads, query_len, key_len]"
)

self.__post_init__()

@classmethod
def jit_rewrap(
cls: Type["AttentionMask"],
attention_mask: Union["AttentionMask", Dict[str, Tensor]],
) -> "AttentionMask":
"""
Rewrap TorchScript dictionary conversion of an attention mask
as an ``AttentionMask``.
:param attention_mask:
The attention mask or its dictionary representation. If the
value is an ``AttentionMask``, the value will be returned as-is.
:returns:
The attention mask.
"""
if isinstance(attention_mask, AttentionMask):
return attention_mask

bool_mask = attention_mask.get("bool_mask")
if bool_mask is None:
raise ValueError(
"Attention mask is not of the `AttentionMask` type, nor a dict with 'bool_mask'."
)
return AttentionMask(bool_mask=bool_mask)

def apply_logit_mask(self, input: Tensor) -> Tensor:
"""
Use the attention mask to mask attention logits.
Expand Down Expand Up @@ -917,20 +889,13 @@ def forward(
*Shape:* ``(batch_size, seq_len, width)``
"""
# The attention mask is converted to a dict for traced models. Rewrap as
# AttentionMask to get validation and utility methods.
attention_mask = AttentionMask.jit_rewrap(attention_mask)

query, key, value = self._query_key_value(input)

if self.rotary_embeds is not None:
query, key = self.rotary_embeds(
query=query, key=key, cache=cache, positions=positions
)

# The key-value is converted to a dict for traced models. Rewrap as
# KeyValueCache to get validation and utility methods.
cache = KeyValueCache.jit_rewrap(cache)
if cache is not None:
cache_k = cache.key
cache_v = cache.value
Expand All @@ -951,7 +916,7 @@ def forward(
output = self.output(attn)

if store_cache:
return output, KeyValueCache(key, value)
return output, KeyValueCache(key=key, value=value)
else:
return output, None

Expand Down
43 changes: 2 additions & 41 deletions curated_transformers/layers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
from torch import Tensor

from ..util.dataclass import DataclassAsTuple

CacheProtocolSelf = TypeVar("CacheProtocolSelf", bound="CacheProtocol")


Expand All @@ -27,7 +25,7 @@ def filter_batch_items(self: CacheProtocolSelf, mask: Tensor) -> CacheProtocolSe


@dataclass
class KeyValueCache(DataclassAsTuple):
class KeyValueCache:
"""
Cache type for layers that cache keys and values.
Expand All @@ -52,41 +50,4 @@ def filter_batch_items(self, mask: Tensor) -> "KeyValueCache":
if mask.dtype != torch.bool:
raise ValueError(f"Cache mask dtype must be bool, was: {mask.dtype}.")

return KeyValueCache(self.key[mask], self.value[mask])

@classmethod
def jit_rewrap(
cls: Type["KeyValueCache"],
key_value_cache: Optional[Union["KeyValueCache", Tuple[Tensor, Tensor]]],
) -> Optional["KeyValueCache"]:
"""
Rewrap TorchScript dictionary conversion of a key-value cache.
:param key_value_cache:
The key-value cache or its dictionary representation. If the
value is a ``KeyValueCache`` or ``None``, it will be
returned as-is.
:returns:
The key-value cache.
"""
if key_value_cache is None or isinstance(key_value_cache, KeyValueCache):
return key_value_cache

if (
not isinstance(key_value_cache, tuple)
or len(key_value_cache) != 2
or not all(isinstance(item, Tensor) for item in key_value_cache)
):
raise ValueError(
f"Key-value cache is not of the `KeyValueCache` type, nor `Tuple[Tensor, Tensor]`: `{type(key_value_cache).__name__}`"
)

key_cache = key_value_cache[0]
value_cache = key_value_cache[1]

if key_cache.shape != value_cache.shape:
raise ValueError(
f"Key cache ({key_cache.shape}) and value cache ({value_cache.shape}) must have same shapes."
)

return cls(key_cache, value_cache)
return KeyValueCache(key=self.key[mask], value=self.value[mask])
4 changes: 0 additions & 4 deletions curated_transformers/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,6 @@ def forward(
head_width = self.head_width
rotary_width = self.rotary_width

# The key-value is converted to a dict for traced models. Rewrap as
# KeyValueCache to get validation and utility methods.
cache = KeyValueCache.jit_rewrap(cache)

# If a cache was provided, but no positions, assume that the
# positions of the current batch continue from the cache.
if cache is not None and positions is None:
Expand Down
3 changes: 1 addition & 2 deletions curated_transformers/models/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from torch import Tensor

from ..layers.cache import CacheProtocol
from ..util.dataclass import DataclassAsTuple

CacheT = TypeVar("CacheT", bound=CacheProtocol)


@dataclass
class ModelOutput(DataclassAsTuple):
class ModelOutput:
"""
Base class for model outputs.
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/albert/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_encoder_output_equals_hf(
ALBERTEncoder,
"explosion-testing/albert-test",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_encoder_hf_serializtion_roundtrip(torch_device):
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/bert/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_encoder_output_equals_hf(
BERTEncoder,
"explosion-testing/bert-test",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_encoder_hf_serializtion_roundtrip(torch_device):
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/camembert/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_encoder_output_equals_hf(
CamemBERTEncoder,
"explosion-testing/camembert-test",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_encoder_hf_serializtion_roundtrip(torch_device):
Expand Down
21 changes: 0 additions & 21 deletions curated_transformers/tests/models/falcon/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,27 +100,6 @@ def test_decoder_with_torch_compile(torch_device, model_revision, with_torch_sdp
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("model_revision", FALCON_TEST_MODELS)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_decoder_with_torchscript_trace(torch_device, model_revision, with_torch_sdp):
model, revision = model_revision
assert_decoder_output_equals_hf(
FalconDecoder,
model,
torch_device,
model_revision=revision,
trust_remote_code=True,
with_cache=False,
with_mask=False,
with_positions=False,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("model_revision", FALCON_TEST_MODELS)
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/gpt_neox/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,6 @@ def test_causal_lm_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_causal_lm_with_torchscript_trace(torch_device, with_torch_sdp):
assert_causal_lm_output_equals_hf(
GPTNeoXCausalLM,
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_causal_lm_hf_serializtion_roundtrip(torch_device):
Expand Down
16 changes: 0 additions & 16 deletions curated_transformers/tests/models/gpt_neox/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,6 @@ def test_decoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_decoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_decoder_output_equals_hf(
GPTNeoXDecoder,
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
torch_device,
with_cache=True,
with_positions=True,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_decoder_hf_serializtion_roundtrip(torch_device):
Expand Down
15 changes: 0 additions & 15 deletions curated_transformers/tests/models/llama/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,6 @@ def test_causal_lm_torch_compile(torch_device, model, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("model", LLAMA_TEST_MODELS)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_causal_lm_with_torchscript_trace(torch_device, model, with_torch_sdp):
assert_causal_lm_output_equals_hf(
LlamaCausalLM,
model,
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("model", LLAMA_TEST_MODELS)
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
Expand Down
15 changes: 0 additions & 15 deletions curated_transformers/tests/models/llama/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,6 @@ def test_decoder_with_torch_compile(torch_device, model, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("model", LLAMA_TEST_MODELS)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_decoder_with_torchscript_trace(torch_device, model, with_torch_sdp):
assert_decoder_output_equals_hf(
LlamaDecoder,
model,
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("model", LLAMA_TEST_MODELS)
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/mpt/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,6 @@ def test_causal_lm_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_causal_lm_with_torchscript_trace(torch_device, with_torch_sdp):
assert_causal_lm_output_equals_hf(
MPTCausalLM,
"explosion-testing/mpt-test",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_causal_lm_hf_serializtion_roundtrip(torch_device):
Expand Down
16 changes: 0 additions & 16 deletions curated_transformers/tests/models/mpt/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,6 @@ def test_decoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_decoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_decoder_output_equals_hf(
MPTDecoder,
"explosion-testing/mpt-test",
torch_device,
with_cache=True,
with_positions=False,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_decoder_hf_serializtion_roundtrip(torch_device):
Expand Down
Loading

0 comments on commit fd49ae7

Please sign in to comment.