Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove support for TorchScript tracing #361

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 2 additions & 37 deletions curated_transformers/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch.nn import Dropout, Linear, Module

from ..semver import Default, FutureMandatory
from ..util.dataclass import DataclassAsDict
from .cache import KeyValueCache
from .embeddings import QueryKeyRotaryEmbeddings

Expand Down Expand Up @@ -47,7 +46,7 @@ def enable_torch_sdp(use_torch_sdp: bool = True):


@dataclass
class AttentionMask(DataclassAsDict):
class AttentionMask:
"""
Mask for attention calculation. Sequence elements for which the
corresponding mask element is set to ``False`` are ignored during
Expand All @@ -74,33 +73,6 @@ def __init__(self, bool_mask: Tensor):
"[batch_len, heads, query_len, key_len]"
)

self.__post_init__()

@classmethod
def jit_rewrap(
cls: Type["AttentionMask"],
attention_mask: Union["AttentionMask", Dict[str, Tensor]],
) -> "AttentionMask":
"""
Rewrap TorchScript dictionary conversion of an attention mask
as an ``AttentionMask``.

:param attention_mask:
The attention mask or its dictionary representation. If the
value is an ``AttentionMask``, the value will be returned as-is.
:returns:
The attention mask.
"""
if isinstance(attention_mask, AttentionMask):
return attention_mask

bool_mask = attention_mask.get("bool_mask")
if bool_mask is None:
raise ValueError(
"Attention mask is not of the `AttentionMask` type, nor a dict with 'bool_mask'."
)
return AttentionMask(bool_mask=bool_mask)

def apply_logit_mask(self, input: Tensor) -> Tensor:
"""
Use the attention mask to mask attention logits.
Expand Down Expand Up @@ -919,20 +891,13 @@ def forward(

*Shape:* ``(batch_size, seq_len, width)``
"""
# The attention mask is converted to a dict for traced models. Rewrap as
# AttentionMask to get validation and utility methods.
attention_mask = AttentionMask.jit_rewrap(attention_mask)

query, key, value = self._query_key_value(input)

if self.rotary_embeds is not None:
query, key = self.rotary_embeds(
query=query, key=key, cache=cache, positions=positions
)

# The key-value is converted to a dict for traced models. Rewrap as
# KeyValueCache to get validation and utility methods.
cache = KeyValueCache.jit_rewrap(cache)
if cache is not None:
cache_k = cache.key
cache_v = cache.value
Expand All @@ -953,7 +918,7 @@ def forward(
output = self.output(attn)

if store_cache:
return output, KeyValueCache(key, value)
return output, KeyValueCache(key=key, value=value)
else:
return output, None

Expand Down
43 changes: 2 additions & 41 deletions curated_transformers/layers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
from torch import Tensor

from ..util.dataclass import DataclassAsTuple

CacheProtocolSelf = TypeVar("CacheProtocolSelf", bound="CacheProtocol")


Expand All @@ -27,7 +25,7 @@ def filter_batch_items(self: CacheProtocolSelf, mask: Tensor) -> CacheProtocolSe


@dataclass
class KeyValueCache(DataclassAsTuple):
class KeyValueCache:
"""
Cache type for layers that cache keys and values.

Expand All @@ -52,41 +50,4 @@ def filter_batch_items(self, mask: Tensor) -> "KeyValueCache":
if mask.dtype != torch.bool:
raise ValueError(f"Cache mask dtype must be bool, was: {mask.dtype}.")

return KeyValueCache(self.key[mask], self.value[mask])

@classmethod
def jit_rewrap(
cls: Type["KeyValueCache"],
key_value_cache: Optional[Union["KeyValueCache", Tuple[Tensor, Tensor]]],
) -> Optional["KeyValueCache"]:
"""
Rewrap TorchScript dictionary conversion of a key-value cache.

:param key_value_cache:
The key-value cache or its dictionary representation. If the
value is a ``KeyValueCache`` or ``None``, it will be
returned as-is.
:returns:
The key-value cache.
"""
if key_value_cache is None or isinstance(key_value_cache, KeyValueCache):
return key_value_cache

if (
not isinstance(key_value_cache, tuple)
or len(key_value_cache) != 2
or not all(isinstance(item, Tensor) for item in key_value_cache)
):
raise ValueError(
f"Key-value cache is not of the `KeyValueCache` type, nor `Tuple[Tensor, Tensor]`: `{type(key_value_cache).__name__}`"
)

key_cache = key_value_cache[0]
value_cache = key_value_cache[1]

if key_cache.shape != value_cache.shape:
raise ValueError(
f"Key cache ({key_cache.shape}) and value cache ({value_cache.shape}) must have same shapes."
)

return cls(key_cache, value_cache)
return KeyValueCache(key=self.key[mask], value=self.value[mask])
4 changes: 0 additions & 4 deletions curated_transformers/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,6 @@ def forward(
head_width = self.head_width
rotary_width = self.rotary_width

# The key-value is converted to a dict for traced models. Rewrap as
# KeyValueCache to get validation and utility methods.
cache = KeyValueCache.jit_rewrap(cache)

# If a cache was provided, but no positions, assume that the
# positions of the current batch continue from the cache.
if cache is not None and positions is None:
Expand Down
3 changes: 1 addition & 2 deletions curated_transformers/models/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from torch import Tensor

from ..layers.cache import CacheProtocol
from ..util.dataclass import DataclassAsTuple

CacheT = TypeVar("CacheT", bound=CacheProtocol)


@dataclass
class ModelOutput(DataclassAsTuple):
class ModelOutput:
"""
Base class for model outputs.

Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/albert/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_encoder_output_equals_hf(
ALBERTEncoder,
"explosion-testing/albert-test",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_encoder_hf_serializtion_roundtrip(torch_device):
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/bert/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_encoder_output_equals_hf(
BERTEncoder,
"explosion-testing/bert-test",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_encoder_hf_serializtion_roundtrip(torch_device):
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/camembert/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_encoder_output_equals_hf(
CamemBERTEncoder,
"explosion-testing/camembert-test",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_encoder_hf_serializtion_roundtrip(torch_device):
Expand Down
21 changes: 0 additions & 21 deletions curated_transformers/tests/models/falcon/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,27 +100,6 @@ def test_decoder_with_torch_compile(torch_device, model_revision, with_torch_sdp
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("model_revision", FALCON_TEST_MODELS)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_decoder_with_torchscript_trace(torch_device, model_revision, with_torch_sdp):
model, revision = model_revision
assert_decoder_output_equals_hf(
FalconDecoder,
model,
torch_device,
model_revision=revision,
trust_remote_code=True,
with_cache=False,
with_mask=False,
with_positions=False,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("model_revision", FALCON_TEST_MODELS)
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/gpt_neox/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,6 @@ def test_causal_lm_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_causal_lm_with_torchscript_trace(torch_device, with_torch_sdp):
assert_causal_lm_output_equals_hf(
GPTNeoXCausalLM,
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_causal_lm_hf_serializtion_roundtrip(torch_device):
Expand Down
16 changes: 0 additions & 16 deletions curated_transformers/tests/models/gpt_neox/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,6 @@ def test_decoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_decoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_decoder_output_equals_hf(
GPTNeoXDecoder,
"trl-internal-testing/tiny-random-GPTNeoXForCausalLM",
torch_device,
with_cache=True,
with_positions=True,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_decoder_hf_serializtion_roundtrip(torch_device):
Expand Down
15 changes: 0 additions & 15 deletions curated_transformers/tests/models/llama/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,6 @@ def test_causal_lm_torch_compile(torch_device, model, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("model", LLAMA_TEST_MODELS)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_causal_lm_with_torchscript_trace(torch_device, model, with_torch_sdp):
assert_causal_lm_output_equals_hf(
LlamaCausalLM,
model,
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("model", LLAMA_TEST_MODELS)
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
Expand Down
15 changes: 0 additions & 15 deletions curated_transformers/tests/models/llama/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,6 @@ def test_decoder_with_torch_compile(torch_device, model, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("model", LLAMA_TEST_MODELS)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_decoder_with_torchscript_trace(torch_device, model, with_torch_sdp):
assert_decoder_output_equals_hf(
LlamaDecoder,
model,
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("model", LLAMA_TEST_MODELS)
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
Expand Down
14 changes: 0 additions & 14 deletions curated_transformers/tests/models/mpt/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,6 @@ def test_causal_lm_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_causal_lm_with_torchscript_trace(torch_device, with_torch_sdp):
assert_causal_lm_output_equals_hf(
MPTCausalLM,
"explosion-testing/mpt-test",
torch_device,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_causal_lm_hf_serializtion_roundtrip(torch_device):
Expand Down
16 changes: 0 additions & 16 deletions curated_transformers/tests/models/mpt/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,6 @@ def test_decoder_with_torch_compile(torch_device, with_torch_sdp):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
@pytest.mark.parametrize("with_torch_sdp", [False, True])
def test_decoder_with_torchscript_trace(torch_device, with_torch_sdp):
assert_decoder_output_equals_hf(
MPTDecoder,
"explosion-testing/mpt-test",
torch_device,
with_cache=True,
with_positions=False,
jit_method=JITMethod.TorchScriptTrace,
with_torch_sdp=with_torch_sdp,
)


@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
def test_decoder_hf_serializtion_roundtrip(torch_device):
Expand Down
Loading
Loading