diff --git a/guidance/models/transformers/_transformers.py b/guidance/models/transformers/_transformers.py index 38f46d336..e3aeda8e2 100644 --- a/guidance/models/transformers/_transformers.py +++ b/guidance/models/transformers/_transformers.py @@ -1,4 +1,6 @@ import os +from peft import PeftModel + try: import torch @@ -76,7 +78,7 @@ def _tokenizer(self, model, **kwargs): return tokenizer class TransformersEngine(Engine): - def __init__(self, model, tokenizer, compute_log_probs, **kwargs): + def __init__(self, model, tokenizer, peft_model_id, compute_log_probs, **kwargs): # fill in default model value if model is None: model = os.environ.get("TRANSFORMERS_MODEL", None) @@ -87,7 +89,7 @@ def __init__(self, model, tokenizer, compute_log_probs, **kwargs): except: pass - self.model_obj = self._model(model, **kwargs) + self.model_obj, orig_tokenizer = self._model_and_tokenizer(model, tokenizer, peft_model_id, **kwargs) if not isinstance(model, str): self.model = model.__class__.__name__ @@ -102,8 +104,10 @@ def __init__(self, model, tokenizer, compute_log_probs, **kwargs): compute_log_probs=compute_log_probs ) - def _model(self, model, **kwargs): - # intantiate the model if needed + + def _model_and_tokenizer(self, model, tokenizer, peft_model_id, **kwargs): + + # intantiate the model and tokenizer if needed if isinstance(model, str): # make sure transformers is installed @@ -112,7 +116,18 @@ def _model(self, model, **kwargs): except: raise Exception("Please install transformers with `pip install transformers` in order to use guidance.models.Transformers!") model = transformers.AutoModelForCausalLM.from_pretrained(model, **kwargs) - return model + if peft_model_id is not None: + try: + model = PeftModel.from_pretrained(model, peft_model_id) + except ImportError as e: + print("Cannot load peft module, please install with 'pip install peft' or 'pip install git+https://github.com/huggingface/peft") + except Exception as e: #fallthrough general exception + print(f"Exception while applying peft model:\n{e.message}") + + assert tokenizer is not None, "You must give a tokenizer object when you provide a model object (as opposed to just a model name)!" + + return model, tokenizer + def _joint_tokenize(self, token_ids): # first_decode = self.tokenizer._orig_tokenizer.decode(token_ids) @@ -190,10 +205,10 @@ def get_logits(self, token_ids, forced_bytes, current_temp): class Transformers(Model): - def __init__(self, model=None, tokenizer=None, echo=True, compute_log_probs=False, **kwargs): + def __init__(self, model=None, tokenizer=None, peft_model_id=None, echo=True, compute_log_probs=False, **kwargs): '''Build a new Transformers model object that represents a model in a given state.''' super().__init__( - TransformersEngine(model, tokenizer, compute_log_probs, **kwargs), + TransformersEngine(model, tokenizer, peft_model_id, compute_log_probs, **kwargs), echo=echo ) diff --git a/setup.py b/setup.py index ed52f23ca..45985495d 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,8 @@ def find_version(*file_paths): "pyformlang", "protobuf", "fastapi", - "uvicorn" + "uvicorn", + "peft" ], extras_require={ 'docs': [ diff --git a/tests/models/test_peft.py b/tests/models/test_peft.py new file mode 100644 index 000000000..90d80935d --- /dev/null +++ b/tests/models/test_peft.py @@ -0,0 +1,22 @@ +from transformers import AutoModelForCausalLM +def test_peft(): + try: + import peft + from peft import get_peft_model + + lora_config = LoraConfig( + r=16, + target_modules=["q_proj", "v_proj"], + task_type=TaskType.CAUSAL_LM, + lora_alpha=32, + lora_dropout=0.05 + ) + model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") + + lora_model = get_peft_model(model, lora_config) + lora_model.print_trainable_parameters() + + print("Running PEFT is successful!") + + except: + raise Exception("Sorry, peft is not installed")