Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused model checkpoint context manager #335

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions curated_transformers/repository/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Optional, Tuple

from .._compat import has_safetensors
from ..util.serde.checkpoint import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType
from ..util.serde.checkpoint import ModelCheckpointType
from ._hf import (
HF_MODEL_CONFIG,
HF_TOKENIZER_CONFIG,
Expand Down Expand Up @@ -146,21 +146,16 @@ def get_checkpoint_paths(

return checkpoint_paths

checkpoint_type = _MODEL_CHECKPOINT_TYPE.get()
# Precedence: Safetensors > PyTorch
checkpoint_type = ModelCheckpointType.SAFE_TENSORS
checkpoint_paths: Optional[List[RepositoryFile]] = None

if checkpoint_type is None:
# Precedence: Safetensors > PyTorch
if has_safetensors:
try:
checkpoint_type = ModelCheckpointType.SAFE_TENSORS
checkpoint_paths = get_checkpoint_paths(checkpoint_type)
except OSError:
pass
if checkpoint_paths is None:
checkpoint_type = ModelCheckpointType.PYTORCH_STATE_DICT
if has_safetensors:
try:
checkpoint_paths = get_checkpoint_paths(checkpoint_type)
else:
except OSError:
pass
if checkpoint_paths is None:
checkpoint_type = ModelCheckpointType.PYTORCH_STATE_DICT
checkpoint_paths = get_checkpoint_paths(checkpoint_type)

assert checkpoint_paths is not None
Expand Down
29 changes: 0 additions & 29 deletions curated_transformers/tests/models/test_hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from curated_transformers.repository.hf_hub import HfHubRepository
from curated_transformers.repository.repository import ModelRepository
from curated_transformers.util.serde.checkpoint import ModelCheckpointType
from curated_transformers.util.serde.load import _use_model_checkpoint_type

from ..compat import has_hf_transformers, has_safetensors
from ..conftest import TORCH_DEVICES
Expand Down Expand Up @@ -61,10 +60,6 @@ def test_checkpoint_type_without_safetensors():
assert Path(ckp_paths[0].path).suffix == ".bin"
assert ckp_type == ModelCheckpointType.PYTORCH_STATE_DICT

with pytest.raises(ValueError, match="`safetensors` library is required"):
with _use_model_checkpoint_type(ModelCheckpointType.SAFE_TENSORS):
BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test")


@pytest.mark.skipif(not has_safetensors, reason="requires huggingface safetensors")
def test_checkpoint_type_with_safetensors():
Expand All @@ -81,30 +76,6 @@ def test_checkpoint_type_with_safetensors():
encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test")


@pytest.mark.skipif(not has_safetensors, reason="requires huggingface safetensors")
def test_forced_checkpoint_type():
with _use_model_checkpoint_type(ModelCheckpointType.PYTORCH_STATE_DICT):
repo = ModelRepository(
HfHubRepository(
"explosion-testing/safetensors-sharded-test", revision="main"
)
)
ckp_paths, ckp_type = repo.model_checkpoints()
assert len(ckp_paths) == 3
assert all(Path(p.path).suffix == ".bin" for p in ckp_paths)
assert ckp_type == ModelCheckpointType.PYTORCH_STATE_DICT

encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test")

with _use_model_checkpoint_type(ModelCheckpointType.SAFE_TENSORS):
ckp_paths, ckp_type = repo.model_checkpoints()
assert len(ckp_paths) == 3
assert all(Path(p.path).suffix == ".safetensors" for p in ckp_paths)
assert ckp_type == ModelCheckpointType.SAFE_TENSORS

encoder = BERTEncoder.from_hf_hub(name="explosion-testing/safetensors-test")


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
Expand Down
9 changes: 1 addition & 8 deletions curated_transformers/util/serde/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from contextvars import ContextVar
from enum import Enum
from typing import TYPE_CHECKING, Callable, Iterable, Mapping, Optional
from typing import TYPE_CHECKING, Callable, Iterable, Mapping

import torch

Expand Down Expand Up @@ -42,12 +41,6 @@ def pretty_name(self) -> str:
return ""


# When `None`, behaviour is implementation-specific.
_MODEL_CHECKPOINT_TYPE: ContextVar[Optional[ModelCheckpointType]] = ContextVar(
"model_checkpoint_type", default=None
)


def _load_safetensor_state_dicts_from_checkpoints(
checkpoints: Iterable[RepositoryFile],
) -> Iterable[Mapping[str, torch.Tensor]]:
Expand Down
25 changes: 1 addition & 24 deletions curated_transformers/util/serde/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ...repository.file import RepositoryFile
from ..pytorch import ModuleIterator, apply_to_module
from .checkpoint import _MODEL_CHECKPOINT_TYPE, ModelCheckpointType
from .checkpoint import ModelCheckpointType

# Args: Parent module, module prefix, parameter name, tensor to convert, device.
# Returns the new paramater.
Expand All @@ -21,29 +21,6 @@
]


@contextmanager
def _use_model_checkpoint_type(
model_checkpoint_type: ModelCheckpointType,
):
"""
Specifies which type of model checkpoint to use when loading a serialized model.

By default, Curated Transformers will attempt to load from the most suitable
checkpoint type depending on its availability. This context manager can be used
to override the default behaviour.

.. code-block:: python

with use_model_checkpoint_type(ModelCheckpointType.SAFETENSORS):
encoder = BertEncoder.from_hf_hub(name="bert-base-uncased")
"""
token = _MODEL_CHECKPOINT_TYPE.set(model_checkpoint_type)
try:
yield
finally:
_MODEL_CHECKPOINT_TYPE.reset(token)


def load_model_from_checkpoints(
model: Module,
*,
Expand Down