diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 9bd6f62..eeb56e8 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -671,6 +671,7 @@ def from_pretrained(cls, model_path: str, revision: str, max_batch_size: int, ma logger.warning("Revision is not supported for JetStream/Pytorch engine, ignoring.") logger.info("Loading model engine (this can take a few minutes).") start = time.time() + torch.set_default_dtype(torch.bfloat16) engine = create_engine( model_path, max_batch_size,