From ce2e3bcdcdcf199e0e1775fbbab09aa1d42c9664 Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 4 Mar 2025 10:29:52 +0530 Subject: [PATCH 1/4] Added RoBERTa converter --- .../src/utils/transformers/convert_roberta.py | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 keras_hub/src/utils/transformers/convert_roberta.py diff --git a/keras_hub/src/utils/transformers/convert_roberta.py b/keras_hub/src/utils/transformers/convert_roberta.py new file mode 100644 index 0000000000..1946b7d2ea --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_roberta.py @@ -0,0 +1,136 @@ +import numpy as np + +from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone +from keras_hub.src.utils.preset_utils import HF_TOKENIZER_CONFIG_FILE +from keras_hub.src.utils.preset_utils import get_file +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = RobertaBackbone + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_heads": transformers_config["num_attention_heads"], + "hidden_dim": transformers_config["hidden_size"], + "intermediate_dim": transformers_config["intermediate_size"], + } + + +def convert_weights(backbone, loader, transformers_config): + # Embedding layer + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="roberta.embeddings.word_embeddings.weight", + ) + loader.port_weight( + keras_variable=backbone.get_layer("position_embedding").position_embeddings, + hf_weight_key="roberta.embeddings.position_embeddings.weight", + ) + # Roberta does not use segment embeddings + loader.port_weight( + keras_variable=backbone.get_layer("embeddings_layer_norm").beta, + hf_weight_key="roberta.embeddings.LayerNorm.beta", + ) + loader.port_weight( + keras_variable=backbone.get_layer("embeddings_layer_norm").gamma, + hf_weight_key="roberta.embeddings.LayerNorm.gamma", + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + # Attention blocks + for i in range(backbone.num_layers): + block = backbone.get_layer(f"transformer_layer_{i}") + attn = block._self_attention_layer + hf_prefix = "roberta.encoder.layer." + # Attention layers + loader.port_weight( + keras_variable=attn.query_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.attention.self.query.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=attn.query_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.attention.self.query.bias", + hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), + ) + loader.port_weight( + keras_variable=attn.key_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.attention.self.key.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=attn.key_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.attention.self.key.bias", + hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), + ) + loader.port_weight( + keras_variable=attn.value_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.attention.self.value.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=attn.value_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.attention.self.value.bias", + hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), + ) + loader.port_weight( + keras_variable=attn.output_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.attention.output.dense.weight", + hook_fn=transpose_and_reshape, + ) + loader.port_weight( + keras_variable=attn.output_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.attention.output.dense.bias", + hook_fn=lambda hf_tensor, shape: np.reshape(hf_tensor, shape), + ) + # Attention layer norm. + loader.port_weight( + keras_variable=block._self_attention_layer_norm.beta, + hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.beta", + ) + loader.port_weight( + keras_variable=block._self_attention_layer_norm.gamma, + hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.gamma", + ) + # MLP layers + loader.port_weight( + keras_variable=block._feedforward_intermediate_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.intermediate.dense.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=block._feedforward_intermediate_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.intermediate.dense.bias", + ) + loader.port_weight( + keras_variable=block._feedforward_output_dense.kernel, + hf_weight_key=f"{hf_prefix}{i}.output.dense.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=block._feedforward_output_dense.bias, + hf_weight_key=f"{hf_prefix}{i}.output.dense.bias", + ) + # Output layer norm. + loader.port_weight( + keras_variable=block._feedforward_layer_norm.beta, + hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.beta", + ) + loader.port_weight( + keras_variable=block._feedforward_layer_norm.gamma, + hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.gamma", + ) + # Roberta does not use a pooler layer + + +def convert_tokenizer(cls, preset, **kwargs): + transformers_config = load_json(preset, HF_TOKENIZER_CONFIG_FILE) + return cls( + get_file(preset, "vocab.txt"), + lowercase=transformers_config["do_lower_case"], + **kwargs, + ) From 627ad1bfcc6255d116626d83e9b6c07d3d1fb013 Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 4 Mar 2025 10:30:21 +0530 Subject: [PATCH 2/4] Add RoBERTa coverter tests --- .../transformers/convert_roberta_test.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 keras_hub/src/utils/transformers/convert_roberta_test.py diff --git a/keras_hub/src/utils/transformers/convert_roberta_test.py b/keras_hub/src/utils/transformers/convert_roberta_test.py new file mode 100644 index 0000000000..03653765dc --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_roberta_test.py @@ -0,0 +1,31 @@ +import pytest + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone +from keras_hub.src.models.roberta.roberta_text_classifier import RobertaTextClassifier +from keras_hub.src.models.text_classifier import TextClassifier +from keras_hub.src.tests.test_case import TestCase + + +class TestTask(TestCase): + @pytest.mark.large + def test_convert_tiny_preset(self): + model = RobertaTextClassifier.from_preset("hf://FacebookAI/roberta-base", num_classes=2) + prompt = "That movies was terrible." + model.predict([prompt]) + + @pytest.mark.large + def test_class_detection(self): + model = TextClassifier.from_preset( + "hf://FacebookAI/roberta-base", + num_classes=2, + load_weights=False, + ) + self.assertIsInstance(model, RobertaTextClassifier) + model = Backbone.from_preset( + "hf://FacebookAI/roberta-base", + load_weights=False, + ) + self.assertIsInstance(model, RobertaBackbone) + + # TODO: compare numerics with huggingface model From 2035ba3b129435e0e7a114ec6c08352f7786a5f2 Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 4 Mar 2025 10:34:49 +0530 Subject: [PATCH 3/4] Added RoBERTA in preset_loader --- keras_hub/src/utils/transformers/preset_loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index a3c46f4cf8..65143c86e9 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -12,6 +12,7 @@ from keras_hub.src.utils.transformers import convert_llama3 from keras_hub.src.utils.transformers import convert_mistral from keras_hub.src.utils.transformers import convert_pali_gemma +from keras_hub.src.utils.transformers import convert_roberta from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -39,6 +40,8 @@ def __init__(self, preset, config): self.converter = convert_mistral elif model_type == "paligemma": self.converter = convert_pali_gemma + elif model_type == "roberta": + self.converter = convert_roberta elif model_type == "vit": self.converter = convert_vit else: From 269fbd7590812b20ae0af19ea7760af1bf0a2d51 Mon Sep 17 00:00:00 2001 From: Omkar Kabde Date: Tue, 4 Mar 2025 11:20:45 +0530 Subject: [PATCH 4/4] fix key names --- .../src/utils/transformers/convert_roberta.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/utils/transformers/convert_roberta.py b/keras_hub/src/utils/transformers/convert_roberta.py index 1946b7d2ea..17d015eba2 100644 --- a/keras_hub/src/utils/transformers/convert_roberta.py +++ b/keras_hub/src/utils/transformers/convert_roberta.py @@ -21,21 +21,23 @@ def convert_backbone_config(transformers_config): def convert_weights(backbone, loader, transformers_config): # Embedding layer loader.port_weight( - keras_variable=backbone.get_layer("token_embedding").embeddings, + keras_variable=backbone.get_layer("embeddings").token_embedding.embeddings, hf_weight_key="roberta.embeddings.word_embeddings.weight", ) loader.port_weight( - keras_variable=backbone.get_layer("position_embedding").position_embeddings, + keras_variable=backbone.get_layer("embeddings").position_embedding.position_embeddings, hf_weight_key="roberta.embeddings.position_embeddings.weight", + hook_fn=lambda hf_tensor, _: hf_tensor[:512], # Take only first 512 positions ) + # Roberta does not use segment embeddings loader.port_weight( keras_variable=backbone.get_layer("embeddings_layer_norm").beta, - hf_weight_key="roberta.embeddings.LayerNorm.beta", + hf_weight_key="roberta.embeddings.LayerNorm.bias", ) loader.port_weight( keras_variable=backbone.get_layer("embeddings_layer_norm").gamma, - hf_weight_key="roberta.embeddings.LayerNorm.gamma", + hf_weight_key="roberta.embeddings.LayerNorm.weight", ) def transpose_and_reshape(x, shape): @@ -90,11 +92,11 @@ def transpose_and_reshape(x, shape): # Attention layer norm. loader.port_weight( keras_variable=block._self_attention_layer_norm.beta, - hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.beta", + hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.bias", ) loader.port_weight( keras_variable=block._self_attention_layer_norm.gamma, - hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.gamma", + hf_weight_key=f"{hf_prefix}{i}.attention.output.LayerNorm.weight", ) # MLP layers loader.port_weight( @@ -118,11 +120,11 @@ def transpose_and_reshape(x, shape): # Output layer norm. loader.port_weight( keras_variable=block._feedforward_layer_norm.beta, - hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.beta", + hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.bias", ) loader.port_weight( keras_variable=block._feedforward_layer_norm.gamma, - hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.gamma", + hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.weight", ) # Roberta does not use a pooler layer