-
Notifications
You must be signed in to change notification settings - Fork 500
/
inference_multigpu_demo.py
207 lines (187 loc) · 8.51 KB
/
inference_multigpu_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description: use torchrun to inference with multi-gpus
usage:
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 inference_multigpu_demo.py --model_type bloom --base_model bigscience/bloom-560m
"""
import argparse
import json
import os
import torch
import torch.distributed as dist
from loguru import logger
from peft import PeftModel
from torch.nn import DataParallel
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BloomForCausalLM,
BloomTokenizerFast,
LlamaForCausalLM,
GenerationConfig,
BitsAndBytesConfig,
)
from template import get_conv_template
MODEL_CLASSES = {
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoModel, AutoTokenizer),
"llama": (LlamaForCausalLM, AutoTokenizer),
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoModelForCausalLM, AutoTokenizer),
}
class TextDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default=None, type=str, required=True)
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
parser.add_argument('--tokenizer_path', default=None, type=str)
parser.add_argument('--template_name', default="vicuna", type=str,
help="Prompt template name, eg: alpaca, vicuna, baichuan, chatglm2 etc.")
parser.add_argument('--system_prompt', default="", type=str)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument('--temperature', type=float, default=0.7)
parser.add_argument("--max_new_tokens", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument('--data_file', default=None, type=str, help="Predict file, one example per line")
parser.add_argument('--output_file', default='./predictions_result.jsonl', type=str)
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--load_in_8bit', action='store_true', help='Whether to load model in 8bit')
parser.add_argument('--load_in_4bit', action='store_true', help='Whether to load model in 4bit')
args = parser.parse_args()
logger.info(args)
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
logger.info(f"local_rank: {local_rank}, world_size: {world_size}")
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
if not torch.cuda.is_available():
raise ValueError("No GPU available, this script is only for GPU inference.")
if args.tokenizer_path is None:
args.tokenizer_path = args.base_model
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True, padding_side='left')
load_type = 'auto'
base_model = model_class.from_pretrained(
args.base_model,
load_in_8bit=args.load_in_8bit,
load_in_4bit=args.load_in_4bit,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map={"": local_rank},
trust_remote_code=True,
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
) if args.load_in_8bit or args.load_in_4bit else None,
)
try:
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
except OSError:
logger.info("Failed to load generation config, use default.")
if args.resize_emb:
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
logger.info(f"Vocab of the base model: {model_vocab_size}")
logger.info(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size != tokenzier_vocab_size:
logger.info("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenzier_vocab_size)
if args.lora_model:
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type,
device_map={"": local_rank})
logger.info("Loaded lora model")
else:
model = base_model
model.eval()
# Use multi-GPU inference
model = DataParallel(model)
model = model.module
logger.info(tokenizer)
# test data
if args.data_file is None:
examples = [
"介绍下北京",
"乙肝和丙肝的区别?",
"失眠怎么办?",
'用一句话描述地球为什么是独一无二的。',
"Tell me about alpacas.",
"Tell me about the president of Mexico in 2019.",
"hello.",
]
else:
with open(args.data_file, 'r', encoding='utf-8') as f:
examples = [l.strip() for l in f.readlines()]
logger.info(f"first 10 examples: {examples[:10]}")
prompt_template = get_conv_template(args.template_name)
write_batch_size = args.batch_size * world_size * 10
generation_kwargs = dict(
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
do_sample=True if args.temperature > 0.0 else False,
repetition_penalty=args.repetition_penalty,
)
stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
if local_rank <= 0 and os.path.exists(args.output_file):
os.remove(args.output_file)
count = 0
for batch in tqdm(
[
examples[i: i + write_batch_size]
for i in range(0, len(examples), write_batch_size)
],
desc="Generating outputs",
):
dataset = TextDataset(batch)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank, shuffle=False)
data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler)
responses = []
inputs = []
for texts in data_loader:
inputs.extend(texts)
prompted_texts = [prompt_template.get_prompt(messages=[[s, '']], system_prompt=args.system_prompt) for s in texts]
logger.debug(f'local_rank: {local_rank}, inputs size:{len(prompted_texts)}, top3: {prompted_texts[:3]}')
inputs_tokens = tokenizer(prompted_texts, return_tensors="pt", padding=True)
input_ids = inputs_tokens['input_ids'].to(local_rank)
outputs = model.generate(input_ids=input_ids, **generation_kwargs)
prompt_len = len(input_ids[0])
outputs = [i[prompt_len:] for i in outputs]
generated_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
logger.debug(
f'local_rank: {local_rank}, outputs size:{len(generated_outputs)}, top3: {generated_outputs[:3]}'
)
responses.extend(generated_outputs)
all_inputs = [None] * world_size
all_responses = [None] * world_size
dist.all_gather_object(all_inputs, inputs)
dist.all_gather_object(all_responses, responses)
# Write responses only on the main process
if local_rank <= 0:
all_inputs_flat = [inp for process_inputs in all_inputs for inp in process_inputs]
all_responses_flat = [response for process_responses in all_responses for response in process_responses]
logger.debug(f"all_responses size:{len(all_responses_flat)}, top5: {all_responses_flat[:5]}")
results = []
for example, response in zip(all_inputs_flat, all_responses_flat):
results.append({"Input": example, "Output": response})
with open(args.output_file, 'a', encoding='utf-8') as f:
for entry in results:
json.dump(entry, f, ensure_ascii=False)
f.write('\n')
count += 1
if local_rank <= 0:
logger.info(f'save to {args.output_file}, total count: {count}')
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__':
main()