Skip to content

Commit

Permalink
Support Gemma 2 loading from Hugging Face
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 719670366
talumbau authored and copybara-github committed Jan 25, 2025
1 parent 5ad0128 commit 50f279c
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions ai_edge_torch/generative/examples/gemma/gemma2.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,22 @@
lm_head=None,
)

ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
attn_query_proj="model.layers.{}.self_attn.q_proj",
attn_key_proj="model.layers.{}.self_attn.k_proj",
attn_value_proj="model.layers.{}.self_attn.v_proj",
attn_output_proj="model.layers.{}.self_attn.o_proj",
pre_attn_norm="model.layers.{}.input_layernorm",
post_attn_norm="model.layers.{}.post_attention_layernorm",
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
embedding="model.embed_tokens",
final_norm="model.norm",
)


class Gemma2Block(attention.TransformerBlock):

@@ -281,9 +297,18 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:


def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Gemma2,
)
try:
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Gemma2,
)
except KeyError as ke:
# Also attempt to load with an alternative naming scheme.
return model_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config_2b(**kwargs),
tensor_names=ALT_TENSOR_NAMES,
model_class=Gemma2,
)

0 comments on commit 50f279c

Please sign in to comment.