diff --git a/curated_transformers/layers/attention.py b/curated_transformers/layers/attention.py index 5c4fa225..4a31bc5f 100644 --- a/curated_transformers/layers/attention.py +++ b/curated_transformers/layers/attention.py @@ -12,7 +12,6 @@ from torch.nn import Dropout, Linear, Module from ..semver import Default, FutureMandatory -from ..util.dataclass import DataclassAsDict from .cache import KeyValueCache from .embeddings import QueryKeyRotaryEmbeddings @@ -47,7 +46,7 @@ def enable_torch_sdp(use_torch_sdp: bool = True): @dataclass -class AttentionMask(DataclassAsDict): +class AttentionMask: """ Mask for attention calculation. Sequence elements for which the corresponding mask element is set to ``False`` are ignored during @@ -74,33 +73,6 @@ def __init__(self, bool_mask: Tensor): "[batch_len, heads, query_len, key_len]" ) - self.__post_init__() - - @classmethod - def jit_rewrap( - cls: Type["AttentionMask"], - attention_mask: Union["AttentionMask", Dict[str, Tensor]], - ) -> "AttentionMask": - """ - Rewrap TorchScript dictionary conversion of an attention mask - as an ``AttentionMask``. - - :param attention_mask: - The attention mask or its dictionary representation. If the - value is an ``AttentionMask``, the value will be returned as-is. - :returns: - The attention mask. - """ - if isinstance(attention_mask, AttentionMask): - return attention_mask - - bool_mask = attention_mask.get("bool_mask") - if bool_mask is None: - raise ValueError( - "Attention mask is not of the `AttentionMask` type, nor a dict with 'bool_mask'." - ) - return AttentionMask(bool_mask=bool_mask) - def apply_logit_mask(self, input: Tensor) -> Tensor: """ Use the attention mask to mask attention logits. @@ -919,10 +891,6 @@ def forward( *Shape:* ``(batch_size, seq_len, width)`` """ - # The attention mask is converted to a dict for traced models. Rewrap as - # AttentionMask to get validation and utility methods. - attention_mask = AttentionMask.jit_rewrap(attention_mask) - query, key, value = self._query_key_value(input) if self.rotary_embeds is not None: @@ -930,9 +898,6 @@ def forward( query=query, key=key, cache=cache, positions=positions ) - # The key-value is converted to a dict for traced models. Rewrap as - # KeyValueCache to get validation and utility methods. - cache = KeyValueCache.jit_rewrap(cache) if cache is not None: cache_k = cache.key cache_v = cache.value @@ -953,7 +918,7 @@ def forward( output = self.output(attn) if store_cache: - return output, KeyValueCache(key, value) + return output, KeyValueCache(key=key, value=value) else: return output, None diff --git a/curated_transformers/layers/cache.py b/curated_transformers/layers/cache.py index b08cb560..1a4e22ff 100644 --- a/curated_transformers/layers/cache.py +++ b/curated_transformers/layers/cache.py @@ -4,8 +4,6 @@ import torch from torch import Tensor -from ..util.dataclass import DataclassAsTuple - CacheProtocolSelf = TypeVar("CacheProtocolSelf", bound="CacheProtocol") @@ -27,7 +25,7 @@ def filter_batch_items(self: CacheProtocolSelf, mask: Tensor) -> CacheProtocolSe @dataclass -class KeyValueCache(DataclassAsTuple): +class KeyValueCache: """ Cache type for layers that cache keys and values. @@ -52,41 +50,4 @@ def filter_batch_items(self, mask: Tensor) -> "KeyValueCache": if mask.dtype != torch.bool: raise ValueError(f"Cache mask dtype must be bool, was: {mask.dtype}.") - return KeyValueCache(self.key[mask], self.value[mask]) - - @classmethod - def jit_rewrap( - cls: Type["KeyValueCache"], - key_value_cache: Optional[Union["KeyValueCache", Tuple[Tensor, Tensor]]], - ) -> Optional["KeyValueCache"]: - """ - Rewrap TorchScript dictionary conversion of a key-value cache. - - :param key_value_cache: - The key-value cache or its dictionary representation. If the - value is a ``KeyValueCache`` or ``None``, it will be - returned as-is. - :returns: - The key-value cache. - """ - if key_value_cache is None or isinstance(key_value_cache, KeyValueCache): - return key_value_cache - - if ( - not isinstance(key_value_cache, tuple) - or len(key_value_cache) != 2 - or not all(isinstance(item, Tensor) for item in key_value_cache) - ): - raise ValueError( - f"Key-value cache is not of the `KeyValueCache` type, nor `Tuple[Tensor, Tensor]`: `{type(key_value_cache).__name__}`" - ) - - key_cache = key_value_cache[0] - value_cache = key_value_cache[1] - - if key_cache.shape != value_cache.shape: - raise ValueError( - f"Key cache ({key_cache.shape}) and value cache ({value_cache.shape}) must have same shapes." - ) - - return cls(key_cache, value_cache) + return KeyValueCache(key=self.key[mask], value=self.value[mask]) diff --git a/curated_transformers/layers/embeddings.py b/curated_transformers/layers/embeddings.py index 09f20c91..f29d8626 100644 --- a/curated_transformers/layers/embeddings.py +++ b/curated_transformers/layers/embeddings.py @@ -278,10 +278,6 @@ def forward( head_width = self.head_width rotary_width = self.rotary_width - # The key-value is converted to a dict for traced models. Rewrap as - # KeyValueCache to get validation and utility methods. - cache = KeyValueCache.jit_rewrap(cache) - # If a cache was provided, but no positions, assume that the # positions of the current batch continue from the cache. if cache is not None and positions is None: diff --git a/curated_transformers/models/output.py b/curated_transformers/models/output.py index 730586c2..db8f7dec 100644 --- a/curated_transformers/models/output.py +++ b/curated_transformers/models/output.py @@ -4,13 +4,12 @@ from torch import Tensor from ..layers.cache import CacheProtocol -from ..util.dataclass import DataclassAsTuple CacheT = TypeVar("CacheT", bound=CacheProtocol) @dataclass -class ModelOutput(DataclassAsTuple): +class ModelOutput: """ Base class for model outputs. diff --git a/curated_transformers/tests/models/albert/test_encoder.py b/curated_transformers/tests/models/albert/test_encoder.py index 6be65b27..dd6e9f79 100644 --- a/curated_transformers/tests/models/albert/test_encoder.py +++ b/curated_transformers/tests/models/albert/test_encoder.py @@ -44,20 +44,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): - assert_encoder_output_equals_hf( - ALBERTEncoder, - "explosion-testing/albert-test", - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_encoder_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/models/bert/test_encoder.py b/curated_transformers/tests/models/bert/test_encoder.py index 3736ab09..174c2551 100644 --- a/curated_transformers/tests/models/bert/test_encoder.py +++ b/curated_transformers/tests/models/bert/test_encoder.py @@ -38,20 +38,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): - assert_encoder_output_equals_hf( - BERTEncoder, - "explosion-testing/bert-test", - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_encoder_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/models/camembert/test_encoder.py b/curated_transformers/tests/models/camembert/test_encoder.py index 0eeb6e73..a5c921dd 100644 --- a/curated_transformers/tests/models/camembert/test_encoder.py +++ b/curated_transformers/tests/models/camembert/test_encoder.py @@ -38,20 +38,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): - assert_encoder_output_equals_hf( - CamemBERTEncoder, - "explosion-testing/camembert-test", - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_encoder_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/models/falcon/test_decoder.py b/curated_transformers/tests/models/falcon/test_decoder.py index 144d520e..79c100a1 100644 --- a/curated_transformers/tests/models/falcon/test_decoder.py +++ b/curated_transformers/tests/models/falcon/test_decoder.py @@ -100,27 +100,6 @@ def test_decoder_with_torch_compile(torch_device, model_revision, with_torch_sdp ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("model_revision", FALCON_TEST_MODELS) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_decoder_with_torchscript_trace(torch_device, model_revision, with_torch_sdp): - model, revision = model_revision - assert_decoder_output_equals_hf( - FalconDecoder, - model, - torch_device, - model_revision=revision, - trust_remote_code=True, - with_cache=False, - with_mask=False, - with_positions=False, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) @pytest.mark.parametrize("model_revision", FALCON_TEST_MODELS) diff --git a/curated_transformers/tests/models/gpt_neox/test_causal_lm.py b/curated_transformers/tests/models/gpt_neox/test_causal_lm.py index 055f16c4..7a072191 100644 --- a/curated_transformers/tests/models/gpt_neox/test_causal_lm.py +++ b/curated_transformers/tests/models/gpt_neox/test_causal_lm.py @@ -39,20 +39,6 @@ def test_causal_lm_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_causal_lm_with_torchscript_trace(torch_device, with_torch_sdp): - assert_causal_lm_output_equals_hf( - GPTNeoXCausalLM, - "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_causal_lm_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/models/gpt_neox/test_decoder.py b/curated_transformers/tests/models/gpt_neox/test_decoder.py index 4a782bd6..11b60ffe 100644 --- a/curated_transformers/tests/models/gpt_neox/test_decoder.py +++ b/curated_transformers/tests/models/gpt_neox/test_decoder.py @@ -45,22 +45,6 @@ def test_decoder_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_decoder_with_torchscript_trace(torch_device, with_torch_sdp): - assert_decoder_output_equals_hf( - GPTNeoXDecoder, - "trl-internal-testing/tiny-random-GPTNeoXForCausalLM", - torch_device, - with_cache=True, - with_positions=True, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_decoder_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/models/llama/test_causal_lm.py b/curated_transformers/tests/models/llama/test_causal_lm.py index 04ac0461..2b31e83e 100644 --- a/curated_transformers/tests/models/llama/test_causal_lm.py +++ b/curated_transformers/tests/models/llama/test_causal_lm.py @@ -48,21 +48,6 @@ def test_causal_lm_torch_compile(torch_device, model, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("model", LLAMA_TEST_MODELS) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_causal_lm_with_torchscript_trace(torch_device, model, with_torch_sdp): - assert_causal_lm_output_equals_hf( - LlamaCausalLM, - model, - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("model", LLAMA_TEST_MODELS) @pytest.mark.parametrize("torch_device", TORCH_DEVICES) diff --git a/curated_transformers/tests/models/llama/test_decoder.py b/curated_transformers/tests/models/llama/test_decoder.py index b07ede54..9467e16b 100644 --- a/curated_transformers/tests/models/llama/test_decoder.py +++ b/curated_transformers/tests/models/llama/test_decoder.py @@ -45,21 +45,6 @@ def test_decoder_with_torch_compile(torch_device, model, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("model", LLAMA_TEST_MODELS) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_decoder_with_torchscript_trace(torch_device, model, with_torch_sdp): - assert_decoder_output_equals_hf( - LlamaDecoder, - model, - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("model", LLAMA_TEST_MODELS) @pytest.mark.parametrize("torch_device", TORCH_DEVICES) diff --git a/curated_transformers/tests/models/mpt/test_causal_lm.py b/curated_transformers/tests/models/mpt/test_causal_lm.py index 1efeb583..53409e2f 100644 --- a/curated_transformers/tests/models/mpt/test_causal_lm.py +++ b/curated_transformers/tests/models/mpt/test_causal_lm.py @@ -40,20 +40,6 @@ def test_causal_lm_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_causal_lm_with_torchscript_trace(torch_device, with_torch_sdp): - assert_causal_lm_output_equals_hf( - MPTCausalLM, - "explosion-testing/mpt-test", - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_causal_lm_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/models/mpt/test_decoder.py b/curated_transformers/tests/models/mpt/test_decoder.py index 185397f3..f29b3bf0 100644 --- a/curated_transformers/tests/models/mpt/test_decoder.py +++ b/curated_transformers/tests/models/mpt/test_decoder.py @@ -48,22 +48,6 @@ def test_decoder_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_decoder_with_torchscript_trace(torch_device, with_torch_sdp): - assert_decoder_output_equals_hf( - MPTDecoder, - "explosion-testing/mpt-test", - torch_device, - with_cache=True, - with_positions=False, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_decoder_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/models/roberta/test_encoder.py b/curated_transformers/tests/models/roberta/test_encoder.py index 73739fb1..80485bd2 100644 --- a/curated_transformers/tests/models/roberta/test_encoder.py +++ b/curated_transformers/tests/models/roberta/test_encoder.py @@ -38,20 +38,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): - assert_encoder_output_equals_hf( - RoBERTaEncoder, - "explosion-testing/roberta-test", - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_encoder_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index f9bf0847..74bfdef2 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -58,28 +58,16 @@ def forward( class JITMethod(Enum): Disable = 0 TorchCompile = 1 - TorchScriptTrace = 2 def convert(self, model: Module, with_torch_sdp: bool, *args) -> Tuple[ - Union[Module, torch.ScriptModule], + Module, Callable[[Union[ModelOutput, Dict[str, torch.Tensor]]], Tensor], ]: with enable_torch_sdp(with_torch_sdp): - if self == JITMethod.Disable: - return model, lambda s: s - elif self == JITMethod.TorchCompile: + if self == JITMethod.TorchCompile: return torch.compile(model), lambda s: s - else: - if isinstance(model, EncoderModule): - cls = ModelOutput - elif isinstance(model, DecoderModule): - cls = ModelOutputWithCache - elif isinstance(model, CausalLMModule): - cls = ModelOutputWithCache - return ( - torch.jit.trace(model, tuple(args)), - lambda s: s, - ) + else: # JITMethod.Disable + return model, lambda s: s def assert_causal_lm_output_equals_hf( @@ -117,7 +105,7 @@ def assert_causal_lm_output_equals_hf( X = torch.randint(0, hf_model.config.vocab_size, (2, 10), device=torch_device) mask = torch.ones_like(X, dtype=torch.bool) with torch.no_grad(): - Y = get_output(model(X, AttentionMask(mask)))[1] + Y = get_output(model(X, AttentionMask(mask))).logits Y_hf = hf_model(X).logits torch_assertclose(Y, Y_hf, atol=atol, rtol=rtol) @@ -128,7 +116,7 @@ def assert_causal_lm_output_equals_hf( mask = torch.rand((2, 10), dtype=torch.float, device=torch_device) < 0.5 with torch.no_grad(): - Y = get_output(model(X, AttentionMask(mask)))[1] * mask.unsqueeze(-1) + Y = get_output(model(X, AttentionMask(mask))).logits * mask.unsqueeze(-1) Y_hf = hf_model(X, attention_mask=mask).logits * mask.unsqueeze(-1) torch_assertclose(Y, Y_hf, atol=atol, rtol=rtol) @@ -173,7 +161,7 @@ def assert_decoder_output_equals_hf( X = torch.randint(0, hf_model.config.vocab_size, (2, 10), device=torch_device) mask = torch.ones_like(X, dtype=torch.bool) with torch.no_grad(): - Y = output(model(X, AttentionMask(mask)))[0][-1] + Y = output(model(X, AttentionMask(mask))).last_hidden_layer_state Y_hf = hf_model(X).last_hidden_state torch_assertclose(Y, Y_hf, atol=atol, rtol=rtol) @@ -234,7 +222,7 @@ def assert_encoder_output_equals_hf( mask = torch.ones_like(X, dtype=torch.bool) with torch.no_grad(): - Y = output(model(X, AttentionMask(mask)))[0][-1] + Y = output(model(X, AttentionMask(mask))).last_hidden_layer_state Y_hf = hf_model(X).last_hidden_state torch_assertclose(Y, Y_hf, atol=atol, rtol=rtol) @@ -276,7 +264,7 @@ def assert_decoder_with_cache_output_equals_hf( device=torch_device, ) empty_cache_jit = [ - KeyValueCache(empty_kv_jit, empty_kv_jit) + KeyValueCache(key=empty_kv_jit, value=empty_kv_jit) ] * hf_model.config.num_hidden_layers X = torch.randint(0, hf_model.config.vocab_size, (2, 10), device=torch_device) @@ -286,7 +274,9 @@ def assert_decoder_with_cache_output_equals_hf( with torch.no_grad(): Y = model(X, AttentionMask(mask), empty_cache_jit) Y_hf = hf_model(X, use_cache=True) - Y = output(model(X_rest, AttentionMask(mask_rest), cache=output(Y)[1]))[0][-1] + Y = output( + model(X_rest, AttentionMask(mask_rest), cache=output(Y).cache) + ).last_hidden_layer_state Y_hf = hf_model(X_rest, past_key_values=Y_hf.past_key_values).last_hidden_state torch_assertclose(Y, Y_hf, atol=atol, rtol=rtol) @@ -310,7 +300,9 @@ def assert_with_mask_output_equals_hf( X = torch.randint(0, hf_model.config.vocab_size, (2, 10), device=torch_device) mask = torch.rand((2, 10), dtype=torch.float, device=torch_device) < 0.5 with torch.no_grad(): - Y = output(model(X, AttentionMask(mask)))[0][-1] * mask.unsqueeze(-1) + Y = output( + model(X, AttentionMask(mask)) + ).last_hidden_layer_state * mask.unsqueeze(-1) Y_hf = hf_model(X, attention_mask=mask).last_hidden_state * mask.unsqueeze(-1) torch_assertclose(Y, Y_hf, atol=atol, rtol=rtol) @@ -349,7 +341,9 @@ def assert_decoder_with_positions_equals_hf( mask = torch.ones_like(X, dtype=torch.bool) positions = torch.randint(0, 10, (2, 10), device=torch_device) with torch.no_grad(): - Y = output(model(X, AttentionMask(mask), positions=positions))[0][-1] + Y = output( + model(X, AttentionMask(mask), positions=positions) + ).last_hidden_layer_state Y_hf = hf_model(X, position_ids=positions).last_hidden_state torch_assertclose(Y, Y_hf, atol=atol, rtol=rtol) diff --git a/curated_transformers/tests/models/xlm_roberta/test_encoder.py b/curated_transformers/tests/models/xlm_roberta/test_encoder.py index 1fc3773e..2130d4bf 100644 --- a/curated_transformers/tests/models/xlm_roberta/test_encoder.py +++ b/curated_transformers/tests/models/xlm_roberta/test_encoder.py @@ -38,20 +38,6 @@ def test_encoder_with_torch_compile(torch_device, with_torch_sdp): ) -@pytest.mark.slow -@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") -@pytest.mark.parametrize("torch_device", TORCH_DEVICES) -@pytest.mark.parametrize("with_torch_sdp", [False, True]) -def test_encoder_with_torchscript_trace(torch_device, with_torch_sdp): - assert_encoder_output_equals_hf( - XLMREncoder, - "explosion-testing/xlm-roberta-test", - torch_device, - jit_method=JITMethod.TorchScriptTrace, - with_torch_sdp=with_torch_sdp, - ) - - @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) def test_encoder_hf_serializtion_roundtrip(torch_device): diff --git a/curated_transformers/tests/util/test_dataclass.py b/curated_transformers/tests/util/test_dataclass.py deleted file mode 100644 index 92844126..00000000 --- a/curated_transformers/tests/util/test_dataclass.py +++ /dev/null @@ -1,112 +0,0 @@ -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple - -import pytest -import torch -from torch import Tensor - -from curated_transformers.util.dataclass import DataclassAsDict, DataclassAsTuple - - -@dataclass -class AsDict(DataclassAsDict): - foo: Tensor - bar: Tensor - - -@dataclass -class AsTuple(DataclassAsTuple): - foo: Tensor - bar: List[Tensor] - baz: Optional[Tensor] - quux: List[Tuple[Any]] - - -@dataclass -class InvalidAsTuple(DataclassAsTuple): - bar: int - - -@dataclass -class InvalidAsDict(DataclassAsDict): - foo: Tensor - bar: int - - -def test_as_dict(): - d = AsDict(foo=torch.zeros(1, 1), bar=torch.full((2, 2), 42)) - assert "foo" in d - assert "bar" in d - assert d["foo"] is d.foo - assert d["bar"] is d.bar - - d.bar = torch.full((3, 3), 80) - assert d["bar"] is d.bar - - d["foo"] = torch.ones((4, 4)) - assert d["foo"] is d.foo - - -def test_as_dict_non_tensor_member_rejected(): - with pytest.raises(TypeError, match=r"bar.*int"): - InvalidAsDict(foo=torch.zeros(1, 1), bar=42) - - -def test_as_dict_invalid_operations(): - d = AsDict(foo=torch.zeros(1, 1), bar=torch.full((2, 2), 42)) - with pytest.raises(TypeError, match=r"foo.*str"): - d.foo = "Glove80" - with pytest.raises(TypeError, match=r"foo.*str"): - d["foo"] = "Glove80" - with pytest.raises(TypeError, match=r"non-`str`.*int"): - d[83] = "Glove80" - with pytest.raises(NotImplementedError): - del d["foo"] - with pytest.raises(NotImplementedError): - delattr(d, "foo") - - -def test_as_tuple(): - t = AsTuple( - torch.full((2, 2), 42), - [torch.ones(3, 3), torch.zeros(4, 4)], - None, - [(42, 80)], - ) - - assert t[0] is t.foo - - assert len(t[1]) == 2 - assert t[1][0] is t.bar[0] - assert t[1][1] is t.bar[1] - - assert len(t[2]) == 1 - assert t[2][0] is t.quux[0] - - -def test_as_tuple_incorrect(): - with pytest.raises(TypeError, match=r"Tensor, str"): - AsTuple( - torch.full((2, 2), 42), - [torch.ones(3, 3), "Glove80"], - None, - [AsDict(torch.full((5, 5), 80), torch.full((6, 6), -1))], - ) - - with pytest.raises(TypeError, match=r"unsupported.*int"): - InvalidAsTuple(42) - - -def test_as_tuple_is_immutable(): - t = AsTuple( - torch.full((2, 2), 42), - [torch.ones(3, 3), torch.zeros(4, 4)], - None, - [(42, 80)], - ) - - with pytest.raises(TypeError, match=r"does not support attribute assignment"): - t.foo = torch.zeros(5, 5) - - with pytest.raises(TypeError, match=r"does not support attribute deletion"): - del t.foo diff --git a/curated_transformers/util/dataclass.py b/curated_transformers/util/dataclass.py deleted file mode 100644 index 59696981..00000000 --- a/curated_transformers/util/dataclass.py +++ /dev/null @@ -1,144 +0,0 @@ -from dataclasses import fields -from typing import Any, Generator, OrderedDict - -from torch import Tensor - - -class DataclassAsDict(OrderedDict[str, Tensor]): - """ - Dataclasses that derive from this struct are also a dictionary. - - Since this class should only be used for dataclasses that are - ``torch.jit.trace``-able, only ``Tensor`` fields are supported. - - Only dataclass fields and keys corresponding to those fields - can be changed. Fields and keys cannot be removed. - """ - - def __post_init__(self): - for field in fields(self): - value = getattr(self, field.name) - if not isinstance(value, Tensor): - raise TypeError( - f"`DataclassAsDict` only supports `Tensor` members, but field '{field.name}' has type `{field.type.__name__}`" - ) - - super().__setitem__(field.name, value) - - def __delitem__(self, key: str): - raise NotImplementedError() - - def __delattr__(self, name: str): - raise NotImplementedError() - - def __setattr__(self, name: str, value: Any) -> None: - if not isinstance(value, Tensor): - raise TypeError( - f"Field '{name}' cannot be set to non-Tensor type `{type(value).__name__}`" - ) - - super().__setitem__(name, value) - super().__setattr__(name, value) - - def __setitem__(self, key: str, value: Tensor) -> None: - if not isinstance(key, str): - raise TypeError( - f"Key cannot be set to non-`str` type `{type(key).__name__}`" - ) - - if not isinstance(value, Tensor): - raise TypeError( - f"Field '{key}' cannot be set to non-Tensor type `{type(value).__name__}`" - ) - - super().__setattr__(key, value) - super().__setitem__(key, value) - - -class _InterceptGeneratorMeta(type): - """ - Tuples can take a generator as their only constructor argument, - dataclasses will see them them as a single argument. Intercept - single generator argument, evaluate the generator, and pass - the values as args. - """ - - def __new__(cls, name, bases, dict): - # Add tuple as a base class. MyPy complains about having tuple - # directly as a base class. See: https://github.com/python/mypy/issues/14818 - return super().__new__(cls, name, bases + (tuple,), dict) - - def __call__(cls, *args, **kwargs): - # Convert generator argument to a tuple, so that we can pass - # the values as regular arguments. - if len(args) == 1 and kwargs == {} and isinstance(args[0], Generator): - args = tuple(args[0]) - - obj = super().__call__(*args, **kwargs) - obj._is_frozen = True - - return obj - - -class DataclassAsTuple(metaclass=_InterceptGeneratorMeta): - """ - Dataclasses that derive from this class are also a tuple. - - Since this class should only be used for dataclasses that are - ``torch.jit.trace``-able, only the following types of fields - are supported: - - * ``Tensor`` - * ``List[Tensor]`` - * ``List[Tuple[...]]`` - * ``Optional`` of any of the above. - - Fields that have the value ``None`` are skipped. - """ - - def __new__(cls, *args, **kwargs): - values = [] - for idx, field in enumerate(fields(cls)): - if idx < len(args): - value = args[idx] - elif field.name in kwargs: - value = kwargs[field.name] - else: - # Field is not specified, consider it optional. If it is - # a mandatory field, the dataclass machinery will complain - # later. - continue - - if value is None: - continue - - values.append(DataclassAsTuple._convert_value(value)) - - return tuple.__new__(cls, values) - - @staticmethod - def _convert_value(value): - if isinstance(value, Tensor): - return value - elif isinstance(value, list): - if all(isinstance(item, Tensor) for item in value): - return tuple(value) - elif all(isinstance(item, tuple) for item in value): - return tuple(value) - else: - type_names = ", ".join(sorted({type(item).__name__ for item in value})) - raise TypeError( - f"List must be `List[Tensor]` or `List[Tuple[...]]`, found types: {type_names}" - ) - else: - raise TypeError(f"Field has unsupported type `{type(value).__name__}`") - - def __delattr__(self, name: str): - if getattr(self, "_is_frozen", False): - raise TypeError("`AsTuple` object does not support attribute deletion") - super().__delattr__(name) - - def __setattr__(self, name: str, value: Any) -> None: - if getattr(self, "_is_frozen", False): - raise TypeError("`AsTuple` object does not support attribute assignment") - super().__setattr__(name, value) diff --git a/docs/source/deployment.rst b/docs/source/deployment.rst index 7b41152f..c741b1a1 100644 --- a/docs/source/deployment.rst +++ b/docs/source/deployment.rst @@ -6,107 +6,4 @@ e.g., in a `Flask-based REST service `_. In other cases, such as deployment to non-CUDA accelerators, additional model transformations might be needed. On this page, we cover several deployment -scenarios. - -TorchScript Tracing -------------------- - -Many deployment methods start from `TorchScript`_. For instance, ONNX conversion -converts the TorchScript representation of a model. TorchScript is a -statically-typed subset of Python. It only supports the types that are necessary -for representing neural network models. - -Curated Transformers supports TorchScript through `tracing`_. -Tracing runs the model with some example inputs and records the computation -graph. The TorchScript code is then generated from this computation graph, -discarding all other Python code. - -.. _TorchScript: https://pytorch.org/docs/stable/jit.html -.. _tracing: https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace - -Tracing a Model -^^^^^^^^^^^^^^^ - -Models are traced using the |torch.jit.trace|_ function. The first argument to -this function is the model that you would like to trace, the second argument the -inputs as a tuple. For example, we can trace a decoder as follows: - - -.. code-block:: python - - import torch - import torch.jit - from curated_transformers.layers import AttentionMask - from curated_transformers.models import AutoDecoder - - device = torch.device("cuda", index=0) - decoder = AutoDecoder.from_hf_hub(name="tiiuae/falcon-7b", device=device) - X_example = torch.zeros(4, 20, dtype=torch.long, device=device) - mask_example = AttentionMask(torch.ones((4, 20), dtype=torch.bool, device=device)) - traced = torch.jit.trace(decoder, (X_example, mask_example)) - -As you can see, we are feeding the model with an all-zeros piece identifier tensor and -an all-ones mask tensor during tracing. This is not really an issue - as long as -the inputs allow the model to run normally, tracing can do its work. - -In the example above, ``traced`` is a TorchScript module. From the surface, it -behaves like any other module. We can feed it some piece identifiers to get -their hidden representations: - -.. code-block:: python - - from curated_transformers.tokenizers import AutoTokenizer - - tokenizer = AutoTokenizer.from_hf_hub(name="tiiuae/falcon-7b") - pieces = tokenizer(["Hello world!", "This is a test"]) - Y = traced(pieces.padded_tensor(device=device), - pieces.attention_mask(device=device)) - assert isinstance(Y, tuple) - last_layer = Y[0][-1] - -The model works as before, albeit with one catch. Normally a decoder returns a -:py:class:`~curated_transformers.models.ModelOutputWithCache` instance, -but the traced model returns a tuple instead. The reason is that TorchScript only -supports a limited set of types. Since arbitrary types are not supported, we -convert the :py:class:`~curated_transformers.models.ModelOutputWithCache` -instance to a tuple in a traced model. The tuple will have the same ordering as the -fields in the untraced model's output, excluding fields that are set to -``None``. In this case we don't ask the decoder to return a key-value cache, so -the ``cache`` field is ``None`` and will not be represented in the tuple. - -Handling Complex Model Signatures -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The previous section describes how we can trace a model. In some cases, it can be -difficult to provide a working argument tuple to |torch.jit.trace|_. Suppose -that we would like to trace a decoder with an attention mask and positions, but -without using a cache. In the -:py:class:`~curated_transformers.models.DecoderModule` API, the ``cache`` -argument is interspersed between the ``attention_mask`` and ``positions`` -arguments. This turns out to be problematic since we cannot pass ``None`` -arguments to the |torch.jit.trace|_ function. While |torch.jit.trace|_ provides an -``example_kwarg_inputs`` parameter to pass arguments by keyword, we have -found that this mechanism often skips over arguments. - -In such cases, we recommend you to make a simple wrapper around a model that only -has the desired arguments. For instance, in the above case you could define a -class ``DecoderWithPositions``: - -.. code-block:: python - - class DecoderWithPositions(Module): - def __init__(self, decoder: DecoderModule): - super().__init__() - self.inner = decoder - - def forward(self, - input_ids: Tensor, - attention_mask: AttentionMask, - positions: Tensor): - return self.inner.forward(input_ids, attention_mask, positions=positions) - -You can then wrap a decoder with this class and trace it using the two mandatory -arguments. - -.. |torch.jit.trace| replace:: ``torch.jit.trace`` -.. _torch.jit.trace: https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace +scenarios in the future. diff --git a/docs/source/develop.rst b/docs/source/develop.rst index 64253d36..1a0ef088 100644 --- a/docs/source/develop.rst +++ b/docs/source/develop.rst @@ -13,54 +13,3 @@ We use two branches during regular development. If the current version is Following semver, only bug fixes must be pushed to ``v1.0.x``. When applicable, a bug should first be fixed in the ``main`` branch using a PR. After that, a backport PR should be made for the ``v1.0.x`` branch with the *backport* label. - -TorchScript ------------ - -Tracing -^^^^^^^ - -We support TorchScript tracing and test it with all models when using -the ``--slow`` flag. - -Tracing only accepts a small number of types for the arguments and return values -of a traced module. For our purposes, these types are: ``Tensor``, ``Dict[str, -Tensor]``, or tuples of these types. This has ramifications for our models -because they take different argument types (e.g., ``AttentionMask`` and -``KeyValueCache``) and return ``ModelOutput`` or one of its subclasses. What -complicates this is that we want to keep strong typing outside TorchScript. We -have addressed these issues as described below. - -Module Arguments -"""""""""""""""" - -Our argument types are dataclasses with only ``Tensor`` fields. These types can -be represented as ``Dict[str, Tensor]`` without any loss of information. To this -end, we have made a ``DataclassAsDict`` base class. Dataclasses that inherit -from this class are also proper dictionaries. This allows us to pass these data -structures to traced models. When such a type is passed to a traced model, the -original type information is erased and inside the model, the argument will be a -regular dictionary. To handle these arguments uniformly and retain access to -utility methods and properties, we rewrap the dictionary as a class. For instance, a -method that uses ``AttentionMask`` can rewrap ``Union[AttentionMask, Dict[str, -Tensor]]`` as an ``AttentionMask``: - -.. code-block:: python - - attention_mask = AttentionMask.jit_rewrap(attention_mask) - -Module Return Values -"""""""""""""""""""" - -The ``ModelOutput``-based return types can contain nested dataclasses. For -instance, ``ModelOutputWithCache`` contains an ``Optional[List[CacheT]]`` field -where ``CacheT`` can be ``KeyValueCache``. Consequently, not every -``ModelOutput`` can be represented as a ``Dict[str, Tentor]``. For that reason, -we represent model outputs as tuples instead. Dataclasses that inherit from -``DataclassAsTuple`` are also a tuple. - -Scripting -^^^^^^^^^ - -We **do not** support TorchScript scripting, since it would require too many -compromises to code quality (e.g., we cannot use ``torch.finfo``).