-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transformers inference supports implementation (#2)
* add inference_hf script * update based on codacy
- Loading branch information
Showing
1 changed file
with
216 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import argparse | ||
import json, os | ||
|
||
TEMPLATE = ( | ||
"[INST] {instruction} [/INST]" | ||
) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--base_model', default=None, type=str, required=True) | ||
parser.add_argument('--tokenizer_path', default=None, type=str) | ||
parser.add_argument('--data_file', default=None, type=str, help="A file that contains instructions (one instruction per line)") | ||
parser.add_argument('--with_prompt', action='store_true', help="wrap the input with the prompt automatically") | ||
parser.add_argument('--interactive', action='store_true', help="run in the instruction mode (single-turn)") | ||
parser.add_argument('--predictions_file', default='./predictions.json', type=str) | ||
parser.add_argument('--gpus', default="0", type=str) | ||
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference') | ||
parser.add_argument('--load_in_8bit', action='store_true', help="Load the LLM in the 8bit mode") | ||
parser.add_argument('--load_in_4bit', action='store_true', help="Load the LLM in the 4bit mode") | ||
parser.add_argument("--use_vllm", action='store_true', help="Use vLLM as back-end LLM service.") | ||
parser.add_argument('--use_flash_attention_2', action='store_true', help="Use flash attention to replace the Mixtral attention") | ||
args = parser.parse_args() | ||
|
||
|
||
if args.use_vllm: | ||
if args.load_in_8bit or args.load_in_4bit: | ||
raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.") | ||
if args.only_cpu: | ||
raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.") | ||
if args.load_in_8bit and args.load_in_4bit: | ||
raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments") | ||
if args.only_cpu is True: | ||
args.gpus = "" | ||
if args.load_in_8bit or args.load_in_4bit: | ||
raise ValueError("Quantization is unavailable on CPU.") | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus | ||
import torch | ||
from transformers import AutoModelForCausalLM, LlamaTokenizer | ||
from transformers import GenerationConfig | ||
from transformers import BitsAndBytesConfig | ||
if args.use_vllm: | ||
from vllm import LLM, SamplingParams | ||
|
||
import sys | ||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
sys.path.append(parent_dir) | ||
|
||
|
||
if args.use_vllm: | ||
generation_config = dict( | ||
temperature=0.2, | ||
top_k=40, | ||
top_p=0.9, | ||
max_tokens=400, | ||
presence_penalty=1.0, | ||
) | ||
else: | ||
generation_config = GenerationConfig( | ||
temperature=0.2, | ||
top_k=40, | ||
top_p=0.9, | ||
do_sample=True, | ||
num_beams=1, | ||
repetition_penalty=1.1, | ||
max_new_tokens=400 | ||
) | ||
|
||
sample_data = ["为什么要减少污染,保护环境?"] | ||
|
||
def generate_prompt(instruction): | ||
return TEMPLATE.format_map({'instruction': instruction}) | ||
|
||
|
||
if __name__ == '__main__': | ||
load_type = torch.float16 | ||
if torch.cuda.is_available(): | ||
device = torch.device(0) | ||
else: | ||
device = torch.device('cpu') | ||
if args.tokenizer_path is None: | ||
args.tokenizer_path = args.base_model | ||
|
||
if args.use_vllm: | ||
model = LLM(model=args.base_model, | ||
tokenizer=args.tokenizer_path, | ||
tokenizer_mode='slow', | ||
tensor_parallel_size=len(args.gpus.split(',')) | ||
) | ||
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) | ||
else: | ||
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) | ||
if args.load_in_4bit or args.load_in_8bit: | ||
quantization_config = BitsAndBytesConfig( | ||
load_in_4bit=args.load_in_4bit, | ||
load_in_8bit=args.load_in_8bit, | ||
bnb_4bit_compute_dtype=load_type, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type="nf4" | ||
) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
args.base_model, | ||
torch_dtype=load_type, | ||
low_cpu_mem_usage=True, | ||
device_map='auto', | ||
load_in_4bit=args.load_in_4bit, | ||
load_in_8bit=args.load_in_8bit, | ||
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None, | ||
attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa" | ||
) | ||
if device==torch.device('cpu'): | ||
model.float() | ||
model.eval() | ||
|
||
# test data | ||
if args.data_file is None: | ||
examples = sample_data | ||
else: | ||
with open(args.data_file,'r') as f: | ||
examples = [line.strip() for line in f.readlines()] | ||
print("first 10 examples:") | ||
for example in examples[:10]: | ||
print(example) | ||
|
||
with torch.no_grad(): | ||
if args.interactive: | ||
print("Start inference with instruction mode.") | ||
|
||
print('='*85) | ||
print("+ 该模式下仅支持单轮问答,无多轮对话能力。\n" | ||
"+ 如要进行多轮对话,请使用llama.cpp") | ||
print('-'*85) | ||
print("+ This mode only supports single-turn QA.\n" | ||
"+ If you want to experience multi-turn dialogue, please use llama.cpp") | ||
print('='*85) | ||
|
||
while True: | ||
raw_input_text = input("Input:") | ||
if len(raw_input_text.strip())==0: | ||
break | ||
if args.with_prompt: | ||
input_text = generate_prompt(instruction=raw_input_text) | ||
else: | ||
input_text = raw_input_text | ||
|
||
if args.use_vllm: | ||
output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False) | ||
response = output[0].outputs[0].text | ||
else: | ||
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? | ||
generation_output = model.generate( | ||
input_ids = inputs["input_ids"].to(device), | ||
attention_mask = inputs['attention_mask'].to(device), | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.eos_token_id, | ||
generation_config = generation_config | ||
) | ||
s = generation_output[0] | ||
output = tokenizer.decode(s,skip_special_tokens=True) | ||
if args.with_prompt: | ||
response = output.split("[/INST]")[-1].strip() | ||
else: | ||
response = output | ||
print("Response: ",response) | ||
print("\n") | ||
else: | ||
print("Start inference.") | ||
results = [] | ||
if args.use_vllm: | ||
if args.with_prompt is True: | ||
inputs = [generate_prompt(example) for example in examples] | ||
else: | ||
inputs = examples | ||
outputs = model.generate(inputs, SamplingParams(**generation_config)) | ||
|
||
for index, (example, output) in enumerate(zip(examples, outputs)): | ||
response = output.outputs[0].text | ||
print(f"======={index}=======") | ||
print(f"Input: {example}\n") | ||
print(f"Output: {response}\n") | ||
results.append({"Input":example,"Output":response}) | ||
else: | ||
for index, example in enumerate(examples): | ||
if args.with_prompt: | ||
input_text = generate_prompt(instruction=example) | ||
else: | ||
input_text = example | ||
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? | ||
generation_output = model.generate( | ||
input_ids = inputs["input_ids"].to(device), | ||
attention_mask = inputs['attention_mask'].to(device), | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.eos_token_id, | ||
generation_config = generation_config | ||
) | ||
s = generation_output[0] | ||
output = tokenizer.decode(s,skip_special_tokens=True) | ||
if args.with_prompt: | ||
response = output.split("[/INST]")[1].strip() | ||
else: | ||
response = output | ||
print(f"======={index}=======") | ||
print(f"Input: {example}\n") | ||
print(f"Output: {response}\n") | ||
|
||
results.append({"Input":input_text,"Output":response}) | ||
|
||
dirname = os.path.dirname(args.predictions_file) | ||
os.makedirs(dirname,exist_ok=True) | ||
with open(args.predictions_file,'w') as f: | ||
json.dump(results,f,ensure_ascii=False,indent=2) | ||
if args.use_vllm: | ||
with open(dirname+'/generation_config.json','w') as f: | ||
json.dump(generation_config,f,ensure_ascii=False,indent=2) | ||
else: | ||
generation_config.save_pretrained('./') |