Skip to content

Commit

Permalink
Transformers inference supports implementation (#2)
Browse files Browse the repository at this point in the history
* add inference_hf script
* update based on codacy
  • Loading branch information
iMountTai authored Jan 29, 2024
1 parent f7abd09 commit 4805b48
Showing 1 changed file with 216 additions and 0 deletions.
216 changes: 216 additions & 0 deletions scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import argparse
import json, os

TEMPLATE = (
"[INST] {instruction} [/INST]"
)

parser = argparse.ArgumentParser()
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--tokenizer_path', default=None, type=str)
parser.add_argument('--data_file', default=None, type=str, help="A file that contains instructions (one instruction per line)")
parser.add_argument('--with_prompt', action='store_true', help="wrap the input with the prompt automatically")
parser.add_argument('--interactive', action='store_true', help="run in the instruction mode (single-turn)")
parser.add_argument('--predictions_file', default='./predictions.json', type=str)
parser.add_argument('--gpus', default="0", type=str)
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
parser.add_argument('--load_in_8bit', action='store_true', help="Load the LLM in the 8bit mode")
parser.add_argument('--load_in_4bit', action='store_true', help="Load the LLM in the 4bit mode")
parser.add_argument("--use_vllm", action='store_true', help="Use vLLM as back-end LLM service.")
parser.add_argument('--use_flash_attention_2', action='store_true', help="Use flash attention to replace the Mixtral attention")
args = parser.parse_args()


if args.use_vllm:
if args.load_in_8bit or args.load_in_4bit:
raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.")
if args.only_cpu:
raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.")
if args.load_in_8bit and args.load_in_4bit:
raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments")
if args.only_cpu is True:
args.gpus = ""
if args.load_in_8bit or args.load_in_4bit:
raise ValueError("Quantization is unavailable on CPU.")

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizer
from transformers import GenerationConfig
from transformers import BitsAndBytesConfig
if args.use_vllm:
from vllm import LLM, SamplingParams

import sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)


if args.use_vllm:
generation_config = dict(
temperature=0.2,
top_k=40,
top_p=0.9,
max_tokens=400,
presence_penalty=1.0,
)
else:
generation_config = GenerationConfig(
temperature=0.2,
top_k=40,
top_p=0.9,
do_sample=True,
num_beams=1,
repetition_penalty=1.1,
max_new_tokens=400
)

sample_data = ["为什么要减少污染,保护环境?"]

def generate_prompt(instruction):
return TEMPLATE.format_map({'instruction': instruction})


if __name__ == '__main__':
load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
else:
device = torch.device('cpu')
if args.tokenizer_path is None:
args.tokenizer_path = args.base_model

if args.use_vllm:
model = LLM(model=args.base_model,
tokenizer=args.tokenizer_path,
tokenizer_mode='slow',
tensor_parallel_size=len(args.gpus.split(','))
)
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
else:
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
if args.load_in_4bit or args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)

model = AutoModelForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,
attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa"
)
if device==torch.device('cpu'):
model.float()
model.eval()

# test data
if args.data_file is None:
examples = sample_data
else:
with open(args.data_file,'r') as f:
examples = [line.strip() for line in f.readlines()]
print("first 10 examples:")
for example in examples[:10]:
print(example)

with torch.no_grad():
if args.interactive:
print("Start inference with instruction mode.")

print('='*85)
print("+ 该模式下仅支持单轮问答,无多轮对话能力。\n"
"+ 如要进行多轮对话,请使用llama.cpp")
print('-'*85)
print("+ This mode only supports single-turn QA.\n"
"+ If you want to experience multi-turn dialogue, please use llama.cpp")
print('='*85)

while True:
raw_input_text = input("Input:")
if len(raw_input_text.strip())==0:
break
if args.with_prompt:
input_text = generate_prompt(instruction=raw_input_text)
else:
input_text = raw_input_text

if args.use_vllm:
output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False)
response = output[0].outputs[0].text
else:
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
generation_config = generation_config
)
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)
if args.with_prompt:
response = output.split("[/INST]")[-1].strip()
else:
response = output
print("Response: ",response)
print("\n")
else:
print("Start inference.")
results = []
if args.use_vllm:
if args.with_prompt is True:
inputs = [generate_prompt(example) for example in examples]
else:
inputs = examples
outputs = model.generate(inputs, SamplingParams(**generation_config))

for index, (example, output) in enumerate(zip(examples, outputs)):
response = output.outputs[0].text
print(f"======={index}=======")
print(f"Input: {example}\n")
print(f"Output: {response}\n")
results.append({"Input":example,"Output":response})
else:
for index, example in enumerate(examples):
if args.with_prompt:
input_text = generate_prompt(instruction=example)
else:
input_text = example
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
generation_config = generation_config
)
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)
if args.with_prompt:
response = output.split("[/INST]")[1].strip()
else:
response = output
print(f"======={index}=======")
print(f"Input: {example}\n")
print(f"Output: {response}\n")

results.append({"Input":input_text,"Output":response})

dirname = os.path.dirname(args.predictions_file)
os.makedirs(dirname,exist_ok=True)
with open(args.predictions_file,'w') as f:
json.dump(results,f,ensure_ascii=False,indent=2)
if args.use_vllm:
with open(dirname+'/generation_config.json','w') as f:
json.dump(generation_config,f,ensure_ascii=False,indent=2)
else:
generation_config.save_pretrained('./')

0 comments on commit 4805b48

Please sign in to comment.