-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel sharding #21
Changes from 20 commits
d75ba94
0ee7430
ca88068
a3de4d7
80170a9
9a9bcf8
5bf6c70
9dfb7b6
ec3b752
a7d7c0b
b6fe32e
e13d9ec
6cdede2
cd99226
550e1fb
2215595
fe888a9
dbf11f7
a96903b
6e6b44e
92e9e31
7901d91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,17 @@ | |
import torch | ||
import os | ||
from enum import Enum | ||
from typing import Dict | ||
from loguru import logger | ||
|
||
os.environ["PJRT_DEVICE"] = "TPU" | ||
|
||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
import torch.multiprocessing as mp | ||
|
||
from optimum.tpu.modeling import TpuModelForCausalLM | ||
from typing import Dict | ||
from loguru import logger | ||
from optimum.tpu.modeling import AutoModelForCausalLM | ||
from transformers import PretrainedConfig, AutoConfig | ||
|
||
|
||
class ModelCommand(Enum): | ||
|
@@ -26,6 +27,14 @@ def __init__(self, manager: mp.Manager): | |
self.root_command = manager.list() | ||
self.model_ready = manager.Event() | ||
self.output_data = manager.Value(torch.Tensor, torch.tensor([])) | ||
self.model_config = manager.Value(PretrainedConfig, None) | ||
|
||
@property | ||
def config(self): | ||
while True: | ||
config = self.model_config.get() | ||
if config is not None: | ||
return config | ||
|
||
def send(self, command: ModelCommand, data: Dict = None): | ||
# First wait until model is ready to receive commands | ||
|
@@ -49,6 +58,7 @@ def __init__(self, root_mailbox: RootMailbox): | |
self.root_command = root_mailbox.root_command | ||
self.model_ready = root_mailbox.model_ready | ||
self.output_data = root_mailbox.output_data | ||
self.model_config = root_mailbox.model_config | ||
|
||
def receive(self): | ||
self.root_bell.wait() | ||
|
@@ -80,9 +90,12 @@ def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable): | |
) | ||
|
||
# Model loading and sharding should happen here | ||
model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) | ||
config = AutoConfig.from_pretrained(model_id) | ||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=config.torch_dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have hard time to get why we need to do this way? We are overriding the default behaviour to the default behaviour no? @regisss do you know? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems the default is to load in fp32 whatever the dtype specified in the config is: https://huggingface.slack.com/archives/C014N4749J9/p1712757959601599 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I got some insights on the design for this. It seems that transformers uses the default pytorch type, i.e.: |
||
model = model.eval() | ||
model.to(device) | ||
if rank == 0: | ||
mailbox.model_config.set(model.config) | ||
|
||
def get_next_token(inputs): | ||
# move inputs to device in a new dict to avoid conflicts | ||
|
@@ -152,5 +165,9 @@ def leave(self): | |
logger.debug("Model loop finished") | ||
self.mailbox = None | ||
|
||
@property | ||
def config(self): | ||
return self.mailbox.config | ||
|
||
def __del__(self): | ||
self.leave() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the
torch_dtype=torch_dtype
? It should be taken from the config no?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it doesn't look like it works this way: