diff --git a/curated_transformers/__init__.py b/curated_transformers/__init__.py index e69de29b..0fdf4836 100644 --- a/curated_transformers/__init__.py +++ b/curated_transformers/__init__.py @@ -0,0 +1,19 @@ +import catalogue +from catalogue import Registry + + +class registry(object): + """ + Registry for models. These registries are used by auto classes to + discover the available models. + """ + + causal_lms: Registry = catalogue.create( + "curated_transformers", "causal_lms", entry_points=True + ) + decoders: Registry = catalogue.create( + "curated_transformers", "decoders", entry_points=True + ) + encoders: Registry = catalogue.create( + "curated_transformers", "encoders", entry_points=True + ) diff --git a/curated_transformers/models/albert/encoder.py b/curated_transformers/models/albert/encoder.py index 1a247c23..2c4ab449 100644 --- a/curated_transformers/models/albert/encoder.py +++ b/curated_transformers/models/albert/encoder.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -99,6 +99,10 @@ def forward( return ModelOutput(all_outputs=[embeddings, *layer_outputs]) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("albert",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/auto_model.py b/curated_transformers/models/auto_model.py index 94de95a0..2c2732cd 100644 --- a/curated_transformers/models/auto_model.py +++ b/curated_transformers/models/auto_model.py @@ -1,27 +1,21 @@ +import warnings from abc import ABC, abstractmethod -from typing import Dict, Generic, Optional, Type, TypeVar +from typing import Generic, Optional, Type, TypeVar import torch +from catalogue import Registry from fsspec import AbstractFileSystem +from curated_transformers import registry + from ..layers.cache import KeyValueCache from ..quantization.bnb.config import BitsAndBytesConfig from ..repository.fsspec import FsspecArgs, FsspecRepository from ..repository.hf_hub import HfHubRepository from ..repository.repository import ModelRepository, Repository -from .albert import ALBERTEncoder -from .bert import BERTEncoder -from .camembert import CamemBERTEncoder from .config import TransformerConfig -from .falcon import FalconCausalLM, FalconDecoder -from .gpt_neox import GPTNeoXCausalLM, GPTNeoXDecoder from .hf_hub import FromHFHub -from .llama import LlamaCausalLM, LlamaDecoder from .module import CausalLMModule, DecoderModule, EncoderModule -from .mpt.causal_lm import MPTCausalLM -from .mpt.decoder import MPTDecoder -from .roberta import RoBERTaEncoder -from .xlm_roberta import XLMREncoder ModelT = TypeVar("ModelT") @@ -32,7 +26,7 @@ class AutoModel(ABC, Generic[ModelT]): Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {} + _registry: Registry @classmethod def _resolve_model_cls( @@ -40,14 +34,25 @@ def _resolve_model_cls( repo: ModelRepository, ) -> Type[FromHFHub]: model_type = repo.model_type() - module_cls = cls._hf_model_type_to_curated.get(model_type) - if module_cls is None: - raise ValueError( - f"Unsupported model type `{model_type}` for {cls.__name__}. " - f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}" - ) - assert issubclass(module_cls, FromHFHub) - return module_cls + + supported_model_types = set() + for entrypoint, module_cls in cls._registry.get_entry_points().items(): + if not issubclass(module_cls, FromHFHub): + warnings.warn( + f"Entry point `{entrypoint}` cannot load from Hugging Face Hub " + "since the FromHFHub mixin is not implemented" + ) + continue + + module_model_types = module_cls.hf_model_types() + if model_type in module_model_types: + return module_cls + supported_model_types.update(module_model_types) + + raise ValueError( + f"Unsupported model type `{model_type}` for {cls.__name__}. " + f"Supported model types: {', '.join(sorted(supported_model_types))}" + ) @classmethod def _instantiate_model( @@ -182,13 +187,7 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]): Encoder model loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "bert": BERTEncoder, - "albert": ALBERTEncoder, - "camembert": CamemBERTEncoder, - "roberta": RoBERTaEncoder, - "xlm-roberta": XLMREncoder, - } + _registry: Registry = registry.encoders @classmethod def from_repo( @@ -208,14 +207,7 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]): Decoder module loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "falcon": FalconDecoder, - "gpt_neox": GPTNeoXDecoder, - "llama": LlamaDecoder, - "mpt": MPTDecoder, - "RefinedWeb": FalconDecoder, - "RefinedWebModel": FalconDecoder, - } + _registry = registry.decoders @classmethod def from_repo( @@ -235,14 +227,7 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]): Causal LM model loaded from the Hugging Face Model Hub. """ - _hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = { - "falcon": FalconCausalLM, - "gpt_neox": GPTNeoXCausalLM, - "llama": LlamaCausalLM, - "mpt": MPTCausalLM, - "RefinedWeb": FalconCausalLM, - "RefinedWebModel": FalconCausalLM, - } + _registry: Registry = registry.causal_lms @classmethod def from_repo( diff --git a/curated_transformers/models/bert/encoder.py b/curated_transformers/models/bert/encoder.py index c9b4c6be..73d81c82 100644 --- a/curated_transformers/models/bert/encoder.py +++ b/curated_transformers/models/bert/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -113,6 +113,10 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None) ] ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("bert",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/camembert/encoder.py b/curated_transformers/models/camembert/encoder.py index 0c18c345..6a16919f 100644 --- a/curated_transformers/models/camembert/encoder.py +++ b/curated_transformers/models/camembert/encoder.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No The encoder. """ super().__init__(config, device=device) + + @classmethod + def hf_model_types(cls) -> Tuple[str, ...]: + return ("camembert",) diff --git a/curated_transformers/models/falcon/causal_lm.py b/curated_transformers/models/falcon/causal_lm.py index 91860d06..4c5661b2 100644 --- a/curated_transformers/models/falcon/causal_lm.py +++ b/curated_transformers/models/falcon/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -52,6 +52,10 @@ def state_dict_from_hf( ) -> Mapping[str, Tensor]: return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("falcon", "RefinedWeb", "RefinedWebModel") + @classmethod def state_dict_to_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/falcon/decoder.py b/curated_transformers/models/falcon/decoder.py index 5e08d928..659496cd 100644 --- a/curated_transformers/models/falcon/decoder.py +++ b/curated_transformers/models/falcon/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -87,6 +87,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("falcon", "RefinedWeb", "RefinedWebModel") + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/gpt_neox/causal_lm.py b/curated_transformers/models/gpt_neox/causal_lm.py index eae67a2d..302f304a 100644 --- a/curated_transformers/models/gpt_neox/causal_lm.py +++ b/curated_transformers/models/gpt_neox/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -46,6 +46,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("gpt_neox",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/gpt_neox/decoder.py b/curated_transformers/models/gpt_neox/decoder.py index 37f3e237..3f6cd19f 100644 --- a/curated_transformers/models/gpt_neox/decoder.py +++ b/curated_transformers/models/gpt_neox/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -122,6 +122,10 @@ def __init__( hidden_width, config.layer.layer_norm_eps, device=device ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("gpt_neox",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/hf_hub/mixin.py b/curated_transformers/models/hf_hub/mixin.py index 7f5e32f2..393d4254 100644 --- a/curated_transformers/models/hf_hub/mixin.py +++ b/curated_transformers/models/hf_hub/mixin.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from fsspec import AbstractFileSystem @@ -168,6 +168,17 @@ def from_hf_hub( quantization_config=quantization_config, ) + @classmethod + @abstractmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + """ + Get the Hugging Face model types supported by this model. + + :returns: + The supported model types. + """ + ... + @abstractmethod def to( self, diff --git a/curated_transformers/models/llama/causal_lm.py b/curated_transformers/models/llama/causal_lm.py index aa9f5320..0487aed1 100644 --- a/curated_transformers/models/llama/causal_lm.py +++ b/curated_transformers/models/llama/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Set, Type, TypeVar +from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar import torch from torch import Tensor @@ -47,6 +47,10 @@ def __init__( device=device, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("llama",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/llama/decoder.py b/curated_transformers/models/llama/decoder.py index 0fb58644..d660d29b 100644 --- a/curated_transformers/models/llama/decoder.py +++ b/curated_transformers/models/llama/decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -128,6 +128,10 @@ def __init__( hidden_width, eps=config.layer.layer_norm_eps, device=device ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("llama",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/mpt/causal_lm.py b/curated_transformers/models/mpt/causal_lm.py index dfc9c72a..de2c3916 100644 --- a/curated_transformers/models/mpt/causal_lm.py +++ b/curated_transformers/models/mpt/causal_lm.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping, Optional, Set, Type, TypeVar +from typing import Any, List, Mapping, Optional, Set, Tuple, Type, TypeVar import torch import torch.nn.functional as F @@ -84,6 +84,10 @@ def forward( logits=logits, ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("mpt",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/mpt/decoder.py b/curated_transformers/models/mpt/decoder.py index 5ee41909..e05de9c6 100644 --- a/curated_transformers/models/mpt/decoder.py +++ b/curated_transformers/models/mpt/decoder.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -123,6 +123,10 @@ def layer_norm(): self.output_layer_norm = layer_norm() + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("mpt",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/roberta/encoder.py b/curated_transformers/models/roberta/encoder.py index 88d4317c..d98e980e 100644 --- a/curated_transformers/models/roberta/encoder.py +++ b/curated_transformers/models/roberta/encoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Mapping, Optional, Type, TypeVar +from typing import Any, Mapping, Optional, Tuple, Type, TypeVar import torch from torch import Tensor @@ -113,6 +113,10 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No ] ) + @classmethod + def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]: + return ("roberta",) + @classmethod def state_dict_from_hf( cls: Type[Self], params: Mapping[str, Tensor] diff --git a/curated_transformers/models/xlm_roberta/encoder.py b/curated_transformers/models/xlm_roberta/encoder.py index 300a7a22..ee398958 100644 --- a/curated_transformers/models/xlm_roberta/encoder.py +++ b/curated_transformers/models/xlm_roberta/encoder.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No The encoder. """ super().__init__(config, device=device) + + @classmethod + def hf_model_types(cls) -> Tuple[str, ...]: + return ("xlm-roberta",) diff --git a/docs/source/api.rst b/docs/source/api.rst index 03cf662b..4bece367 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,6 +9,7 @@ API decoders causal-lm generation + registries repositories tokenizers quantization diff --git a/docs/source/registries.rst b/docs/source/registries.rst new file mode 100644 index 00000000..3fa3545d --- /dev/null +++ b/docs/source/registries.rst @@ -0,0 +1,26 @@ +Registries +========== + +All models in Curated Transformers are added to a registry. Each auto class uses +a registry to query which models are available. This mechanism allows +third-party models to hook into the auto classes. This makes it possible to use +construction methods such as +:py:class:`~curated_transformers.models.AutoEncoder.from_hf_hub` with third-party +models. + +Third-party packages can register models in the ``options.entry_points`` section +of ``setup.cfg``. For example, if the ``models`` module of the +``extra-transformers`` package contains the ``FooCausalLM``, ``BarDecoder``, and +``BazEncoder`` classes, they can be registered in ``setup.cfg`` as follows: + +.. code-block:: ini + + [options.entry_points] + curated_transformers_causal_lms = + extra-transformers.FooCausalLM = extra_transformers.models:FooCausalLM + + curated_transformers_decoders = + extra-transformers.BarDecoder = extra_transformers.models:BarDecoder + + curated_transformers_encoders = + extra-transformers.BazEncoder = extra_transformers.models:BazEncoder diff --git a/requirements.txt b/requirements.txt index 8d0472b6..dd5b846a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +catalogue>=2.0.4,<2.1.0 curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 diff --git a/setup.cfg b/setup.cfg index 0917474a..c921a2d6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ zip_safe = true include_package_data = true python_requires = >=3.8 install_requires = + catalogue>=2.0.4,<2.1.0 curated-tokenizers>=0.9.1,<1.0.0 huggingface-hub>=0.14 tokenizers>=0.13.3 @@ -24,9 +25,29 @@ quantization = bitsandbytes>=0.40 # bitsandbytes has a dependency on scipy but doesn't # list it as one for pip installs. So, we'll pull that - # in too (until its rectified upstream). + # in too (until it's rectified upstream). scipy>=1.11 +[options.entry_points] +curated_transformers_causal_lms = + curated-transformers.LlamaCausalLM = curated_transformers.models:LlamaCausalLM + curated-transformers.FalconCausalLM = curated_transformers.models:FalconCausalLM + curated-transformers.GPTNeoXCausalLM = curated_transformers.models:GPTNeoXCausalLM + curated-transformers.MPTCausalLM = curated_transformers.models:MPTCausalLM + +curated_transformers_decoders = + curated-transformers.LlamaDecoder = curated_transformers.models:LlamaDecoder + curated-transformers.FalconDecoder = curated_transformers.models:FalconDecoder + curated-transformers.GPTNeoXDecoder = curated_transformers.models:GPTNeoXDecoder + curated-transformers.MPTDecoder = curated_transformers.models:MPTDecoder + +curated_transformers_encoders = + curated-transformers.ALBERTEncoder = curated_transformers.models:ALBERTEncoder + curated-transformers.BERTEncoder = curated_transformers.models:BERTEncoder + curated-transformers.CamemBERTEncoder = curated_transformers.models:CamemBERTEncoder + curated-transformers.RoBERTaEncoder = curated_transformers.models:RoBERTaEncoder + curated-transformers.XLMREncoder = curated_transformers.models:XLMREncoder + [bdist_wheel] universal = true