From fcbdfef1b6dbdf4c023e211f3e962960fbe628ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 5 Oct 2023 09:11:36 +0200 Subject: [PATCH] Register models using `catalogue` So far we have hardcoded the available encoders/decoders/causal LMs in the auto classes. This has the downside that the auto classes only work with models that are provided by Curated Transformers. This change adds registries for encoders/decoders/causal LMs. The auto classes query the relevant registry and check which registered model supports the downloaded model (through the `hf_model_types` method of the `FromHFHub` mixin). This makes it possible to register external models with Curated Transformers, so that they can also be used with the auto classes. Adding registries for tokenizers and generators is deferred to future PRs. --- curated_transformers/__init__.py | 19 +++++ curated_transformers/models/albert/encoder.py | 6 +- curated_transformers/models/auto_model.py | 71 ++++++++----------- curated_transformers/models/bert/encoder.py | 6 +- .../models/camembert/encoder.py | 6 +- .../models/falcon/causal_lm.py | 6 +- curated_transformers/models/falcon/decoder.py | 6 +- .../models/gpt_neox/causal_lm.py | 6 +- .../models/gpt_neox/decoder.py | 6 +- curated_transformers/models/hf_hub/mixin.py | 13 +++- .../models/llama/causal_lm.py | 6 +- curated_transformers/models/llama/decoder.py | 6 +- curated_transformers/models/mpt/causal_lm.py | 6 +- curated_transformers/models/mpt/decoder.py | 6 +- .../models/roberta/encoder.py | 6 +- .../models/xlm_roberta/encoder.py | 6 +- docs/source/api.rst | 1 + docs/source/registries.rst | 24 +++++++ requirements.txt | 1 + setup.cfg | 23 +++++- 20 files changed, 172 insertions(+), 58 deletions(-) create mode 100644 docs/source/registries.rst diff --git a/curated_transformers/__init__.py b/curated_transformers/__init__.py index e69de29b..0fdf4836 100644 --- a/curated_transformers/__init__.py +++ b/curated_transformers/__init__.py @@ -0,0 +1,19 @@ +import catalogue +from catalogue import Registry + + +class registry(object): + """ + Registry for models. These registries are used by auto classes to + discover the available models. + """ + + causal_lms: Registry = catalogue.create( + "curated_transformers", "causal_lms", entry_points=True + ) + decoders: Registry = catalogue.create( + "curated_transformers", "decoders", entry_points=True + ) + encoders: Registry = catalogue.create( + "curated_transformers", "encoders", entry_points=True + ) diff --git a/curated_transformers/models/albert/encoder.py b/curated_transformers/models/albert/encoder.py index 1a247c23..2c4ab449 100644 --- a/curated_transformers/models/albert/encoder.py +++ b/curated_transformers/models/albert/encoder.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -99,6 +99,10 @@ def forward( return ModelOutput(all_outputs=[embeddings, *layer_outputs]) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("albert",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/auto_model.py b/curated_transformers/models/auto_model.py index 94de95a0..e1523a30 100644 --- a/curated_transformers/models/auto_model.py +++ b/curated_transformers/models/auto_model.py @@ -1,27 +1,21 @@ +import warnings from abc import ABC, abstractmethod -from typing import Dict, Generic, Optional, Type, TypeVar +from typing import Generic, Optional, Type, TypeVar import torch +from catalogue import Registry from fsspec import AbstractFileSystem +from curated_transformers import registry + from ..layers.cache import KeyValueCache from ..quantization.bnb.config import BitsAndBytesConfig from ..repository.fsspec import FsspecArgs, FsspecRepository from ..repository.hf_hub import HfHubRepository from ..repository.repository import ModelRepository, Repository -from .albert import ALBERTEncoder -from .bert import BERTEncoder -from .camembert import CamemBERTEncoder from .config import TransformerConfig -from .falcon import FalconCausalLM, FalconDecoder -from .gpt_neox import GPTNeoXCausalLM, GPTNeoXDecoder from .hf_hub import FromHFHub -from .llama import LlamaCausalLM, LlamaDecoder from .module import CausalLMModule, DecoderModule, EncoderModule -from .mpt.causal_lm import MPTCausalLM -from .mpt.decoder import MPTDecoder -from .roberta import RoBERTaEncoder -from .xlm_roberta import XLMREncoder ModelT = TypeVar("ModelT") @@ -32,7 +26,7 @@ class AutoModel(ABC, Generic[ModelT]): Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {} + _registry: Registry @classmethod def _resolve_model_cls( @@ -40,14 +34,25 @@ def _resolve_model_cls( repo: ModelRepository, ) -> Type[FromHFHub]: model_type = repo.model_type() - module_cls = cls._hf_model_type_to_curated.get(model_type) - if module_cls is None: - raise ValueError( - f"Unsupported model type `{model_type}` for {cls.__name__}. " - f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}" - ) - assert issubclass(module_cls, FromHFHub) - return module_cls + + supported_model_types = set() + for entrypoint, module_cls in cls._registry.get_entry_points().items(): + if not issubclass(module_cls, FromHFHub): + warnings.warn( + f"Entry point {entrypoint} cannot load from Huggingface Hub " + "since the FromHFHub mixin is not implemented" + ) + continue + + module_model_types = module_cls.hf_model_types() + if model_type in module_model_types: + return module_cls + supported_model_types.update(module_model_types) + + raise ValueError( + f"Unsupported model type `{model_type}` for {cls.__name__}. " + f"Supported model types: {', '.join(sorted(supported_model_types))}" + ) @classmethod def _instantiate_model( @@ -182,13 +187,7 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]): Encoder model loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "bert": BERTEncoder, - "albert": ALBERTEncoder, - "camembert": CamemBERTEncoder, - "roberta": RoBERTaEncoder, - "xlm-roberta": XLMREncoder, - } + _registry: Registry = registry.encoders @classmethod def from_repo( @@ -208,14 +207,7 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]): Decoder module loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "falcon": FalconDecoder, - "gpt_neox": GPTNeoXDecoder, - "llama": LlamaDecoder, - "mpt": MPTDecoder, - "RefinedWeb": FalconDecoder, - "RefinedWebModel": FalconDecoder, - } + _registry = registry.decoders @classmethod def from_repo( @@ -235,14 +227,7 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]): Causal LM model loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "falcon": FalconCausalLM, - "gpt_neox": GPTNeoXCausalLM, - "llama": LlamaCausalLM, - "mpt": MPTCausalLM, - "RefinedWeb": FalconCausalLM, - "RefinedWebModel": FalconCausalLM, - } + _registry: Registry = registry.causal_lms @classmethod def from_repo( diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index 363b2875..acbdb389 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -105,6 +105,10 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None) ] ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("bert",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/camembert/encoder.py b/curated_transformers/models/camembert/encoder.py index 0c18c345..6a16919f 100644 --- a/curated_transformers/models/camembert/encoder.py +++ b/curated_transformers/models/camembert/encoder.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No The encoder. """ super().__init__(config, device=device) + + @classmethod + def hf_model_types(cls) -> Tuple[str, ...]: + return ("camembert",) diff --git a/curated_transformers/models/falcon/causal_lm.py b/curated_transformers/models/falcon/causal_lm.py index 91860d06..4c5661b2 100644 --- a/curated_transformers/models/falcon/causal_lm.py +++ b/curated_transformers/models/falcon/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -52,6 +52,10 @@ def state_dict_from_hf( ) -> Mapping[str, Tensor]: return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("falcon", "RefinedWeb", "RefinedWebModel") + @classmethod def state_dict_to_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/falcon/decoder.py b/curated_transformers/models/falcon/decoder.py index 478a70d0..6cef0fe1 100644 --- a/curated_transformers/models/falcon/decoder.py +++ b/curated_transformers/models/falcon/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -86,6 +86,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("falcon", "RefinedWeb", "RefinedWebModel") + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/gpt_neox/causal_lm.py b/curated_transformers/models/gpt_neox/causal_lm.py index eae67a2d..302f304a 100644 --- a/curated_transformers/models/gpt_neox/causal_lm.py +++ b/curated_transformers/models/gpt_neox/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -46,6 +46,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("gpt_neox",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/gpt_neox/decoder.py b/curated_transformers/models/gpt_neox/decoder.py index c7eb634a..7c41ea5c 100644 --- a/curated_transformers/models/gpt_neox/decoder.py +++ b/curated_transformers/models/gpt_neox/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -114,6 +114,10 @@ def __init__( hidden_width, config.layer.layer_norm_eps, device=device ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("gpt_neox",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/hf_hub/mixin.py b/curated_transformers/models/hf_hub/mixin.py index b0807c8c..2ee59ac2 100644 --- a/curated_transformers/models/hf_hub/mixin.py +++ b/curated_transformers/models/hf_hub/mixin.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from fsspec import AbstractFileSystem @@ -168,6 +168,17 @@ def from_hf_hub( quantization_config=quantization_config, ) + @classmethod + @abstractmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + """ + Get the Hugging Face model types supported by this model. + + :returns: + The supported model types. + """ + ... + @abstractmethod def to( self, diff --git a/curated_transformers/models/llama/causal_lm.py b/curated_transformers/models/llama/causal_lm.py index aa9f5320..0487aed1 100644 --- a/curated_transformers/models/llama/causal_lm.py +++ b/curated_transformers/models/llama/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -47,6 +47,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("llama",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/llama/decoder.py b/curated_transformers/models/llama/decoder.py index 1e4b7e69..e3e51aa6 100644 --- a/curated_transformers/models/llama/decoder.py +++ b/curated_transformers/models/llama/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -120,6 +120,10 @@ def __init__( hidden_width, eps=config.layer.layer_norm_eps, device=device ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("llama",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/mpt/causal_lm.py b/curated_transformers/models/mpt/causal_lm.py index dfc9c72a..de2c3916 100644 --- a/curated_transformers/models/mpt/causal_lm.py +++ b/curated_transformers/models/mpt/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping, Optional, Set, Type, TypeVar +from typing import Any, List, Mapping, Optional, Set, Tuple, Type, TypeVar import torch import torch.nn.functional as F @@ -84,6 +84,10 @@ def forward( logits=logits, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("mpt",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/mpt/decoder.py b/curated_transformers/models/mpt/decoder.py index f6acc05d..f72b9b31 100644 --- a/curated_transformers/models/mpt/decoder.py +++ b/curated_transformers/models/mpt/decoder.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -120,6 +120,10 @@ def layer_norm(): self.output_layer_norm = layer_norm() + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("mpt",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index b9cfc3f3..c642a6fa 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -105,6 +105,10 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No ] ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("roberta",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/xlm_roberta/encoder.py b/curated_transformers/models/xlm_roberta/encoder.py index 300a7a22..ee398958 100644 --- a/curated_transformers/models/xlm_roberta/encoder.py +++ b/curated_transformers/models/xlm_roberta/encoder.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No The encoder. """ super().__init__(config, device=device) + + @classmethod + def hf_model_types(cls) -> Tuple[str, ...]: + return ("xlm-roberta",) diff --git a/docs/source/api.rst b/docs/source/api.rst index 03cf662b..4bece367 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,6 +9,7 @@ API decoders causal-lm generation + registries repositories tokenizers quantization diff --git a/docs/source/registries.rst b/docs/source/registries.rst new file mode 100644 index 00000000..49c6cd89 --- /dev/null +++ b/docs/source/registries.rst @@ -0,0 +1,24 @@ +Registries +========== + +All models in Curated Transformers are added to a registry. Each auto class uses +a registry to query which models are available. This mechanism allows +third-party models to hook into the auto classes. This makes it possible to use +construction methods such as ``AutoModel.from_hf_hub`` with third-party models. + +Third-party packages can register models in the ``options.entry_points`` section +of ``setup.cfg``. For example, if the ``models`` module of the +``extra-transformers`` package contains the ``FooCausalLM``, ``BarDecoder``, and +``BazEncoder`` classes, they can be registered in ``setup.cfg`` as follows: + +.. code-block:: ini + + [options.entry_points] + curated_transformers_causal_lms = + extra-transformers.FooCausalLM = extra_transformers.models:FooCausalLM + + curated_transformers_decoders = + extra-transformers.BarDecoder = extra_transformers.models:BarDecoder + + curated_transformers_encoders = + extra-transformers.BazEncoder = extra_transformers.models:BazEncoder diff --git a/requirements.txt b/requirements.txt index 8d0472b6..dd5b846a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +catalogue>=2.0.4,<2.1.0 curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 diff --git a/setup.cfg b/setup.cfg index 70b92871..29a6d811 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ zip_safe = true include_package_data = true python_requires = >=3.8 install_requires = + catalogue>=2.0.4,<2.1.0 curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 @@ -24,9 +25,29 @@ quantization = bitsandbytes>=0.40 # bitsandbytes has a dependency on scipy but doesn't # list it as one for pip installs. So, we'll pull that - # in too (until its rectified upstream). + # in too (until it's rectified upstream). scipy>=1.11 +[options.entry_points] +curated_transformers_causal_lms = + curated-transformers.LlamaCausalLM = curated_transformers.models:LlamaCausalLM + curated-transformers.FalconCausalLM = curated_transformers.models:FalconCausalLM + curated-transformers.GPTNeoXCausalLM = curated_transformers.models:GPTNeoXCausalLM + curated-transformers.MPTCausalLM = curated_transformers.models:MPTCausalLM + +curated_transformers_decoders = + curated-transformers.LlamaDecoder = curated_transformers.models:LlamaDecoder + curated-transformers.FalconDecoder = curated_transformers.models:FalconDecoder + curated-transformers.GPTNeoXDecoder = curated_transformers.models:GPTNeoXDecoder + curated-transformers.MPTDecoder = curated_transformers.models:MPTDecoder + +curated_transformers_encoders = + curated-transformers.ALBERTEncoder = curated_transformers.models:ALBERTEncoder + curated-transformers.BERTEncoder = curated_transformers.models:BERTEncoder + curated-transformers.CamemBERTEncoder = curated_transformers.models:CamemBERTEncoder + curated-transformers.RoBERTaEncoder = curated_transformers.models:RoBERTaEncoder + curated-transformers.XLMREncoder = curated_transformers.models:XLMREncoder + [bdist_wheel] universal = true