Skip to content

Commit

Permalink
feat(Jetstream Pt): set torch default dtype to bfloat16
Browse files Browse the repository at this point in the history
This is what should be used on TPUs.
  • Loading branch information
tengomucho committed Sep 19, 2024
1 parent 6ca6f43 commit 93ca260
Showing 1 changed file with 1 addition and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ def from_pretrained(cls, model_path: str, revision: str, max_batch_size: int, ma
logger.warning("Revision is not supported for JetStream/Pytorch engine, ignoring.")
logger.info("Loading model engine (this can take a few minutes).")
start = time.time()
torch.set_default_dtype(torch.bfloat16)
engine = create_engine(
model_path,
max_batch_size,
Expand Down

0 comments on commit 93ca260

Please sign in to comment.