diff --git a/ml_mdm/language_models/factory.py b/ml_mdm/language_models/factory.py index df8f838..f55432e 100644 --- a/ml_mdm/language_models/factory.py +++ b/ml_mdm/language_models/factory.py @@ -2,7 +2,7 @@ # Copyright (C) 2024 Apple Inc. All rights reserved. import logging -from transformers import T5ForConditionalGeneration +from transformers import T5ForConditionalGeneration, BitsAndBytesConfig import torch import torch.nn as nn @@ -42,11 +42,11 @@ def load(self): class LanguageModel(nn.Module): - def __init__(self, args, model): + def __init__(self, args, model, device: torch.device = "cpu"): super().__init__() self.model = model self.embed_dim = model.embed_dim - self.device = "cpu" + self.device = device self.args = args # use pre-computed text embeddings. delete the language model! @@ -69,6 +69,7 @@ def forward(self, sample, tokenizer): ) else: sample_tokens = sample["tokens"] + sample_tokens = sample_tokens.to(self.device) if args.categorical_conditioning: lm_outputs = ( @@ -134,7 +135,19 @@ def create_lm(args, device: torch.device = "cuda"): ) # TODO (jack_carlson) this line is never reached else: tokenizer = create_tokenizer(args.vocab_file) - model = T5Encoder.from_pretrained(args.text_model) - model = LanguageModel(args, model).to(device) + if torch.cuda.is_available(): + model = T5Encoder.from_pretrained( + args.text_model, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True) + ) + model = LanguageModel(args, model, device="cuda") # .to() not supported for bitsandbytes `8-bit` quantized models + else: + model = T5Encoder.from_pretrained( + args.text_model, + device_map="cpu", + torch_dtype=torch.float16 + ) + model = LanguageModel(args, model, device = "cpu") model.eval() return tokenizer, model