Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama support #32

Merged
merged 13 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ tpu-tgi:

# Run code quality checks
style_check:
ruff .
ruff check .

style:
ruff check . --fix
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ working closely with Google and Google Cloud to make this a reality.

We currently support a few LLM models targeting text generation scenarios:
- Gemma (2b)
- Llama (soon)
- Llama (8b)
- Mistral (soon)


Expand Down
19 changes: 10 additions & 9 deletions optimum/tpu/generation/token_selector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import logging
from typing import Optional
from typing import List, Optional, Union

import torch
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -42,15 +42,15 @@ def __init__(
mode: GenerationMode,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
eos_token_id: int,
eos_token_ids: Union[int,List[int]],
pad_token_id: int,
logits_warper: Optional[LogitsProcessorList] = None,
seed: Optional[int] = 0,
):
self.mode = mode
self.logits_processor = logits_processor
self.stopping_criteria = stopping_criteria
self.eos_token_id = eos_token_id
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
xm.set_rng_state(seed)
Expand Down Expand Up @@ -132,13 +132,14 @@ def create(
stopping_criteria = StoppingCriteriaList()
stopping_criteria = model._get_stopping_criteria(generation_config, stopping_criteria=stopping_criteria)

# The generation requires special tokens
eos_token_id = generation_config.eos_token_id
# This is not supposed to happen for any of the models we support
assert eos_token_id is not None and not isinstance(eos_token_id, list)
eos_token_id = generation_config.eos_token_id
assert eos_token_id is not None
# The generation requires special tokens
eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]
if generation_config.pad_token_id is None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-end generation.")
generation_config.pad_token_id = eos_token_ids[0]

generation_mode = generation_config.get_generation_mode()
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
Expand All @@ -153,7 +154,7 @@ def create(
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
logits_warper=logits_warper,
eos_token_id=eos_token_id,
eos_token_ids=eos_token_ids,
pad_token_id=generation_config.pad_token_id,
seed=seed,
)
Expand Down
8 changes: 6 additions & 2 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
from transformers import AutoConfig
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM

from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM


def config_name_to_class(pretrained_model_name_or_path: str):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
if config.model_type == "gemma":
from .modeling_gemma import TpuGemmaForCausalLM

return TpuGemmaForCausalLM
if config.model_type == "llama":
from .modeling_llama import LlamaForCausalLM

return LlamaForCausalLM
return BaseAutoModelForCausalLM


Expand Down
Loading
Loading