Skip to content

Commit

Permalink
Add support for converting Curated Transfomer state dicts to
Browse files Browse the repository at this point in the history
Hugging Face compatible state dicts.
  • Loading branch information
shadeMe committed Sep 26, 2023
1 parent 1f023dc commit bb95cdc
Show file tree
Hide file tree
Showing 36 changed files with 869 additions and 521 deletions.
109 changes: 40 additions & 69 deletions curated_transformers/models/albert/_hf.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,46 @@
import re
from types import MappingProxyType
from typing import Any, Callable, Dict, Mapping, Tuple, Union

from torch import Tensor
from typing import Any, Callable, Dict, List, Tuple, Union

from ...layers.activations import Activation
from ..hf_hub import _process_hf_keys
from ...util.string import StringTransform, StrLStrip, StrRepl, StrSub, StrSubInv
from ..hf_hub.conversion import process_hf_keys
from .config import ALBERTConfig

HF_KEY_TO_CURATED_KEY = MappingProxyType(
{
"embeddings.word_embeddings.weight": "embeddings.piece_embeddings.weight",
"embeddings.token_type_embeddings.weight": "embeddings.type_embeddings.weight",
"embeddings.position_embeddings.weight": "embeddings.position_embeddings.weight",
"embeddings.LayerNorm.weight": "embeddings.embed_output_layer_norm.weight",
"embeddings.LayerNorm.bias": "embeddings.embed_output_layer_norm.bias",
# Embedding projection
"encoder.embedding_hidden_mapping_in.weight": "embeddings.projection.weight",
"encoder.embedding_hidden_mapping_in.bias": "embeddings.projection.bias",
}
)
# Order-dependent.
HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
# Prefixes.
StrLStrip("albert.", reversible=False),
StrSub(
(r"^encoder\.(embedding_|albert_layer)", "\\1"),
(r"^(embedding_|albert_layer)", "encoder.\\1"),
),
# Layer groups
StrSub(
(r"^albert_layer_groups\.", "groups."), (r"^groups\.", "albert_layer_groups.")
),
# Inner layers.
StrSubInv((".albert_layers.", ".group_layers.")),
# Attention blocks.
StrSubInv((".attention.", ".mha.")),
StrSubInv((".mha.LayerNorm", ".attn_residual_layer_norm")),
StrSubInv((".mha.dense", ".mha.output")),
# Pointwise feed-forward layers.
StrSubInv((".ffn.", ".ffn.intermediate.")),
StrSubInv((".ffn_output.", ".ffn.output.")),
StrSubInv((".full_layer_layer_norm.", ".ffn_residual_layer_norm.")),
# Embeddings.
StrRepl("embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight"),
StrRepl(
"embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight"
),
StrRepl(
"embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"
),
StrRepl("embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight"),
StrRepl("embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias"),
# Embedding projection.
StrRepl("embedding_hidden_mapping_in.weight", "embeddings.projection.weight"),
StrRepl("embedding_hidden_mapping_in.bias", "embeddings.projection.bias"),
]

HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = {
"attention_probs_dropout_prob": "attention_probs_dropout_prob",
Expand All @@ -40,55 +61,5 @@


def convert_hf_config(hf_config: Any) -> ALBERTConfig:
kwargs = _process_hf_keys("ALBERT", hf_config, HF_CONFIG_KEY_MAPPING)
kwargs = process_hf_keys("ALBERT", hf_config, HF_CONFIG_KEY_MAPPING)
return ALBERTConfig(model_max_length=hf_config["max_position_embeddings"], **kwargs)


def convert_hf_state_dict(params: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
# Strip the `albert` prefix from ALBERT model parameters.
stripped_params = {re.sub(r"^albert\.", "", k): v for k, v in params.items()}

# The ALBERT encoder parameters have the following form:
#
# encoder.albert_layer_groups.{hidden_group}.albert_layers.{inner_layer}.{param_name}
#
# hidden_group is in [0, n_hidden_group)
# inner_layer is in [0, n_layers_per_group)

out = {}
for name, parameter in stripped_params.items():
if "encoder.albert_layer" not in name:
continue

# TODO: Make these substitutions less ugly.

# Remove the prefix and rename.
name = re.sub(r"^encoder\.", "", name)

# Layer groups
name = re.sub(r"^albert_layer_groups\.", "groups.", name)

# Inner layers.
name = re.sub(r"\.albert_layers\.", ".group_layers.", name)

# Attention blocks.
name = re.sub(r"\.attention\.", ".mha.", name)
name = re.sub(r"\.mha\.LayerNorm", r".attn_residual_layer_norm", name)
name = re.sub(r"\.mha\.dense\.", r".mha.output.", name)

# Pointwise feed-forward layers.
name = re.sub(r"\.ffn\.", r".ffn.intermediate.", name)
name = re.sub(r"\.ffn_output\.", r".ffn.output.", name)
name = re.sub(
r"\.full_layer_layer_norm\.",
r".ffn_residual_layer_norm.",
name,
)

out[name] = parameter

for hf_name, curated_name in HF_KEY_TO_CURATED_KEY.items():
if hf_name in stripped_params:
out[curated_name] = stripped_params[hf_name]

return out
15 changes: 12 additions & 3 deletions curated_transformers/models/albert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
TransformerEmbeddings,
)
from ..hf_hub import FromHFHub
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..module import EncoderModule
from ..output import ModelOutput
from ._hf import convert_hf_config, convert_hf_state_dict
from ._hf import HF_PARAM_KEY_TRANSFORMS, convert_hf_config
from .config import ALBERTConfig
from .layer_group import ALBERTLayerGroup

Expand Down Expand Up @@ -99,8 +100,16 @@ def forward(
return ModelOutput(all_outputs=[embeddings, *layer_outputs])

@classmethod
def convert_hf_state_dict(cls, params: Mapping[str, Tensor]):
return convert_hf_state_dict(params)
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def state_dict_to_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def from_hf_config(
Expand Down
113 changes: 45 additions & 68 deletions curated_transformers/models/bert/_hf.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,51 @@
import re
from types import MappingProxyType
from typing import Any, Callable, Dict, Mapping, Tuple, Union

from torch import Tensor
from typing import Any, Callable, Dict, List, Tuple, Union

from ...layers.activations import Activation
from ..hf_hub import _process_hf_keys
from ...util.string import StringTransform, StrLStrip, StrRepl, StrSub, StrSubInv
from ..hf_hub.conversion import process_hf_keys
from .config import BERTConfig

HF_KEY_TO_CURATED_KEY = MappingProxyType(
{
"embeddings.word_embeddings.weight": "embeddings.piece_embeddings.weight",
"embeddings.token_type_embeddings.weight": "embeddings.type_embeddings.weight",
"embeddings.position_embeddings.weight": "embeddings.position_embeddings.weight",
"embeddings.LayerNorm.weight": "embeddings.embed_output_layer_norm.weight",
"embeddings.LayerNorm.bias": "embeddings.embed_output_layer_norm.bias",
}
)

# Order-dependent.
HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
# Old HF parameter names (one-way transforms).
StrSub((r"\.gamma$", ".weight"), backward=None),
StrSub((r"\.beta$", ".bias"), backward=None),
# Prefixes.
StrLStrip("bert.", reversible=False),
StrSub(
(r"^encoder\.(layer\.)", "\\1"),
(r"^(layer\.)", "encoder.\\1"),
),
# Layers.
StrSub((r"^layer", "layers"), (r"^layers", "layer")),
# Attention blocks.
StrSub(
(r"\.attention\.self\.(query|key|value)", ".mha.\\1"),
(r"\.mha\.(query|key|value)", ".attention.self.\\1"),
),
StrSubInv((r".attention.output.dense", ".mha.output")),
StrSubInv((r".attention.output.LayerNorm", ".attn_residual_layer_norm")),
# Pointwise feed-forward layers.
StrSubInv((r".intermediate.dense", ".ffn.intermediate")),
StrSub(
(r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"),
(r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"),
),
StrSub(
(r"(\.\d+)\.output\.dense", "\\1.ffn.output"),
(r"(\.\d+)\.ffn\.output", "\\1.output.dense"),
),
# Embeddings.
StrRepl("embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight"),
StrRepl(
"embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight"
),
StrRepl(
"embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"
),
StrRepl("embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight"),
StrRepl("embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias"),
]

HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = {
"attention_probs_dropout_prob": "attention_probs_dropout_prob",
Expand All @@ -35,61 +63,10 @@


def convert_hf_config(hf_config: Any) -> BERTConfig:
kwargs = _process_hf_keys("BERT", hf_config, HF_CONFIG_KEY_MAPPING)
kwargs = process_hf_keys("BERT", hf_config, HF_CONFIG_KEY_MAPPING)

return BERTConfig(
embedding_width=hf_config["hidden_size"],
model_max_length=hf_config["max_position_embeddings"],
**kwargs,
)


def convert_hf_state_dict(params: Mapping[str, Tensor]) -> Mapping[str, Tensor]:
out = {}

renamed_params = _rename_old_hf_names(params)

# Strip the `bert` prefix from BERT model parameters.
stripped_params = {re.sub(r"^bert\.", "", k): v for k, v in renamed_params.items()}

for name, parameter in stripped_params.items():
if "encoder.layer." not in name:
continue

# TODO: Make these substitutions less ugly.

# Remove the prefix and rename the internal 'layers' variable.
name = re.sub(r"^encoder\.", "", name)
name = re.sub(r"^layer", "layers", name)

# The HF model has one more level of indirection for the output layers in their
# attention heads and the feed-forward network layers.
name = re.sub(r"\.attention\.self\.(query|key|value)", r".mha.\1", name)
name = re.sub(r"\.attention\.(output)\.dense", r".mha.\1", name)
name = re.sub(
r"\.attention\.output\.LayerNorm", r".attn_residual_layer_norm", name
)
name = re.sub(r"\.(intermediate)\.dense", r".ffn.\1", name)
name = re.sub(
r"(\.\d+)\.output\.LayerNorm", r"\1.ffn_residual_layer_norm", name
)
name = re.sub(r"(\.\d+)\.(output)\.dense", r"\1.ffn.\2", name)

out[name] = parameter

for hf_name, curated_name in HF_KEY_TO_CURATED_KEY.items():
if hf_name in stripped_params:
out[curated_name] = stripped_params[hf_name]

return out


def _rename_old_hf_names(
params: Mapping[str, Tensor],
) -> Mapping[str, Tensor]:
out = {}
for name, parameter in params.items():
name = re.sub(r"\.gamma$", ".weight", name)
name = re.sub(r"\.beta$", ".bias", name)
out[name] = parameter
return out
15 changes: 12 additions & 3 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
TransformerLayerNorms,
)
from ..hf_hub import FromHFHub
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerEncoder
from ._hf import convert_hf_config, convert_hf_state_dict
from ._hf import HF_PARAM_KEY_TRANSFORMS, convert_hf_config
from .config import BERTConfig

# Only provided as typing.Self in Python 3.11+.
Expand Down Expand Up @@ -105,8 +106,16 @@ def __init__(self, config: BERTConfig, *, device: Optional[torch.device] = None)
)

@classmethod
def convert_hf_state_dict(cls, params: Mapping[str, Tensor]):
return convert_hf_state_dict(params)
def state_dict_from_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_from_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def state_dict_to_hf(
cls: Type[Self], params: Mapping[str, Tensor]
) -> Mapping[str, Tensor]:
return state_dict_to_hf(params, HF_PARAM_KEY_TRANSFORMS)

@classmethod
def from_hf_config(
Expand Down
Loading

0 comments on commit bb95cdc

Please sign in to comment.