From 4055d7e5a12105367ca1c3224bd1e3d1d3e362a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Sun, 11 Feb 2024 09:03:25 +0100 Subject: [PATCH] Set torch upper bound to <2.1.0 (#363) * Set torch upper bound to <2.1.0 Some changes in PyTorch 2.1.0 and later are incompatible with Curated Transformers 1.x. Fixing these issues would require API changes. So we set the upper bound on supported PyTorch versions. We will soon release Curated Transformers 2.0.0, which is compatible with the lastest PyTorch versions. * black --- curated_transformers/models/falcon/layer.py | 8 +++++--- curated_transformers/tests/models/util.py | 4 +--- .../tokenizers/legacy/legacy_tokenizer.py | 6 ++---- requirements.txt | 2 +- setup.cfg | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/curated_transformers/models/falcon/layer.py b/curated_transformers/models/falcon/layer.py index 003156dc..5a740911 100644 --- a/curated_transformers/models/falcon/layer.py +++ b/curated_transformers/models/falcon/layer.py @@ -70,9 +70,11 @@ def __init__( n_key_value_heads=attention_config.n_key_value_heads, ), rotary_embeds=rotary_embeds, - qkv_mode=QkvMode.MERGED_SPLIT_AFTER - if attention_config.n_key_value_heads == 1 - else QkvMode.MERGED_SPLIT_BEFORE, + qkv_mode=( + QkvMode.MERGED_SPLIT_AFTER + if attention_config.n_key_value_heads == 1 + else QkvMode.MERGED_SPLIT_BEFORE + ), use_bias=attention_config.use_bias, device=device, ) diff --git a/curated_transformers/tests/models/util.py b/curated_transformers/tests/models/util.py index d9a12926..c229a288 100644 --- a/curated_transformers/tests/models/util.py +++ b/curated_transformers/tests/models/util.py @@ -58,9 +58,7 @@ class JITMethod(Enum): TorchCompile = 1 TorchScriptTrace = 2 - def convert( - self, model: Module, with_torch_sdp: bool, *args - ) -> Tuple[ + def convert(self, model: Module, with_torch_sdp: bool, *args) -> Tuple[ Union[Module, torch.ScriptModule], Callable[[Union[ModelOutput, Dict[str, torch.Tensor]]], Tensor], ]: diff --git a/curated_transformers/tokenizers/legacy/legacy_tokenizer.py b/curated_transformers/tokenizers/legacy/legacy_tokenizer.py index 12de0014..9c05cd3a 100644 --- a/curated_transformers/tokenizers/legacy/legacy_tokenizer.py +++ b/curated_transformers/tokenizers/legacy/legacy_tokenizer.py @@ -155,12 +155,10 @@ def _convert_strings( @abstractmethod def _decode( self, input: Iterable[Iterable[int]], skip_special_pieces: bool - ) -> List[str]: - ... + ) -> List[str]: ... @abstractmethod - def _encode(self, input: Iterable[MergedInputChunks]) -> PiecesWithIds: - ... + def _encode(self, input: Iterable[MergedInputChunks]) -> PiecesWithIds: ... class AddBosEosPreEncoder(PreEncoder): diff --git a/requirements.txt b/requirements.txt index 8d0472b6..5945169e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 -torch>=1.12.0 +torch>=1.12.0,<2.1.0 # Development dependencies mypy>=0.990,<1.1.0; platform_machine != "aarch64" diff --git a/setup.cfg b/setup.cfg index 78f793df..3697b742 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,7 +17,7 @@ install_requires = curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 - torch>=1.12.0 + torch>=1.12.0,<2.1.0 [options.extras_require] quantization =