diff --git a/curated_transformers/repository/repository.py b/curated_transformers/repository/repository.py index 32b74450..8d1e3610 100644 --- a/curated_transformers/repository/repository.py +++ b/curated_transformers/repository/repository.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple from .._compat import has_safetensors -from ..util.serde.checkpoint import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType +from ..util.serde.checkpoint import ModelCheckpointType from ._hf import ( HF_MODEL_CONFIG, HF_TOKENIZER_CONFIG, @@ -146,21 +146,16 @@ def get_checkpoint_paths( return checkpoint_paths - checkpoint_type = _MODEL_CHECKPOINT_TYPE.get() + # Precedence: Safetensors > PyTorch + checkpoint_type = ModelCheckpointType.SAFE_TENSORS checkpoint_paths: Optional[List[RepositoryFile]] = None - - if checkpoint_type is None: - # Precedence: Safetensors > PyTorch - if has_safetensors: - try: - checkpoint_type = ModelCheckpointType.SAFE_TENSORS - checkpoint_paths = get_checkpoint_paths(checkpoint_type) - except OSError: - pass - if checkpoint_paths is None: - checkpoint_type = ModelCheckpointType.PYTORCH_STATE_DICT + if has_safetensors: + try: checkpoint_paths = get_checkpoint_paths(checkpoint_type) - else: + except OSError: + pass + if checkpoint_paths is None: + checkpoint_type = ModelCheckpointType.PYTORCH_STATE_DICT checkpoint_paths = get_checkpoint_paths(checkpoint_type) assert checkpoint_paths is not None diff --git a/curated_transformers/tests/models/test_hf_hub.py b/curated_transformers/tests/models/test_hf_hub.py index 55a5f229..9b57de3f 100644 --- a/curated_transformers/tests/models/test_hf_hub.py +++ b/curated_transformers/tests/models/test_hf_hub.py @@ -7,7 +7,6 @@ from curated_transformers.repository.hf_hub import HfHubRepository from curated_transformers.repository.repository import ModelRepository from curated_transformers.util.serde.checkpoint import ModelCheckpointType -from curated_transformers.util.serde.load import _use_model_checkpoint_type from ..compat import has_hf_transformers, has_safetensors from ..conftest import TORCH_DEVICES @@ -61,10 +60,6 @@ def test_checkpoint_type_without_safetensors(): assert Path(ckp_paths[0].path).suffix == ".bin" assert ckp_type == ModelCheckpointType.PYTORCH_STATE_DICT - with pytest.raises(ValueError, match="`safetensors` library is required"): - with _use_model_checkpoint_type(ModelCheckpointType.SAFE_TENSORS): - BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test") - @pytest.mark.skipif(not has_safetensors, reason="requires huggingface safetensors") def test_checkpoint_type_with_safetensors(): @@ -81,30 +76,6 @@ def test_checkpoint_type_with_safetensors(): encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test") -@pytest.mark.skipif(not has_safetensors, reason="requires huggingface safetensors") -def test_forced_checkpoint_type(): - with _use_model_checkpoint_type(ModelCheckpointType.PYTORCH_STATE_DICT): - repo = ModelRepository( - HfHubRepository( - "explosion-testing/safetensors-sharded-test", revision="main" - ) - ) - ckp_paths, ckp_type = repo.model_checkpoints() - assert len(ckp_paths) == 3 - assert all(Path(p.path).suffix == ".bin" for p in ckp_paths) - assert ckp_type == ModelCheckpointType.PYTORCH_STATE_DICT - - encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test") - - with _use_model_checkpoint_type(ModelCheckpointType.SAFE_TENSORS): - ckp_paths, ckp_type = repo.model_checkpoints() - assert len(ckp_paths) == 3 - assert all(Path(p.path).suffix == ".safetensors" for p in ckp_paths) - assert ckp_type == ModelCheckpointType.SAFE_TENSORS - - encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test") - - @pytest.mark.slow @pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers") @pytest.mark.parametrize("torch_device", TORCH_DEVICES) diff --git a/curated_transformers/util/serde/checkpoint.py b/curated_transformers/util/serde/checkpoint.py index 70b348c1..c69c24ad 100644 --- a/curated_transformers/util/serde/checkpoint.py +++ b/curated_transformers/util/serde/checkpoint.py @@ -1,6 +1,5 @@ -from contextvars import ContextVar from enum import Enum -from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional +from typing import TYPE_CHECKING, Callable, Iterable, Mapping import torch @@ -42,12 +41,6 @@ def pretty_name(self) -> str: return "" -# When `None`, behaviour is implementation-specific. -_MODEL_CHECKPOINT_TYPE: ContextVar[Optional[ModelCheckpointType]] = ContextVar( - "model_checkpoint_type", default=None -) - - def _load_safetensor_state_dicts_from_checkpoints( checkpoints: Iterable[RepositoryFile], ) -> Iterable[Mapping[str, torch.Tensor]]: diff --git a/curated_transformers/util/serde/load.py b/curated_transformers/util/serde/load.py index 05939e86..0243dbf2 100644 --- a/curated_transformers/util/serde/load.py +++ b/curated_transformers/util/serde/load.py @@ -6,7 +6,7 @@ from ...repository.file import RepositoryFile from ..pytorch import ModuleIterator, apply_to_module -from .checkpoint import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType +from .checkpoint import ModelCheckpointType # Args: Parent module, module prefix, parameter name, tensor to convert, device. # Returns the new paramater. @@ -21,29 +21,6 @@ ] -@contextmanager -def _use_model_checkpoint_type( - model_checkpoint_type: ModelCheckpointType, -): - """ - Specifies which type of model checkpoint to use when loading a serialized model. - - By default, Curated Transformers will attempt to load from the most suitable - checkpoint type depending on its availability. This context manager can be used - to override the default behaviour. - - .. code-block:: python - - with use_model_checkpoint_type(ModelCheckpointType.SAFETENSORS): - encoder = BertEncoder.from_hf_hub(name="bert-base-uncased") - """ - token = _MODEL_CHECKPOINT_TYPE.set(model_checkpoint_type) - try: - yield - finally: - _MODEL_CHECKPOINT_TYPE.reset(token) - - def load_model_from_checkpoints( model: Module, *,