Skip to content

Commit

Permalink
Add StringTransformations to expose factory methods for
Browse files Browse the repository at this point in the history
`StringTransform` subclasses
  • Loading branch information
shadeMe committed Sep 26, 2023
1 parent dff38bd commit ed22421
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 142 deletions.
50 changes: 26 additions & 24 deletions curated_transformers/models/albert/_hf.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,59 @@
from typing import Any, Callable, Dict, List, Tuple, Union

from ...layers.activations import Activation
from ...util.string import (
StringRemovePrefix,
StringReplace,
StringSubInvertible,
StringSubRegEx,
StringTransform,
)
from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import process_hf_keys
from .config import ALBERTConfig

# Order-dependent.
HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
# Prefixes.
StringRemovePrefix("albert.", reversible=False),
StringSubRegEx(
StringTransformations.remove_prefix("albert.", reversible=False),
StringTransformations.regex_sub(
(r"^encoder\.(embedding_|albert_layer)", "\\1"),
(r"^(embedding_|albert_layer)", "encoder.\\1"),
),
# Layer groups
StringSubRegEx(
StringTransformations.regex_sub(
(r"^albert_layer_groups\.", "groups."), (r"^groups\.", "albert_layer_groups.")
),
# Inner layers.
StringSubInvertible((".albert_layers.", ".group_layers.")),
StringTransformations.regex_sub_invertible((".albert_layers.", ".group_layers.")),
# Attention blocks.
StringSubInvertible((".attention.", ".mha.")),
StringSubInvertible((".mha.LayerNorm", ".attn_residual_layer_norm")),
StringSubInvertible((".mha.dense", ".mha.output")),
StringTransformations.regex_sub_invertible((".attention.", ".mha.")),
StringTransformations.regex_sub_invertible(
(".mha.LayerNorm", ".attn_residual_layer_norm")
),
StringTransformations.regex_sub_invertible((".mha.dense", ".mha.output")),
# Pointwise feed-forward layers.
StringSubInvertible((".ffn.", ".ffn.intermediate.")),
StringSubInvertible((".ffn_output.", ".ffn.output.")),
StringSubInvertible((".full_layer_layer_norm.", ".ffn_residual_layer_norm.")),
StringTransformations.regex_sub_invertible((".ffn.", ".ffn.intermediate.")),
StringTransformations.regex_sub_invertible((".ffn_output.", ".ffn.output.")),
StringTransformations.regex_sub_invertible(
(".full_layer_layer_norm.", ".ffn_residual_layer_norm.")
),
# Embeddings.
StringReplace(
StringTransformations.replace(
"embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight"
),
StringReplace(
StringTransformations.replace(
"embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight"
),
StringReplace(
StringTransformations.replace(
"embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"
),
StringReplace(
StringTransformations.replace(
"embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight"
),
StringReplace(
StringTransformations.replace(
"embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias"
),
# Embedding projection.
StringReplace("embedding_hidden_mapping_in.weight", "embeddings.projection.weight"),
StringReplace("embedding_hidden_mapping_in.bias", "embeddings.projection.bias"),
StringTransformations.replace(
"embedding_hidden_mapping_in.weight", "embeddings.projection.weight"
),
StringTransformations.replace(
"embedding_hidden_mapping_in.bias", "embeddings.projection.bias"
),
]

HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = {
Expand Down
46 changes: 23 additions & 23 deletions curated_transformers/models/bert/_hf.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,60 @@
from typing import Any, Callable, Dict, List, Tuple, Union

from ...layers.activations import Activation
from ...util.string import (
StringRemovePrefix,
StringReplace,
StringSubInvertible,
StringSubRegEx,
StringTransform,
)
from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import process_hf_keys
from .config import BERTConfig

# Order-dependent.
HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
# Old HF parameter names (one-way transforms).
StringSubRegEx((r"\.gamma$", ".weight"), backward=None),
StringSubRegEx((r"\.beta$", ".bias"), backward=None),
StringTransformations.regex_sub((r"\.gamma$", ".weight"), backward=None),
StringTransformations.regex_sub((r"\.beta$", ".bias"), backward=None),
# Prefixes.
StringRemovePrefix("bert.", reversible=False),
StringSubRegEx(
StringTransformations.remove_prefix("bert.", reversible=False),
StringTransformations.regex_sub(
(r"^encoder\.(layer\.)", "\\1"),
(r"^(layer\.)", "encoder.\\1"),
),
# Layers.
StringSubRegEx((r"^layer", "layers"), (r"^layers", "layer")),
StringTransformations.regex_sub((r"^layer", "layers"), (r"^layers", "layer")),
# Attention blocks.
StringSubRegEx(
StringTransformations.regex_sub(
(r"\.attention\.self\.(query|key|value)", ".mha.\\1"),
(r"\.mha\.(query|key|value)", ".attention.self.\\1"),
),
StringSubInvertible((r".attention.output.dense", ".mha.output")),
StringSubInvertible((r".attention.output.LayerNorm", ".attn_residual_layer_norm")),
StringTransformations.regex_sub_invertible(
(r".attention.output.dense", ".mha.output")
),
StringTransformations.regex_sub_invertible(
(r".attention.output.LayerNorm", ".attn_residual_layer_norm")
),
# Pointwise feed-forward layers.
StringSubInvertible((r".intermediate.dense", ".ffn.intermediate")),
StringSubRegEx(
StringTransformations.regex_sub_invertible(
(r".intermediate.dense", ".ffn.intermediate")
),
StringTransformations.regex_sub(
(r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"),
(r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"),
),
StringSubRegEx(
StringTransformations.regex_sub(
(r"(\.\d+)\.output\.dense", "\\1.ffn.output"),
(r"(\.\d+)\.ffn\.output", "\\1.output.dense"),
),
# Embeddings.
StringReplace(
StringTransformations.replace(
"embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight"
),
StringReplace(
StringTransformations.replace(
"embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight"
),
StringReplace(
StringTransformations.replace(
"embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"
),
StringReplace(
StringTransformations.replace(
"embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight"
),
StringReplace(
StringTransformations.replace(
"embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias"
),
]
Expand Down
47 changes: 24 additions & 23 deletions curated_transformers/models/falcon/_hf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from typing import Any, Callable, Dict, List, Tuple, Union

from ...util.string import (
StringRemovePrefix,
StringSubInvertible,
StringSubRegEx,
StringTransform,
)
from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import process_hf_keys
from .config import FalconConfig

Expand All @@ -16,32 +11,38 @@

# Order-dependent.
COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
StringSubRegEx((r"^h\.", "layers."), (r"^layers\.", "h.")),
StringSubInvertible((r"decoder.h.", "decoder.layers.")),
StringTransformations.regex_sub((r"^h\.", "layers."), (r"^layers\.", "h.")),
StringTransformations.regex_sub_invertible((r"decoder.h.", "decoder.layers.")),
# Attention blocks.
StringSubInvertible((r".self_attention", ".mha")),
StringSubInvertible((r".mha.query_key_value", ".mha.input")),
StringSubInvertible((r".mha.dense", ".mha.output")),
StringTransformations.regex_sub_invertible((r".self_attention", ".mha")),
StringTransformations.regex_sub_invertible((r".mha.query_key_value", ".mha.input")),
StringTransformations.regex_sub_invertible((r".mha.dense", ".mha.output")),
# Pointwise feedforward.
StringSubInvertible((r".mlp", ".ffn")),
StringSubInvertible((r".dense_h_to_4h", ".intermediate")),
StringSubInvertible((r".ffn.dense_4h_to_h", ".ffn.output")),
StringTransformations.regex_sub_invertible((r".mlp", ".ffn")),
StringTransformations.regex_sub_invertible((r".dense_h_to_4h", ".intermediate")),
StringTransformations.regex_sub_invertible((r".ffn.dense_4h_to_h", ".ffn.output")),
# Layer norms.
StringSubInvertible((r".input_layernorm", ".attn_layer_norm")),
StringSubInvertible((r".ln_attn", ".attn_input_layer_norm")),
StringSubInvertible((r".post_attention_layernorm", ".ffn_layer_norm")),
StringSubInvertible((r".ln_mlp", ".ffn_input_layer_norm")),
StringSubInvertible((r"ln_f.", "output_layer_norm.")),
StringTransformations.regex_sub_invertible(
(r".input_layernorm", ".attn_layer_norm")
),
StringTransformations.regex_sub_invertible((r".ln_attn", ".attn_input_layer_norm")),
StringTransformations.regex_sub_invertible(
(r".post_attention_layernorm", ".ffn_layer_norm")
),
StringTransformations.regex_sub_invertible((r".ln_mlp", ".ffn_input_layer_norm")),
StringTransformations.regex_sub_invertible((r"ln_f.", "output_layer_norm.")),
# Embeddings.
StringSubInvertible((r"word_embeddings.", "embeddings.piece_embeddings.")),
StringSubInvertible((r"lm_head.", "output_embeddings.")),
StringTransformations.regex_sub_invertible(
(r"word_embeddings.", "embeddings.piece_embeddings.")
),
StringTransformations.regex_sub_invertible((r"lm_head.", "output_embeddings.")),
]

DECODER_HF_PARAM_KEY_TRANSFORMS = [
StringRemovePrefix("transformer.", reversible=False)
StringTransformations.remove_prefix("transformer.", reversible=False)
] + COMMON_HF_PARAM_KEY_TRANSFORMS
CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = [
StringSubInvertible((r"transformer.", "decoder.")),
StringTransformations.regex_sub_invertible((r"transformer.", "decoder.")),
] + COMMON_HF_PARAM_KEY_TRANSFORMS

HF_CONFIG_KEY_MAPPING_REFINED_WEB_MODEL: Dict[str, Union[str, Tuple[str, Callable]]] = {
Expand Down
36 changes: 22 additions & 14 deletions curated_transformers/models/gpt_neox/_hf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, List, Tuple, Union

from ...layers.activations import Activation
from ...util.string import StringRemovePrefix, StringSubInvertible, StringTransform
from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import process_hf_keys
from .config import GPTNeoXConfig

Expand All @@ -11,26 +11,34 @@

# Order-dependent.
COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
StringSubInvertible((r"gpt_neox", "decoder")),
StringTransformations.regex_sub_invertible((r"gpt_neox", "decoder")),
# Attention blocks.
StringSubInvertible((r".attention", ".mha")),
StringSubInvertible((r".mha.query_key_value", ".mha.input")),
StringSubInvertible((r".mha.dense", ".mha.output")),
StringTransformations.regex_sub_invertible((r".attention", ".mha")),
StringTransformations.regex_sub_invertible((r".mha.query_key_value", ".mha.input")),
StringTransformations.regex_sub_invertible((r".mha.dense", ".mha.output")),
# Pointwise feedforward.
StringSubInvertible((r".mlp", ".ffn")),
StringSubInvertible((r".dense_h_to_4h", ".intermediate")),
StringSubInvertible((r".ffn.dense_4h_to_h", ".ffn.output")),
StringTransformations.regex_sub_invertible((r".mlp", ".ffn")),
StringTransformations.regex_sub_invertible((r".dense_h_to_4h", ".intermediate")),
StringTransformations.regex_sub_invertible((r".ffn.dense_4h_to_h", ".ffn.output")),
# Layer norms.
StringSubInvertible((r".input_layernorm", ".attn_input_layer_norm")),
StringSubInvertible((r".post_attention_layernorm", ".ffn_input_layer_norm")),
StringSubInvertible((r"final_layer_norm.", "output_layer_norm.")),
StringTransformations.regex_sub_invertible(
(r".input_layernorm", ".attn_input_layer_norm")
),
StringTransformations.regex_sub_invertible(
(r".post_attention_layernorm", ".ffn_input_layer_norm")
),
StringTransformations.regex_sub_invertible(
(r"final_layer_norm.", "output_layer_norm.")
),
# Embeddings.
StringSubInvertible((r"embed_in.", "embeddings.piece_embeddings.")),
StringSubInvertible((r"embed_out.", "output_embeddings.")),
StringTransformations.regex_sub_invertible(
(r"embed_in.", "embeddings.piece_embeddings.")
),
StringTransformations.regex_sub_invertible((r"embed_out.", "output_embeddings.")),
]

DECODER_HF_PARAM_KEY_TRANSFORMS = [
StringRemovePrefix("gpt_neox.", reversible=False)
StringTransformations.remove_prefix("gpt_neox.", reversible=False)
] + COMMON_HF_PARAM_KEY_TRANSFORMS
CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = COMMON_HF_PARAM_KEY_TRANSFORMS

Expand Down
2 changes: 1 addition & 1 deletion curated_transformers/models/hf_hub/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import Tensor

from ...util.string import StringTransform
from ...util.string import StringTransform, StringTransformations


def process_hf_keys(
Expand Down
45 changes: 23 additions & 22 deletions curated_transformers/models/llama/_hf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from typing import Any, Callable, Dict, List, Tuple, Union

from ...layers.activations import Activation
from ...util.string import (
StringRemovePrefix,
StringSubInvertible,
StringSubRegEx,
StringTransform,
)
from ...util.string import StringTransform, StringTransformations
from ..hf_hub.conversion import process_hf_keys
from .config import LlamaConfig

Expand All @@ -17,33 +12,39 @@
# Order-dependent.
COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [
# Attention blocks.
StringSubInvertible((r".self_attn", ".mha")),
StringSubInvertible((r".q_proj", ".query")),
StringSubInvertible((r".k_proj", ".key")),
StringSubInvertible((r".v_proj", ".value")),
StringSubInvertible((r".o_proj", ".output")),
StringTransformations.regex_sub_invertible((r".self_attn", ".mha")),
StringTransformations.regex_sub_invertible((r".q_proj", ".query")),
StringTransformations.regex_sub_invertible((r".k_proj", ".key")),
StringTransformations.regex_sub_invertible((r".v_proj", ".value")),
StringTransformations.regex_sub_invertible((r".o_proj", ".output")),
# Pointwise feedforward
StringSubInvertible((r".mlp", ".ffn")),
StringSubInvertible((r".up_proj", ".intermediate")),
StringSubInvertible((r"ffn.down_proj", "ffn.output")),
StringSubInvertible((r".gate_proj", ".gate")),
StringTransformations.regex_sub_invertible((r".mlp", ".ffn")),
StringTransformations.regex_sub_invertible((r".up_proj", ".intermediate")),
StringTransformations.regex_sub_invertible((r"ffn.down_proj", "ffn.output")),
StringTransformations.regex_sub_invertible((r".gate_proj", ".gate")),
# RMS norms
StringSubInvertible((r".input_layernorm", ".attn_input_layer_norm")),
StringSubInvertible((r".post_attention_layernorm", ".ffn_input_layer_norm")),
StringSubRegEx(
StringTransformations.regex_sub_invertible(
(r".input_layernorm", ".attn_input_layer_norm")
),
StringTransformations.regex_sub_invertible(
(r".post_attention_layernorm", ".ffn_input_layer_norm")
),
StringTransformations.regex_sub(
(r"^(decoder\.)?norm\.", "\\1output_layer_norm."),
(r"^(decoder\.)?output_layer_norm\.", "\\1norm."),
),
# Embeddings
StringSubInvertible((r"embed_tokens.", "embeddings.piece_embeddings.")),
StringSubInvertible((r"lm_head.", "output_embeddings.")),
StringTransformations.regex_sub_invertible(
(r"embed_tokens.", "embeddings.piece_embeddings.")
),
StringTransformations.regex_sub_invertible((r"lm_head.", "output_embeddings.")),
]

DECODER_HF_PARAM_KEY_TRANSFORMS = [
StringRemovePrefix("model.", reversible=False)
StringTransformations.remove_prefix("model.", reversible=False)
] + COMMON_HF_PARAM_KEY_TRANSFORMS
CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = [
StringSubInvertible((r"model.", "decoder."))
StringTransformations.regex_sub_invertible((r"model.", "decoder."))
] + COMMON_HF_PARAM_KEY_TRANSFORMS

HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = {
Expand Down
Loading

0 comments on commit ed22421

Please sign in to comment.