Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update guidance with the mode of PEFT #656

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from peft import PeftModel


try:
import torch
Expand Down Expand Up @@ -76,7 +78,7 @@ def _tokenizer(self, model, **kwargs):
return tokenizer

class TransformersEngine(Engine):
def __init__(self, model, tokenizer, compute_log_probs, **kwargs):
def __init__(self, model, tokenizer, peft_model_id, compute_log_probs, **kwargs):
# fill in default model value
if model is None:
model = os.environ.get("TRANSFORMERS_MODEL", None)
Expand All @@ -87,7 +89,7 @@ def __init__(self, model, tokenizer, compute_log_probs, **kwargs):
except:
pass

self.model_obj = self._model(model, **kwargs)
self.model_obj, orig_tokenizer = self._model_and_tokenizer(model, tokenizer, peft_model_id, **kwargs)

if not isinstance(model, str):
self.model = model.__class__.__name__
Expand All @@ -102,8 +104,10 @@ def __init__(self, model, tokenizer, compute_log_probs, **kwargs):
compute_log_probs=compute_log_probs
)

def _model(self, model, **kwargs):
# intantiate the model if needed

def _model_and_tokenizer(self, model, tokenizer, peft_model_id, **kwargs):

# intantiate the model and tokenizer if needed
if isinstance(model, str):

# make sure transformers is installed
Expand All @@ -112,7 +116,18 @@ def _model(self, model, **kwargs):
except:
raise Exception("Please install transformers with `pip install transformers` in order to use guidance.models.Transformers!")
model = transformers.AutoModelForCausalLM.from_pretrained(model, **kwargs)
return model
if peft_model_id is not None:
try:
model = PeftModel.from_pretrained(model, peft_model_id)
except ImportError as e:
print("Cannot load peft module, please install with 'pip install peft' or 'pip install git+https://github.com/huggingface/peft")
except Exception as e: #fallthrough general exception
print(f"Exception while applying peft model:\n{e.message}")

assert tokenizer is not None, "You must give a tokenizer object when you provide a model object (as opposed to just a model name)!"

return model, tokenizer


def _joint_tokenize(self, token_ids):
# first_decode = self.tokenizer._orig_tokenizer.decode(token_ids)
Expand Down Expand Up @@ -190,10 +205,10 @@ def get_logits(self, token_ids, forced_bytes, current_temp):


class Transformers(Model):
def __init__(self, model=None, tokenizer=None, echo=True, compute_log_probs=False, **kwargs):
def __init__(self, model=None, tokenizer=None, peft_model_id=None, echo=True, compute_log_probs=False, **kwargs):
'''Build a new Transformers model object that represents a model in a given state.'''
super().__init__(
TransformersEngine(model, tokenizer, compute_log_probs, **kwargs),
TransformersEngine(model, tokenizer, peft_model_id, compute_log_probs, **kwargs),
echo=echo
)

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def find_version(*file_paths):
"pyformlang",
"protobuf",
"fastapi",
"uvicorn"
"uvicorn",
"peft"
],
extras_require={
'docs': [
Expand Down
22 changes: 22 additions & 0 deletions tests/models/test_peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from transformers import AutoModelForCausalLM
def test_peft():
try:
import peft
from peft import get_peft_model

lora_config = LoraConfig(
r=16,
target_modules=["q_proj", "v_proj"],
task_type=TaskType.CAUSAL_LM,
lora_alpha=32,
lora_dropout=0.05
)
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()

print("Running PEFT is successful!")

except:
raise Exception("Sorry, peft is not installed")