diff --git a/curated_transformers/models/albert/_hf.py b/curated_transformers/models/albert/_hf.py index 7b9f7fdd..8fc0d705 100644 --- a/curated_transformers/models/albert/_hf.py +++ b/curated_transformers/models/albert/_hf.py @@ -1,57 +1,59 @@ from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ...util.string import ( - StringRemovePrefix, - StringReplace, - StringSubInvertible, - StringSubRegEx, - StringTransform, -) +from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import process_hf_keys from .config import ALBERTConfig # Order-dependent. HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ # Prefixes. - StringRemovePrefix("albert.", reversible=False), - StringSubRegEx( + StringTransformations.remove_prefix("albert.", reversible=False), + StringTransformations.regex_sub( (r"^encoder\.(embedding_|albert_layer)", "\\1"), (r"^(embedding_|albert_layer)", "encoder.\\1"), ), # Layer groups - StringSubRegEx( + StringTransformations.regex_sub( (r"^albert_layer_groups\.", "groups."), (r"^groups\.", "albert_layer_groups.") ), # Inner layers. - StringSubInvertible((".albert_layers.", ".group_layers.")), + StringTransformations.regex_sub_invertible((".albert_layers.", ".group_layers.")), # Attention blocks. - StringSubInvertible((".attention.", ".mha.")), - StringSubInvertible((".mha.LayerNorm", ".attn_residual_layer_norm")), - StringSubInvertible((".mha.dense", ".mha.output")), + StringTransformations.regex_sub_invertible((".attention.", ".mha.")), + StringTransformations.regex_sub_invertible( + (".mha.LayerNorm", ".attn_residual_layer_norm") + ), + StringTransformations.regex_sub_invertible((".mha.dense", ".mha.output")), # Pointwise feed-forward layers. - StringSubInvertible((".ffn.", ".ffn.intermediate.")), - StringSubInvertible((".ffn_output.", ".ffn.output.")), - StringSubInvertible((".full_layer_layer_norm.", ".ffn_residual_layer_norm.")), + StringTransformations.regex_sub_invertible((".ffn.", ".ffn.intermediate.")), + StringTransformations.regex_sub_invertible((".ffn_output.", ".ffn.output.")), + StringTransformations.regex_sub_invertible( + (".full_layer_layer_norm.", ".ffn_residual_layer_norm.") + ), # Embeddings. - StringReplace( + StringTransformations.replace( "embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias" ), # Embedding projection. - StringReplace("embedding_hidden_mapping_in.weight", "embeddings.projection.weight"), - StringReplace("embedding_hidden_mapping_in.bias", "embeddings.projection.bias"), + StringTransformations.replace( + "embedding_hidden_mapping_in.weight", "embeddings.projection.weight" + ), + StringTransformations.replace( + "embedding_hidden_mapping_in.bias", "embeddings.projection.bias" + ), ] HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = { diff --git a/curated_transformers/models/bert/_hf.py b/curated_transformers/models/bert/_hf.py index 082b2fe1..ab482c94 100644 --- a/curated_transformers/models/bert/_hf.py +++ b/curated_transformers/models/bert/_hf.py @@ -1,60 +1,60 @@ from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ...util.string import ( - StringRemovePrefix, - StringReplace, - StringSubInvertible, - StringSubRegEx, - StringTransform, -) +from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import process_hf_keys from .config import BERTConfig # Order-dependent. HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ # Old HF parameter names (one-way transforms). - StringSubRegEx((r"\.gamma$", ".weight"), backward=None), - StringSubRegEx((r"\.beta$", ".bias"), backward=None), + StringTransformations.regex_sub((r"\.gamma$", ".weight"), backward=None), + StringTransformations.regex_sub((r"\.beta$", ".bias"), backward=None), # Prefixes. - StringRemovePrefix("bert.", reversible=False), - StringSubRegEx( + StringTransformations.remove_prefix("bert.", reversible=False), + StringTransformations.regex_sub( (r"^encoder\.(layer\.)", "\\1"), (r"^(layer\.)", "encoder.\\1"), ), # Layers. - StringSubRegEx((r"^layer", "layers"), (r"^layers", "layer")), + StringTransformations.regex_sub((r"^layer", "layers"), (r"^layers", "layer")), # Attention blocks. - StringSubRegEx( + StringTransformations.regex_sub( (r"\.attention\.self\.(query|key|value)", ".mha.\\1"), (r"\.mha\.(query|key|value)", ".attention.self.\\1"), ), - StringSubInvertible((r".attention.output.dense", ".mha.output")), - StringSubInvertible((r".attention.output.LayerNorm", ".attn_residual_layer_norm")), + StringTransformations.regex_sub_invertible( + (r".attention.output.dense", ".mha.output") + ), + StringTransformations.regex_sub_invertible( + (r".attention.output.LayerNorm", ".attn_residual_layer_norm") + ), # Pointwise feed-forward layers. - StringSubInvertible((r".intermediate.dense", ".ffn.intermediate")), - StringSubRegEx( + StringTransformations.regex_sub_invertible( + (r".intermediate.dense", ".ffn.intermediate") + ), + StringTransformations.regex_sub( (r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"), (r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"), ), - StringSubRegEx( + StringTransformations.regex_sub( (r"(\.\d+)\.output\.dense", "\\1.ffn.output"), (r"(\.\d+)\.ffn\.output", "\\1.output.dense"), ), # Embeddings. - StringReplace( + StringTransformations.replace( "embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias" ), ] diff --git a/curated_transformers/models/falcon/_hf.py b/curated_transformers/models/falcon/_hf.py index 0d257bf6..7113b821 100644 --- a/curated_transformers/models/falcon/_hf.py +++ b/curated_transformers/models/falcon/_hf.py @@ -1,11 +1,6 @@ from typing import Any, Callable, Dict, List, Tuple, Union -from ...util.string import ( - StringRemovePrefix, - StringSubInvertible, - StringSubRegEx, - StringTransform, -) +from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import process_hf_keys from .config import FalconConfig @@ -16,32 +11,38 @@ # Order-dependent. COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ - StringSubRegEx((r"^h\.", "layers."), (r"^layers\.", "h.")), - StringSubInvertible((r"decoder.h.", "decoder.layers.")), + StringTransformations.regex_sub((r"^h\.", "layers."), (r"^layers\.", "h.")), + StringTransformations.regex_sub_invertible((r"decoder.h.", "decoder.layers.")), # Attention blocks. - StringSubInvertible((r".self_attention", ".mha")), - StringSubInvertible((r".mha.query_key_value", ".mha.input")), - StringSubInvertible((r".mha.dense", ".mha.output")), + StringTransformations.regex_sub_invertible((r".self_attention", ".mha")), + StringTransformations.regex_sub_invertible((r".mha.query_key_value", ".mha.input")), + StringTransformations.regex_sub_invertible((r".mha.dense", ".mha.output")), # Pointwise feedforward. - StringSubInvertible((r".mlp", ".ffn")), - StringSubInvertible((r".dense_h_to_4h", ".intermediate")), - StringSubInvertible((r".ffn.dense_4h_to_h", ".ffn.output")), + StringTransformations.regex_sub_invertible((r".mlp", ".ffn")), + StringTransformations.regex_sub_invertible((r".dense_h_to_4h", ".intermediate")), + StringTransformations.regex_sub_invertible((r".ffn.dense_4h_to_h", ".ffn.output")), # Layer norms. - StringSubInvertible((r".input_layernorm", ".attn_layer_norm")), - StringSubInvertible((r".ln_attn", ".attn_input_layer_norm")), - StringSubInvertible((r".post_attention_layernorm", ".ffn_layer_norm")), - StringSubInvertible((r".ln_mlp", ".ffn_input_layer_norm")), - StringSubInvertible((r"ln_f.", "output_layer_norm.")), + StringTransformations.regex_sub_invertible( + (r".input_layernorm", ".attn_layer_norm") + ), + StringTransformations.regex_sub_invertible((r".ln_attn", ".attn_input_layer_norm")), + StringTransformations.regex_sub_invertible( + (r".post_attention_layernorm", ".ffn_layer_norm") + ), + StringTransformations.regex_sub_invertible((r".ln_mlp", ".ffn_input_layer_norm")), + StringTransformations.regex_sub_invertible((r"ln_f.", "output_layer_norm.")), # Embeddings. - StringSubInvertible((r"word_embeddings.", "embeddings.piece_embeddings.")), - StringSubInvertible((r"lm_head.", "output_embeddings.")), + StringTransformations.regex_sub_invertible( + (r"word_embeddings.", "embeddings.piece_embeddings.") + ), + StringTransformations.regex_sub_invertible((r"lm_head.", "output_embeddings.")), ] DECODER_HF_PARAM_KEY_TRANSFORMS = [ - StringRemovePrefix("transformer.", reversible=False) + StringTransformations.remove_prefix("transformer.", reversible=False) ] + COMMON_HF_PARAM_KEY_TRANSFORMS CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = [ - StringSubInvertible((r"transformer.", "decoder.")), + StringTransformations.regex_sub_invertible((r"transformer.", "decoder.")), ] + COMMON_HF_PARAM_KEY_TRANSFORMS HF_CONFIG_KEY_MAPPING_REFINED_WEB_MODEL: Dict[str, Union[str, Tuple[str, Callable]]] = { diff --git a/curated_transformers/models/gpt_neox/_hf.py b/curated_transformers/models/gpt_neox/_hf.py index 79377c62..0dc34167 100644 --- a/curated_transformers/models/gpt_neox/_hf.py +++ b/curated_transformers/models/gpt_neox/_hf.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ...util.string import StringRemovePrefix, StringSubInvertible, StringTransform +from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import process_hf_keys from .config import GPTNeoXConfig @@ -11,26 +11,34 @@ # Order-dependent. COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ - StringSubInvertible((r"gpt_neox", "decoder")), + StringTransformations.regex_sub_invertible((r"gpt_neox", "decoder")), # Attention blocks. - StringSubInvertible((r".attention", ".mha")), - StringSubInvertible((r".mha.query_key_value", ".mha.input")), - StringSubInvertible((r".mha.dense", ".mha.output")), + StringTransformations.regex_sub_invertible((r".attention", ".mha")), + StringTransformations.regex_sub_invertible((r".mha.query_key_value", ".mha.input")), + StringTransformations.regex_sub_invertible((r".mha.dense", ".mha.output")), # Pointwise feedforward. - StringSubInvertible((r".mlp", ".ffn")), - StringSubInvertible((r".dense_h_to_4h", ".intermediate")), - StringSubInvertible((r".ffn.dense_4h_to_h", ".ffn.output")), + StringTransformations.regex_sub_invertible((r".mlp", ".ffn")), + StringTransformations.regex_sub_invertible((r".dense_h_to_4h", ".intermediate")), + StringTransformations.regex_sub_invertible((r".ffn.dense_4h_to_h", ".ffn.output")), # Layer norms. - StringSubInvertible((r".input_layernorm", ".attn_input_layer_norm")), - StringSubInvertible((r".post_attention_layernorm", ".ffn_input_layer_norm")), - StringSubInvertible((r"final_layer_norm.", "output_layer_norm.")), + StringTransformations.regex_sub_invertible( + (r".input_layernorm", ".attn_input_layer_norm") + ), + StringTransformations.regex_sub_invertible( + (r".post_attention_layernorm", ".ffn_input_layer_norm") + ), + StringTransformations.regex_sub_invertible( + (r"final_layer_norm.", "output_layer_norm.") + ), # Embeddings. - StringSubInvertible((r"embed_in.", "embeddings.piece_embeddings.")), - StringSubInvertible((r"embed_out.", "output_embeddings.")), + StringTransformations.regex_sub_invertible( + (r"embed_in.", "embeddings.piece_embeddings.") + ), + StringTransformations.regex_sub_invertible((r"embed_out.", "output_embeddings.")), ] DECODER_HF_PARAM_KEY_TRANSFORMS = [ - StringRemovePrefix("gpt_neox.", reversible=False) + StringTransformations.remove_prefix("gpt_neox.", reversible=False) ] + COMMON_HF_PARAM_KEY_TRANSFORMS CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = COMMON_HF_PARAM_KEY_TRANSFORMS diff --git a/curated_transformers/models/hf_hub/conversion.py b/curated_transformers/models/hf_hub/conversion.py index 5241353a..0b3258a2 100644 --- a/curated_transformers/models/hf_hub/conversion.py +++ b/curated_transformers/models/hf_hub/conversion.py @@ -2,7 +2,7 @@ from torch import Tensor -from ...util.string import StringTransform +from ...util.string import StringTransform, StringTransformations def process_hf_keys( diff --git a/curated_transformers/models/llama/_hf.py b/curated_transformers/models/llama/_hf.py index 1dcb0e52..ebe7cd31 100644 --- a/curated_transformers/models/llama/_hf.py +++ b/curated_transformers/models/llama/_hf.py @@ -1,12 +1,7 @@ from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ...util.string import ( - StringRemovePrefix, - StringSubInvertible, - StringSubRegEx, - StringTransform, -) +from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import process_hf_keys from .config import LlamaConfig @@ -17,33 +12,39 @@ # Order-dependent. COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ # Attention blocks. - StringSubInvertible((r".self_attn", ".mha")), - StringSubInvertible((r".q_proj", ".query")), - StringSubInvertible((r".k_proj", ".key")), - StringSubInvertible((r".v_proj", ".value")), - StringSubInvertible((r".o_proj", ".output")), + StringTransformations.regex_sub_invertible((r".self_attn", ".mha")), + StringTransformations.regex_sub_invertible((r".q_proj", ".query")), + StringTransformations.regex_sub_invertible((r".k_proj", ".key")), + StringTransformations.regex_sub_invertible((r".v_proj", ".value")), + StringTransformations.regex_sub_invertible((r".o_proj", ".output")), # Pointwise feedforward - StringSubInvertible((r".mlp", ".ffn")), - StringSubInvertible((r".up_proj", ".intermediate")), - StringSubInvertible((r"ffn.down_proj", "ffn.output")), - StringSubInvertible((r".gate_proj", ".gate")), + StringTransformations.regex_sub_invertible((r".mlp", ".ffn")), + StringTransformations.regex_sub_invertible((r".up_proj", ".intermediate")), + StringTransformations.regex_sub_invertible((r"ffn.down_proj", "ffn.output")), + StringTransformations.regex_sub_invertible((r".gate_proj", ".gate")), # RMS norms - StringSubInvertible((r".input_layernorm", ".attn_input_layer_norm")), - StringSubInvertible((r".post_attention_layernorm", ".ffn_input_layer_norm")), - StringSubRegEx( + StringTransformations.regex_sub_invertible( + (r".input_layernorm", ".attn_input_layer_norm") + ), + StringTransformations.regex_sub_invertible( + (r".post_attention_layernorm", ".ffn_input_layer_norm") + ), + StringTransformations.regex_sub( (r"^(decoder\.)?norm\.", "\\1output_layer_norm."), (r"^(decoder\.)?output_layer_norm\.", "\\1norm."), ), # Embeddings - StringSubInvertible((r"embed_tokens.", "embeddings.piece_embeddings.")), - StringSubInvertible((r"lm_head.", "output_embeddings.")), + StringTransformations.regex_sub_invertible( + (r"embed_tokens.", "embeddings.piece_embeddings.") + ), + StringTransformations.regex_sub_invertible((r"lm_head.", "output_embeddings.")), ] DECODER_HF_PARAM_KEY_TRANSFORMS = [ - StringRemovePrefix("model.", reversible=False) + StringTransformations.remove_prefix("model.", reversible=False) ] + COMMON_HF_PARAM_KEY_TRANSFORMS CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = [ - StringSubInvertible((r"model.", "decoder.")) + StringTransformations.regex_sub_invertible((r"model.", "decoder.")) ] + COMMON_HF_PARAM_KEY_TRANSFORMS HF_CONFIG_KEY_MAPPING: Dict[str, Union[str, Tuple[str, Callable]]] = { diff --git a/curated_transformers/models/mpt/_hf.py b/curated_transformers/models/mpt/_hf.py index 28966cd8..2928dc65 100644 --- a/curated_transformers/models/mpt/_hf.py +++ b/curated_transformers/models/mpt/_hf.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, List, Tuple, Union -from ...util.string import StringRemovePrefix, StringSubInvertible, StringTransform +from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import process_hf_keys from .config import MPTConfig @@ -10,26 +10,28 @@ # Order-dependent. COMMON_HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ - StringSubInvertible((r"transformer", "decoder")), - StringSubInvertible((r"blocks", "layers")), + StringTransformations.regex_sub_invertible((r"transformer", "decoder")), + StringTransformations.regex_sub_invertible((r"blocks", "layers")), # Attention blocks. - StringSubInvertible((r".attn", ".mha")), - StringSubInvertible((r".Wqkv", ".input")), - StringSubInvertible((r".out_proj", ".output")), + StringTransformations.regex_sub_invertible((r".attn", ".mha")), + StringTransformations.regex_sub_invertible((r".Wqkv", ".input")), + StringTransformations.regex_sub_invertible((r".out_proj", ".output")), # Pointwise feedforward. - StringSubInvertible((r".up_proj", ".intermediate")), - StringSubInvertible((r"ffn.down_proj", "ffn.output")), + StringTransformations.regex_sub_invertible((r".up_proj", ".intermediate")), + StringTransformations.regex_sub_invertible((r"ffn.down_proj", "ffn.output")), # Layer norms. - StringSubInvertible((r".norm_1", ".attn_input_layer_norm")), - StringSubInvertible((r".norm_2", ".ffn_input_layer_norm")), - StringSubInvertible((r"norm_f.", "output_layer_norm.")), + StringTransformations.regex_sub_invertible((r".norm_1", ".attn_input_layer_norm")), + StringTransformations.regex_sub_invertible((r".norm_2", ".ffn_input_layer_norm")), + StringTransformations.regex_sub_invertible((r"norm_f.", "output_layer_norm.")), # Embeddings. - StringSubInvertible((r"wte.", "embeddings.piece_embeddings.")), + StringTransformations.regex_sub_invertible( + (r"wte.", "embeddings.piece_embeddings.") + ), ] DECODER_HF_PARAM_KEY_TRANSFORMS = [ - StringRemovePrefix("transformer.", reversible=False) + StringTransformations.remove_prefix("transformer.", reversible=False) ] + COMMON_HF_PARAM_KEY_TRANSFORMS CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS = COMMON_HF_PARAM_KEY_TRANSFORMS diff --git a/curated_transformers/models/roberta/_hf.py b/curated_transformers/models/roberta/_hf.py index ab578df5..ea1b08a2 100644 --- a/curated_transformers/models/roberta/_hf.py +++ b/curated_transformers/models/roberta/_hf.py @@ -1,57 +1,57 @@ from typing import Any, Callable, Dict, List, Tuple, Union from ...layers.activations import Activation -from ...util.string import ( - StringRemovePrefix, - StringReplace, - StringSubInvertible, - StringSubRegEx, - StringTransform, -) +from ...util.string import StringTransform, StringTransformations from ..hf_hub.conversion import process_hf_keys from .config import RoBERTaConfig # Order-dependent. HF_PARAM_KEY_TRANSFORMS: List[StringTransform] = [ # Prefixes. - StringRemovePrefix("roberta.", reversible=False), - StringSubRegEx( + StringTransformations.remove_prefix("roberta.", reversible=False), + StringTransformations.regex_sub( (r"^encoder\.(layer\.)", "\\1"), (r"^(layer\.)", "encoder.\\1"), ), # Layers. - StringSubRegEx((r"^layer", "layers"), (r"^layers", "layer")), + StringTransformations.regex_sub((r"^layer", "layers"), (r"^layers", "layer")), # Attention blocks. - StringSubRegEx( + StringTransformations.regex_sub( (r"\.attention\.self\.(query|key|value)", ".mha.\\1"), (r"\.mha\.(query|key|value)", ".attention.self.\\1"), ), - StringSubInvertible((r".attention.output.dense", ".mha.output")), - StringSubInvertible((r".attention.output.LayerNorm", ".attn_residual_layer_norm")), + StringTransformations.regex_sub_invertible( + (r".attention.output.dense", ".mha.output") + ), + StringTransformations.regex_sub_invertible( + (r".attention.output.LayerNorm", ".attn_residual_layer_norm") + ), # Pointwise feed-forward layers. - StringSubInvertible((r".intermediate.dense", ".ffn.intermediate")), - StringSubRegEx( + StringTransformations.regex_sub_invertible( + (r".intermediate.dense", ".ffn.intermediate") + ), + StringTransformations.regex_sub( (r"(\.\d+)\.output\.LayerNorm", "\\1.ffn_residual_layer_norm"), (r"(\.\d+)\.ffn_residual_layer_norm", "\\1.output.LayerNorm"), ), - StringSubRegEx( + StringTransformations.regex_sub( (r"(\.\d+)\.output\.dense", "\\1.ffn.output"), (r"(\.\d+)\.ffn\.output", "\\1.output.dense"), ), # Embeddings. - StringReplace( + StringTransformations.replace( "embeddings.word_embeddings.weight", "embeddings.piece_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.token_type_embeddings.weight", "embeddings.type_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.LayerNorm.weight", "embeddings.embed_output_layer_norm.weight" ), - StringReplace( + StringTransformations.replace( "embeddings.LayerNorm.bias", "embeddings.embed_output_layer_norm.bias" ), ] diff --git a/curated_transformers/util/string.py b/curated_transformers/util/string.py index 9c16acf2..34df4626 100644 --- a/curated_transformers/util/string.py +++ b/curated_transformers/util/string.py @@ -46,6 +46,79 @@ def revert(self, string: str) -> str: return string +class StringTransformations: + """ + Provides factory methods for different string transformations. + """ + + @staticmethod + def regex_sub( + forward: Tuple[str, str], backward: Optional[Tuple[str, str]] + ) -> StringTransform: + """ + Factory method to construct a string substitution transform + using regular expressions. + + :param forward: + Tuple where the first string is a RegEx pattern + and the second the replacement. + + This operation is performed when the :meth:`.apply` + method is invoked. + :param backward: + Optional tuple where the first string is a RegEx pattern + and the second the replacement. + + This operation is performed when the :meth:`.revert` + method is invoked. If ``None``, it is a no-op. + """ + return StringSubRegEx(forward, backward) + + @staticmethod + def regex_sub_invertible(forward: Tuple[str, str]) -> StringTransform: + """ + Factory method to construct a string substitution transform + using regular expressions whose backward transformation can + be automatically derived from the forward transformation. + + :param forward: + Tuple where the first string is string to match + and the second the replacement, neither of which + can contain RegEx meta-characters. + """ + return StringSubInvertible(forward) + + @staticmethod + def replace( + replacee: str, replacement: str, *, reversible: bool = True + ) -> StringTransform: + """ + Factory method to construct a string replacement transform. + + :param replacee: + The full string to be replaced. + :param replacement: + The replacement string. + :param reversible: + If the reverse transformation is to + be performed. + """ + return StringReplace(replacee, replacement, reversible=reversible) + + @staticmethod + def remove_prefix(prefix: str, *, reversible: bool = True) -> StringTransform: + """ + Factory method to construct a string prefix removal transform. + + :param prefix: + Prefix to be removed. + :param reversible: + If the reverse transformation is to + be performed. + """ + return StringRemovePrefix(prefix, reversible=reversible) + + class StringSubRegEx(StringTransform): """ Substitute a substring with another string using @@ -117,6 +190,9 @@ def __init__(self, replacee: str, replacement: str, *, reversible: bool = True): The full string to be replaced. :param replacement: The replacement string. + :param reversible: + If the reverse transformation is to + be performed. """ super().__init__(reversible) self.replacee = replacee @@ -145,7 +221,10 @@ def __init__(self, prefix: str, *, reversible: bool = True): Construct a reversible left strip. :param prefix: - Prefix to be stripped. + Prefix to be removed. + :param reversible: + If the reverse transformation is to + be performed. """ super().__init__(reversible)