Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel sharding #21

Merged
merged 22 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d75ba94
chore: update transformers dependency
tengomucho Apr 5, 2024
0ee7430
feat: import transformer's gemma modeling code
tengomucho Apr 8, 2024
ca88068
chore: rename model Gemma -> TpuGemma to prepare for changes
tengomucho Apr 8, 2024
a3de4d7
feat(DistributedModel): added config property
tengomucho Apr 8, 2024
80170a9
chore: rename test_parallel_proxy.py -> test_distributed_model.py
tengomucho Apr 8, 2024
9a9bcf8
fix: use AutoModelForCausalLM instead of TpuModelForCausalLM
tengomucho Apr 8, 2024
5bf6c70
feat: AutoModelForCausalLM will choose TpuGemmaForCausalLM if possible
tengomucho Apr 8, 2024
9dfb7b6
fix(TpuGemma): avoid using device_map when loading model
tengomucho Apr 8, 2024
ec3b752
feat(gemma): sharding o_proj
tengomucho Apr 8, 2024
a7d7c0b
feat(gemma): sharding on q_proj
tengomucho Apr 8, 2024
b6fe32e
feat(gemma): sharding on k and v proj
tengomucho Apr 9, 2024
e13d9ec
feat(gemma): sharding on mlp gate and up proj
tengomucho Apr 9, 2024
6cdede2
feat(gemma): sharding on mlp down proj
tengomucho Apr 9, 2024
cd99226
feat: model il loaded using pytorch_dtype from config
tengomucho Apr 9, 2024
550e1fb
fix: remove useless import
tengomucho Apr 9, 2024
2215595
feat(tests): added test showing gemma7b sharding and prefill works
tengomucho Apr 9, 2024
fe888a9
chore: config_name_to_class uses config.model_type now
tengomucho Apr 10, 2024
dbf11f7
fix: get_generation_mode is now a method of generation_config
tengomucho Apr 10, 2024
a96903b
fix(TGI server): fix slot.stopped changed after transformers update
tengomucho Apr 10, 2024
6e6b44e
fix(generator): fix sample generation again
tengomucho Apr 10, 2024
92e9e31
fix: better handle torch_dtype
tengomucho Apr 10, 2024
7901d91
fix: remove unused import
tengomucho Apr 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/text-generation/generation_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import platform
from typing import List
import torch_xla.core.xla_model as xm
from optimum.tpu.modeling import TpuModelForCausalLM
from optimum.tpu.modeling import AutoModelForCausalLM
from transformers import AutoTokenizer, StaticCache


Expand Down Expand Up @@ -56,7 +56,7 @@ def main():
model_id = "google/gemma-2b"
torch_dtype = torch.bfloat16

model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the torch_dtype=torch_dtype? It should be taken from the config no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it doesn't look like it works this way:

>>> from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
print(model.config.torch_dtype)
print(model.model.layers[0].self_attn.o_proj.weight.dtype)
>>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████████████████| 2/2 [00:00<00:00,  2.65it/s]
>>> print(model.config.torch_dtype)
torch.bfloat16
>>> print(model.model.layers[0].self_attn.o_proj.weight.dtype)
torch.float32

device = model.device
model = model.eval()

Expand Down
25 changes: 21 additions & 4 deletions optimum/tpu/distributed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import torch
import os
from enum import Enum
from typing import Dict
from loguru import logger

os.environ["PJRT_DEVICE"] = "TPU"

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch.multiprocessing as mp

from optimum.tpu.modeling import TpuModelForCausalLM
from typing import Dict
from loguru import logger
from optimum.tpu.modeling import AutoModelForCausalLM
from transformers import PretrainedConfig, AutoConfig


class ModelCommand(Enum):
Expand All @@ -26,6 +27,14 @@ def __init__(self, manager: mp.Manager):
self.root_command = manager.list()
self.model_ready = manager.Event()
self.output_data = manager.Value(torch.Tensor, torch.tensor([]))
self.model_config = manager.Value(PretrainedConfig, None)

@property
def config(self):
while True:
config = self.model_config.get()
if config is not None:
return config

def send(self, command: ModelCommand, data: Dict = None):
# First wait until model is ready to receive commands
Expand All @@ -49,6 +58,7 @@ def __init__(self, root_mailbox: RootMailbox):
self.root_command = root_mailbox.root_command
self.model_ready = root_mailbox.model_ready
self.output_data = root_mailbox.output_data
self.model_config = root_mailbox.model_config

def receive(self):
self.root_bell.wait()
Expand Down Expand Up @@ -80,9 +90,12 @@ def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
)

# Model loading and sharding should happen here
model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
config = AutoConfig.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=config.torch_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have hard time to get why we need to do this way? We are overriding the default behaviour to the default behaviour no? @regisss do you know?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the default is to load in fp32 whatever the dtype specified in the config is: https://huggingface.slack.com/archives/C014N4749J9/p1712757959601599

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I got some insights on the design for this. It seems that transformers uses the default pytorch type, i.e.: torch.float32. So probably I will need to change this code later, as it might not work if there are models whose weights were not trained in float32/bfloat16. I have seen we cannot use bf16 everywhere already, because some operations cannot be made (I've seen it in a unit test with gpt2). It is probably a custom configuration we need to add to the model. I pushed a fix cleaner than this.

model = model.eval()
model.to(device)
if rank == 0:
mailbox.model_config.set(model.config)

def get_next_token(inputs):
# move inputs to device in a new dict to avoid conflicts
Expand Down Expand Up @@ -152,5 +165,9 @@ def leave(self):
logger.debug("Model loop finished")
self.mailbox = None

@property
def config(self):
return self.mailbox.config

def __del__(self):
self.leave()
2 changes: 1 addition & 1 deletion optimum/tpu/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def create(
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id

generation_mode = model._get_generation_mode(generation_config, None)
generation_mode = generation_config.get_generation_mode()
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
raise ValueError("Unsupported generation mode")

Expand Down
24 changes: 13 additions & 11 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@
from typing import Any

from loguru import logger
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM
from transformers.utils import is_accelerate_available
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM, AutoConfig

from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM


def config_name_to_class(pretrained_model_name_or_path: str):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
if config.model_type == "gemma":
return TpuGemmaForCausalLM
return BaseAutoModelForCausalLM


# TODO: For now TpuModelForCausalLM is just a shallow wrapper of
# AutoModelForCausalLM, later this could be replaced by a custom class.
class AutoModelForCausalLM(BaseAutoModelForCausalLM):

@classmethod
Expand All @@ -45,13 +51,9 @@ def from_pretrained(
logger.debug(f"Device set to: {device}")
else:
device = "xla"
if is_accelerate_available():
model = BaseAutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, device_map=device, *model_args, **kwargs
)
else:
model = BaseAutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model.to(device)
cls = config_name_to_class(pretrained_model_name_or_path)
model = cls.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model.to(device)
# Update config with specific data)
if task is not None or getattr(model.config, "task", None) is None:
model.config.task = task
Expand Down
Loading
Loading