From e1c58c711ce944e726d8a8e76bbfd9c3167d0925 Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Fri, 13 Dec 2024 14:21:57 -0500 Subject: [PATCH 1/2] Update factory.py --- ml_mdm/language_models/factory.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/ml_mdm/language_models/factory.py b/ml_mdm/language_models/factory.py index df8f838..3ff8fc1 100644 --- a/ml_mdm/language_models/factory.py +++ b/ml_mdm/language_models/factory.py @@ -2,7 +2,7 @@ # Copyright (C) 2024 Apple Inc. All rights reserved. import logging -from transformers import T5ForConditionalGeneration +from transformers import T5ForConditionalGeneration, BitsAndBytesConfig import torch import torch.nn as nn @@ -42,11 +42,11 @@ def load(self): class LanguageModel(nn.Module): - def __init__(self, args, model): + def __init__(self, args, model, device="cpu"): super().__init__() self.model = model self.embed_dim = model.embed_dim - self.device = "cpu" + self.device = device self.args = args # use pre-computed text embeddings. delete the language model! @@ -69,6 +69,7 @@ def forward(self, sample, tokenizer): ) else: sample_tokens = sample["tokens"] + sample_tokens = sample_tokens.to(self.device) if args.categorical_conditioning: lm_outputs = ( @@ -134,7 +135,19 @@ def create_lm(args, device: torch.device = "cuda"): ) # TODO (jack_carlson) this line is never reached else: tokenizer = create_tokenizer(args.vocab_file) - model = T5Encoder.from_pretrained(args.text_model) - model = LanguageModel(args, model).to(device) + if torch.cuda.is_available(): + model = T5Encoder.from_pretrained( + args.text_model, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True) + ) + model = LanguageModel(args, model, device="cuda") # .to() not supported for bitsandbytes `8-bit` quantized models + else: + model = T5Encoder.from_pretrained( + args.text_model, + device_map="cpu", + torch_dtype=torch.float16 + ) + model = LanguageModel(args, model, device="cpu") model.eval() return tokenizer, model From 52c0d19c4b4b425d47279fcbbfd12559705a93fd Mon Sep 17 00:00:00 2001 From: levinkhho <125771265+levinkhho@users.noreply.github.com> Date: Sat, 14 Dec 2024 15:30:39 -0500 Subject: [PATCH 2/2] Update factory.py --- ml_mdm/language_models/factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ml_mdm/language_models/factory.py b/ml_mdm/language_models/factory.py index 3ff8fc1..f55432e 100644 --- a/ml_mdm/language_models/factory.py +++ b/ml_mdm/language_models/factory.py @@ -42,7 +42,7 @@ def load(self): class LanguageModel(nn.Module): - def __init__(self, args, model, device="cpu"): + def __init__(self, args, model, device: torch.device = "cpu"): super().__init__() self.model = model self.embed_dim = model.embed_dim @@ -148,6 +148,6 @@ def create_lm(args, device: torch.device = "cuda"): device_map="cpu", torch_dtype=torch.float16 ) - model = LanguageModel(args, model, device="cpu") + model = LanguageModel(args, model, device = "cpu") model.eval() return tokenizer, model