Skip to content

Commit

Permalink
chore: config_name_to_class uses config.model_type now
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Apr 10, 2024
1 parent 2215595 commit fe888a9
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
from typing import Any

from loguru import logger
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM, AutoConfig

from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM


def config_name_to_class(pretrained_model_name_or_path: str):
if "gemma" in pretrained_model_name_or_path:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
if config.model_type == "gemma":
return TpuGemmaForCausalLM
return BaseAutoModelForCausalLM


class AutoModelForCausalLM(BaseAutoModelForCausalLM):

@classmethod
Expand Down

0 comments on commit fe888a9

Please sign in to comment.