forked from artidoro/qlora
-
Notifications
You must be signed in to change notification settings - Fork 13
/
merge.py
39 lines (34 loc) · 1.18 KB
/
merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import os
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--peft", type=str)
parser.add_argument("--out", type=str)
parser.add_argument("--push", action="store_true")
return parser.parse_args()
def main():
args = get_args()
print(f"Loading base model: {args.base}")
base_model = AutoModelForCausalLM.from_pretrained(
args.base,
return_dict=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
print(f"Loading PEFT: {args.peft}")
model = PeftModel.from_pretrained(base_model, args.peft)
print(f"Running merge_and_unload")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(args.base)
model.save_pretrained(args.out, safe_serialization=True, max_shard_size='4GB')
tokenizer.save_pretrained(args.out)
if args.push:
print(f"Saving to hub ...")
model.push(args.out, use_temp_dir=False)
tokenizer.push(args.out, use_temp_dir=False)
if __name__ == "__main__" :
main()