Skip to content

Commit

Permalink
fix: better handle torch_dtype
Browse files Browse the repository at this point in the history
bfloat16 will be set by default in gemma models, other models will still
load in float32 by default.
  • Loading branch information
tengomucho committed Apr 10, 2024
1 parent 6e6b44e commit 92e9e31
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
3 changes: 1 addition & 2 deletions optimum/tpu/distributed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
)

# Model loading and sharding should happen here
config = AutoConfig.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=config.torch_dtype)
model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.eval()
model.to(device)
if rank == 0:
Expand Down
9 changes: 9 additions & 0 deletions optimum/tpu/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,3 +1344,12 @@ def _reorder_cache(past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Unless specified otherwise, the model weights type will be bfloat16
torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16)
# forward to base implementation
return super().from_pretrained(
pretrained_model_name_or_path, *model_args, torch_dtype=torch_dtype, **kwargs
)

0 comments on commit 92e9e31

Please sign in to comment.