Skip to content

Commit

Permalink
Add repository abstraction (#331)
Browse files Browse the repository at this point in the history
* Add repository abstraction

Before this change, loading models was done with a bunch of standalone
functions for Hugging Face Hub and fsspec. These functions had a lot of
overlap and adding yet another storage backend would require
duplicating the same functions again and littering them through the code
base.

This change does away with all the standalone functions and introduces
the `Repository` API. This base class requires implementations to define
a few basic operations. More complex operations are implemented in terms
of these basic operations and generic across repository types.

Initially there are two repository types, `HfHubRepository` and
`FsspecRepository`. There are also two wrappers for `Repository`
instances that implement model operations (`ModelRepository`) and
tokenizer operations (`TokenizerRepository`).

* Fixes

Co-authored-by: Madeesh Kannan <[email protected]>

* More specific catch

Co-authored-by: Madeesh Kannan <[email protected]>

* Address PR comments

* Repository example, some fixes

* Add missing elipsis

---------

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
danieldk and shadeMe authored Sep 26, 2023
1 parent 84246c9 commit 1f023dc
Show file tree
Hide file tree
Showing 28 changed files with 1,013 additions and 1,152 deletions.
187 changes: 61 additions & 126 deletions curated_transformers/models/auto_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, Type, TypeVar
from typing import Dict, Generic, Optional, Type, TypeVar

import torch
from fsspec import AbstractFileSystem

from ..layers.cache import KeyValueCache
from ..quantization.bnb.config import BitsAndBytesConfig
from ..util.fsspec import get_config_model_type as get_config_model_type_fsspec
from ..util.hf import get_config_model_type
from ..repository.fsspec import FsspecArgs, FsspecRepository
from ..repository.hf_hub import HfHubRepository
from ..repository.repository import ModelRepository, Repository
from .albert import ALBERTEncoder
from .bert import BERTEncoder
from .camembert import CamemBERTEncoder
Expand All @@ -33,36 +34,12 @@ class AutoModel(ABC, Generic[ModelT]):

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {}

@classmethod
def _resolve_model_cls_fsspec(
cls,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
) -> Type[FromHFHub]:
model_type = get_config_model_type_fsspec(
fs, model_path, fsspec_args=fsspec_args
)
if model_type is None:
raise ValueError(
"The model type is not defined in the model configuration."
)
module_cls = cls._hf_model_type_to_curated.get(model_type)
if module_cls is None:
raise ValueError(
f"Unsupported model type `{model_type}` for {cls.__name__}. "
f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}"
)
assert issubclass(module_cls, FromHFHub)
return module_cls

@classmethod
def _resolve_model_cls(
cls,
name: str,
revision: str,
repo: ModelRepository,
) -> Type[FromHFHub]:
model_type = get_config_model_type(name, revision)
model_type = repo.model_type()
module_cls = cls._hf_model_type_to_curated.get(model_type)
if module_cls is None:
raise ValueError(
Expand All @@ -73,36 +50,15 @@ def _resolve_model_cls(
return module_cls

@classmethod
def _instantiate_model_from_fsspec(
cls,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]],
device: Optional[torch.device],
quantization_config: Optional[BitsAndBytesConfig],
) -> FromHFHub:
module_cls = cls._resolve_model_cls_fsspec(fs, model_path)
module = module_cls.from_fsspec(
fs=fs,
model_path=model_path,
fsspec_args=fsspec_args,
device=device,
quantization_config=quantization_config,
)
return module

@classmethod
def _instantiate_model_from_hf_hub(
def _instantiate_model(
cls,
name: str,
revision: str,
repo: Repository,
device: Optional[torch.device],
quantization_config: Optional[BitsAndBytesConfig],
) -> FromHFHub:
module_cls = cls._resolve_model_cls(name, revision)
module = module_cls.from_hf_hub(
name=name,
revision=revision,
module_cls = cls._resolve_model_cls(ModelRepository(repo))
module = module_cls.from_repo(
repo=repo,
device=device,
quantization_config=quantization_config,
)
Expand All @@ -114,7 +70,7 @@ def from_fsspec(
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
fsspec_args: Optional[FsspecArgs] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> ModelT:
Expand All @@ -135,10 +91,17 @@ def from_fsspec(
:returns:
Module with the parameters loaded.
"""
raise NotImplementedError
return cls.from_repo(
repo=FsspecRepository(
fs,
path=model_path,
fsspec_args=fsspec_args,
),
device=device,
quantization_config=quantization_config,
)

@classmethod
@abstractmethod
def from_hf_hub(
cls,
*,
Expand All @@ -161,6 +124,34 @@ def from_hf_hub(
:returns:
Loaded model or generator.
"""
return cls.from_repo(
repo=HfHubRepository(name=name, revision=revision),
device=device,
quantization_config=quantization_config,
)

@classmethod
@abstractmethod
def from_repo(
cls,
*,
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> ModelT:
"""
Construct and load a model or a generator from a repository.
:param repository:
The repository to load from.
:param device:
Device on which to initialize the model.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Loaded model or generator.
"""

raise NotImplementedError

@classmethod
Expand All @@ -181,8 +172,9 @@ def from_hf_hub_to_cache(
:param revision:
Model revision.
"""
module_cls = cls._resolve_model_cls(name, revision)
module_cls.from_hf_hub_to_cache(name=name, revision=revision)
repo = ModelRepository(HfHubRepository(name=name, revision=revision))
repo.model_config()
repo.model_checkpoints()


class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]):
Expand All @@ -199,33 +191,14 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]):
}

@classmethod
def from_fsspec(
def from_repo(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule[TransformerConfig]:
encoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(encoder, EncoderModule)
return encoder

@classmethod
def from_hf_hub(
cls,
*,
name: str,
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> EncoderModule[TransformerConfig]:
encoder = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
encoder = cls._instantiate_model(repo, device, quantization_config)
assert isinstance(encoder, EncoderModule)
return encoder

Expand All @@ -245,33 +218,14 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]):
}

@classmethod
def from_fsspec(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule[TransformerConfig, KeyValueCache]:
decoder = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(decoder, DecoderModule)
return decoder

@classmethod
def from_hf_hub(
def from_repo(
cls,
*,
name: str,
revision: str = "main",
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> DecoderModule[TransformerConfig, KeyValueCache]:
decoder = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
decoder = cls._instantiate_model(repo, device, quantization_config)
assert isinstance(decoder, DecoderModule)
return decoder

Expand All @@ -291,32 +245,13 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]):
}

@classmethod
def from_fsspec(
def from_repo(
cls,
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[TransformerConfig, KeyValueCache]:
causal_lm = cls._instantiate_model_from_fsspec(
fs, model_path, fsspec_args, device, quantization_config
)
assert isinstance(causal_lm, CausalLMModule)
return causal_lm

@classmethod
def from_hf_hub(
cls,
*,
name: str,
revision: str = "main",
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> CausalLMModule[TransformerConfig, KeyValueCache]:
causal_lm = cls._instantiate_model_from_hf_hub(
name, revision, device, quantization_config
)
causal_lm = cls._instantiate_model(repo, device, quantization_config)
assert isinstance(causal_lm, CausalLMModule)
return causal_lm
53 changes: 29 additions & 24 deletions curated_transformers/models/hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@

from ..quantization import prepare_module_for_quantization
from ..quantization.bnb.config import BitsAndBytesConfig
from ..util.fsspec import (
get_model_checkpoint_files as get_model_checkpoint_files_fsspec,
)
from ..util.fsspec import get_model_config as get_model_config_fsspec
from ..util.hf import get_model_checkpoint_files, get_model_config
from ..util.serde import ModelCheckpointType, ModelFile, load_model_from_checkpoints
from ..repository.fsspec import FsspecArgs, FsspecRepository
from ..repository.hf_hub import HfHubRepository
from ..repository.repository import ModelRepository, Repository
from ..util.serde import load_model_from_checkpoints
from .module import TransformerModule

# Only provided as typing.Self in Python 3.11+.
Expand Down Expand Up @@ -94,16 +92,17 @@ def from_hf_hub_to_cache(
:param revision:
Model revision.
"""
_ = get_model_config(name, revision)
_ = get_model_checkpoint_files(name, revision)
repo = ModelRepository(HfHubRepository(name=name, revision=revision))
repo.model_config()
repo.model_checkpoints()

@classmethod
def from_fsspec(
cls: Type[Self],
*,
fs: AbstractFileSystem,
model_path: str,
fsspec_args: Optional[Dict[str, Any]] = None,
fsspec_args: Optional[FsspecArgs] = None,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
Expand All @@ -124,13 +123,8 @@ def from_fsspec(
:returns:
Module with the parameters loaded.
"""
return cls._create_and_load_model(
get_config=lambda: get_model_config_fsspec(
fs, model_path, fsspec_args=fsspec_args
),
get_checkpoint_files=lambda: get_model_checkpoint_files_fsspec(
fs, model_path, fsspec_args=fsspec_args
),
return cls.from_repo(
repo=FsspecRepository(fs, model_path, fsspec_args),
device=device,
quantization_config=quantization_config,
)
Expand Down Expand Up @@ -158,9 +152,8 @@ def from_hf_hub(
:returns:
Module with the parameters loaded.
"""
return cls._create_and_load_model(
get_config=lambda: get_model_config(name, revision),
get_checkpoint_files=lambda: get_model_checkpoint_files(name, revision),
return cls.from_repo(
repo=HfHubRepository(name=name, revision=revision),
device=device,
quantization_config=quantization_config,
)
Expand All @@ -182,15 +175,27 @@ def to(
...

@classmethod
def _create_and_load_model(
def from_repo(
cls: Type[Self],
*,
get_config: Callable[[], Dict[Any, str]],
get_checkpoint_files: Callable[[], Tuple[List[ModelFile], ModelCheckpointType]],
repo: Repository,
device: Optional[torch.device] = None,
quantization_config: Optional[BitsAndBytesConfig] = None,
) -> Self:
config = get_config()
"""
Construct and load a model from a repository.
:param repository:
The repository to load from.
:param device:
Device on which to initialize the model.
:param quantization_config:
Configuration for loading quantized weights.
:returns:
Loaded model.
"""
model_repo = ModelRepository(repo)
config = model_repo.model_config()
model = cls.from_hf_config(hf_config=config, device=torch.device("meta"))

# Convert the model to the expected dtype.
Expand All @@ -211,7 +216,7 @@ def _create_and_load_model(
tensor2param = None

# Download model and convert HF parameter names to ours.
checkpoint_filenames, checkpoint_type = get_checkpoint_files()
checkpoint_filenames, checkpoint_type = model_repo.model_checkpoints()
load_model_from_checkpoints(
model, # type:ignore
filepaths=checkpoint_filenames,
Expand Down
Loading

0 comments on commit 1f023dc

Please sign in to comment.