Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register models using catalogue #351

Merged
merged 3 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions curated_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import catalogue
from catalogue import Registry


class registry(object):
"""
Registry for models. These registries are used by auto classes to
discover the available models.
"""

causal_lms: Registry = catalogue.create(
"curated_transformers", "causal_lms", entry_points=True
)
decoders: Registry = catalogue.create(
"curated_transformers", "decoders", entry_points=True
)
encoders: Registry = catalogue.create(
"curated_transformers", "encoders", entry_points=True
)
6 changes: 5 additions & 1 deletion curated_transformers/models/albert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -99,6 +99,10 @@ def forward(

return ModelOutput(all_outputs=[embeddings, *layer_outputs])

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("albert",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
71 changes: 28 additions & 43 deletions curated_transformers/models/auto_model.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Generic, Optional, Type, TypeVar
from typing import Generic, Optional, Type, TypeVar

import torch
from catalogue import Registry
from fsspec import AbstractFileSystem

from curated_transformers import registry

from ..layers.cache import KeyValueCache
from ..quantization.bnb.config import BitsAndBytesConfig
from ..repository.fsspec import FsspecArgs, FsspecRepository
from ..repository.hf_hub import HfHubRepository
from ..repository.repository import ModelRepository, Repository
from .albert import ALBERTEncoder
from .bert import BERTEncoder
from .camembert import CamemBERTEncoder
from .config import TransformerConfig
from .falcon import FalconCausalLM, FalconDecoder
from .gpt_neox import GPTNeoXCausalLM, GPTNeoXDecoder
from .hf_hub import FromHFHub
from .llama import LlamaCausalLM, LlamaDecoder
from .module import CausalLMModule, DecoderModule, EncoderModule
from .mpt.causal_lm import MPTCausalLM
from .mpt.decoder import MPTDecoder
from .roberta import RoBERTaEncoder
from .xlm_roberta import XLMREncoder

ModelT = TypeVar("ModelT")

Expand All @@ -32,22 +26,33 @@ class AutoModel(ABC, Generic[ModelT]):
Face Model Hub.
"""

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {}
_registry: Registry

@classmethod
def _resolve_model_cls(
cls,
repo: ModelRepository,
) -> Type[FromHFHub]:
model_type = repo.model_type()
module_cls = cls._hf_model_type_to_curated.get(model_type)
if module_cls is None:
raise ValueError(
f"Unsupported model type `{model_type}` for {cls.__name__}. "
f"Supported model types: {tuple(cls._hf_model_type_to_curated.keys())}"
)
assert issubclass(module_cls, FromHFHub)
return module_cls

supported_model_types = set()
for entrypoint, module_cls in cls._registry.get_entry_points().items():
if not issubclass(module_cls, FromHFHub):
warnings.warn(
f"Entry point `{entrypoint}` cannot load from Hugging Face Hub "
"since the FromHFHub mixin is not implemented"
)
continue

module_model_types = module_cls.hf_model_types()
if model_type in module_model_types:
return module_cls
supported_model_types.update(module_model_types)

raise ValueError(
f"Unsupported model type `{model_type}` for {cls.__name__}. "
f"Supported model types: {', '.join(sorted(supported_model_types))}"
)

@classmethod
def _instantiate_model(
Expand Down Expand Up @@ -182,13 +187,7 @@ class AutoEncoder(AutoModel[EncoderModule[TransformerConfig]]):
Encoder model loaded from the Hugging Face Model Hub.
"""

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {
"bert": BERTEncoder,
"albert": ALBERTEncoder,
"camembert": CamemBERTEncoder,
"roberta": RoBERTaEncoder,
"xlm-roberta": XLMREncoder,
}
_registry: Registry = registry.encoders

@classmethod
def from_repo(
Expand All @@ -208,14 +207,7 @@ class AutoDecoder(AutoModel[DecoderModule[TransformerConfig, KeyValueCache]]):
Decoder module loaded from the Hugging Face Model Hub.
"""

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {
"falcon": FalconDecoder,
"gpt_neox": GPTNeoXDecoder,
"llama": LlamaDecoder,
"mpt": MPTDecoder,
"RefinedWeb": FalconDecoder,
"RefinedWebModel": FalconDecoder,
}
_registry = registry.decoders

@classmethod
def from_repo(
Expand All @@ -235,14 +227,7 @@ class AutoCausalLM(AutoModel[CausalLMModule[TransformerConfig, KeyValueCache]]):
Causal LM model loaded from the Hugging Face Model Hub.
"""

_hf_model_type_to_curated: Dict[str, Type[FromHFHub]] = {
"falcon": FalconCausalLM,
"gpt_neox": GPTNeoXCausalLM,
"llama": LlamaCausalLM,
"mpt": MPTCausalLM,
"RefinedWeb": FalconCausalLM,
"RefinedWebModel": FalconCausalLM,
}
_registry: Registry = registry.causal_lms

@classmethod
def from_repo(
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -105,6 +105,10 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None)
]
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("bert",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/camembert/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

import torch

Expand All @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
The encoder.
"""
super().__init__(config, device=device)

@classmethod
def hf_model_types(cls) -> Tuple[str, ...]:
return ("camembert",)
6 changes: 5 additions & 1 deletion curated_transformers/models/falcon/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Type, TypeVar
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -52,6 +52,10 @@ def state_dict_from_hf(
) -> Mapping[str, Tensor]:
return state_dict_from_hf(params, CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("falcon", "RefinedWeb", "RefinedWebModel")

@classmethod
def state_dict_to_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/falcon/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -86,6 +86,10 @@ def __init__(
device=device,
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("falcon", "RefinedWeb", "RefinedWebModel")

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/gpt_neox/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Type, TypeVar
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -46,6 +46,10 @@ def __init__(
device=device,
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("gpt_neox",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/gpt_neox/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -114,6 +114,10 @@ def __init__(
hidden_width, config.layer.layer_norm_eps, device=device
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("gpt_neox",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
13 changes: 12 additions & 1 deletion curated_transformers/models/hf_hub/mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from fsspec import AbstractFileSystem
Expand Down Expand Up @@ -168,6 +168,17 @@ def from_hf_hub(
quantization_config=quantization_config,
)

@classmethod
@abstractmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
"""
Get the Hugging Face model types supported by this model.

:returns:
The supported model types.
"""
...

@abstractmethod
def to(
self,
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/llama/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Set, Type, TypeVar
from typing import Any, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -47,6 +47,10 @@ def __init__(
device=device,
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("llama",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/llama/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -120,6 +120,10 @@ def __init__(
hidden_width, eps=config.layer.layer_norm_eps, device=device
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("llama",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/mpt/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Mapping, Optional, Set, Type, TypeVar
from typing import Any, List, Mapping, Optional, Set, Tuple, Type, TypeVar

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -84,6 +84,10 @@ def forward(
logits=logits,
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("mpt",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/mpt/decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -120,6 +120,10 @@ def layer_norm():

self.output_layer_norm = layer_norm()

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("mpt",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/roberta/encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Mapping, Optional, Type, TypeVar
from typing import Any, Mapping, Optional, Tuple, Type, TypeVar

import torch
from torch import Tensor
Expand Down Expand Up @@ -105,6 +105,10 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
]
)

@classmethod
def hf_model_types(cls: Type[Self]) -> Tuple[str, ...]:
return ("roberta",)

@classmethod
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
Expand Down
6 changes: 5 additions & 1 deletion curated_transformers/models/xlm_roberta/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

import torch

Expand All @@ -25,3 +25,7 @@ def __init__(self, config: RoBERTaConfig, *, device: Optional[torch.device] = No
The encoder.
"""
super().__init__(config, device=device)

@classmethod
def hf_model_types(cls) -> Tuple[str, ...]:
return ("xlm-roberta",)
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ API
decoders
causal-lm
generation
registries
repositories
tokenizers
quantization
Expand Down
Loading