Skip to content

Commit

Permalink
Llama support (#32)
Browse files Browse the repository at this point in the history
* chore(build): remove ruff warning when running `make style_check`

* refactor(modeling): import specific modeling only when required

* chore: import Llama modeling from transformers, to allow loading model

imported from transformers v4.40.1.

* feat: internal LLamaModelforCausalLM selected if possible

* feat(llama): sharding on o_proj

* feat(llama): sharding on q,k,v

* feat(llama): sharding on MLP Linears

* feat(llama): sharding on lm_head

* test: add slow test to verify sharded Llama3-8b can be loaded on TPU

* fix(generation): eos_token_id can be a list in configs

This essentially copies commit 8a4a98d2472b8e0180eb9bd4a1824f983e220811
from optimum-neuron, that fixed the same problem.

* test(tgi): added test to validate Llama3 8b on TGI

* doc(README): include Llama in list of TGI supported models

* refactor(generator): move var initialization out of a loop
  • Loading branch information
tengomucho committed May 3, 2024
1 parent c9937a9 commit 973655d
Show file tree
Hide file tree
Showing 8 changed files with 1,811 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ tpu-tgi-ie:

# Run code quality checks
style_check:
ruff .
ruff check .

style:
ruff check . --fix
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ working closely with Google and Google Cloud to make this a reality.

We currently support a few LLM models targeting text generation scenarios:
- Gemma (2b)
- Llama (soon)
- Llama (8b)
- Mistral (soon)


Expand Down
19 changes: 10 additions & 9 deletions optimum/tpu/generation/token_selector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import logging
from typing import Optional
from typing import List, Optional, Union

import torch
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -42,15 +42,15 @@ def __init__(
mode: GenerationMode,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
eos_token_id: int,
eos_token_ids: Union[int,List[int]],
pad_token_id: int,
logits_warper: Optional[LogitsProcessorList] = None,
seed: Optional[int] = 0,
):
self.mode = mode
self.logits_processor = logits_processor
self.stopping_criteria = stopping_criteria
self.eos_token_id = eos_token_id
self.eos_token_ids = eos_token_ids
self.pad_token_id = pad_token_id
self.logits_warper = logits_warper
xm.set_rng_state(seed)
Expand Down Expand Up @@ -132,13 +132,14 @@ def create(
stopping_criteria = StoppingCriteriaList()
stopping_criteria = model._get_stopping_criteria(generation_config, stopping_criteria=stopping_criteria)

# The generation requires special tokens
eos_token_id = generation_config.eos_token_id
# This is not supposed to happen for any of the models we support
assert eos_token_id is not None and not isinstance(eos_token_id, list)
eos_token_id = generation_config.eos_token_id
assert eos_token_id is not None
# The generation requires special tokens
eos_token_ids = eos_token_id if isinstance(eos_token_id, list) else [eos_token_id]
if generation_config.pad_token_id is None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_ids[0]} for open-end generation.")
generation_config.pad_token_id = eos_token_ids[0]

generation_mode = generation_config.get_generation_mode()
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
Expand All @@ -153,7 +154,7 @@ def create(
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
logits_warper=logits_warper,
eos_token_id=eos_token_id,
eos_token_ids=eos_token_ids,
pad_token_id=generation_config.pad_token_id,
seed=seed,
)
Expand Down
8 changes: 6 additions & 2 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
from transformers import AutoConfig
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM

from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM


def config_name_to_class(pretrained_model_name_or_path: str):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
if config.model_type == "gemma":
from .modeling_gemma import TpuGemmaForCausalLM

return TpuGemmaForCausalLM
if config.model_type == "llama":
from .modeling_llama import LlamaForCausalLM

return LlamaForCausalLM
return BaseAutoModelForCausalLM


Expand Down
Loading

0 comments on commit 973655d

Please sign in to comment.