forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predictor.py
1007 lines (870 loc) · 43.1 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import os
import sys
import time
from abc import abstractmethod
from dataclasses import dataclass, field
from threading import Thread
from typing import List, Optional
import numpy as np
import paddle
import paddle.distributed.fleet.base.topology as tp
from paddle.distributed import fleet
from utils import (
dybatch_preprocess,
get_alibi_slopes,
get_infer_model_path,
get_prefix_tuning_params,
init_chat_template,
load_real_time_tokens,
)
from paddlenlp.generation import GenerationConfig, TextIteratorStreamer
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
from paddlenlp.taskflow.utils import static_mode_guard
from paddlenlp.trainer import PdArgumentParser
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
ChatGLMv2Tokenizer,
LlamaTokenizer,
PretrainedModel,
PretrainedTokenizer,
)
from paddlenlp.utils.import_utils import import_module, is_paddlenlp_ops_available
from paddlenlp.utils.log import logger
@dataclass
class PredictorArgument:
model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."})
model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"})
src_length: int = field(default=1024, metadata={"help": "The max length of source text."})
max_length: int = field(default=2048, metadata={"help": "the max length for decoding."})
top_k: int = field(default=0, metadata={"help": "top_k parameter for generation"})
top_p: float = field(default=0.7, metadata={"help": "top_p parameter for generation"})
temperature: float = field(default=0.95, metadata={"help": "top_p parameter for generation"})
repetition_penalty: float = field(default=1.0, metadata={"help": "repetition penalty parameter for generation"})
device: str = field(default="gpu", metadata={"help": "Device"})
dtype: str = field(default=None, metadata={"help": "Model dtype"})
lora_path: str = field(default=None, metadata={"help": "The directory of LoRA parameters. Default to None"})
export_precache: bool = field(default=False, metadata={"help": "whether use prefix weight to do infer"})
prefix_path: str = field(
default=None, metadata={"help": "The directory of Prefix Tuning parameters. Default to None"}
)
decode_strategy: str = field(
default="sampling",
metadata={
"help": "the decoding strategy of generation, which should be one of ['sampling', 'greedy_search', 'beam_search']. Default to sampling"
},
)
mode: str = field(
default="dynamic", metadata={"help": "the type of predictor, it should be one of [dynamic, static]"}
)
inference_model: bool = field(default=False, metadata={"help": "whether use InferenceModel to do generation"})
quant_type: str = field(
default=None,
metadata={"help": "Quantization type. Supported values: a8w8, weight_only_int4, weight_only_int8"},
)
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
benchmark: bool = field(
default=False,
metadata={
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
},
)
chat_template: str = field(
default=None,
metadata={
"help": "the path of `chat_template.json` file to handle multi-rounds conversation. "
"If is None(do not set --chat_template argument), it will use the default `chat_template.json`;"
"If is equal with `model_name_or_path`, it will use the default loading; "
"If is directory, it will find the `chat_template.json` under the directory; If is file, it will load it."
"If is none string, it will not use chat_template.json."
},
)
@property
def total_max_length(self):
return self.src_length + self.max_length
@dataclass
class ModelArgument:
model_type: str = field(
default=None,
metadata={"help": "the type of the model, which can be one of ['gpt-3', 'ernie-3.5-se', 'llama-img2txt']"},
)
data_file: str = field(default=None, metadata={"help": "data file directory"})
output_file: str = field(default="output.json", metadata={"help": "predict result file directory"})
def batchfy_text(texts, batch_size):
batch_texts = []
batch_start = 0
while batch_start < len(texts):
batch_texts += [texts[batch_start : min(batch_start + batch_size, len(texts))]]
batch_start += batch_size
return batch_texts
def init_dist_env():
tensor_parallel_degree = paddle.distributed.get_world_size()
tensor_parallel_rank = paddle.distributed.get_rank()
if tensor_parallel_degree > 1:
# refer to: https://github.com/PaddlePaddle/Paddle/blob/4abea956ee852ce52791a1e08fa92ed4d3be150d/python/paddle/distributed/fleet/fleet.py#L298C23-L298C45
hcg = tp._HYBRID_PARALLEL_GROUP
if hcg is None:
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": tensor_parallel_degree,
"pp_degree": 1,
"sharding_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
tensor_parallel_rank = hcg.get_model_parallel_rank()
return tensor_parallel_rank, tensor_parallel_degree
def get_eos_token_id(
tokenizer: PretrainedTokenizer, generation_config: Optional[GenerationConfig] = None
) -> int | List[List[int]]:
"""get eos_token_id from generation_config or tokenizer
Returns:
int | List[int]: eos_token_id to stop the generation
"""
if generation_config is None or generation_config.eos_token_id is None:
return tokenizer.eos_token_id
return generation_config.eos_token_id
class BasePredictor:
def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None):
self.model_config = AutoConfig.from_pretrained(config.model_name_or_path)
self.config: PredictorArgument = config
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path, padding_side="left")
self.tokenizer = tokenizer
self.return_tensors = "pd"
self.tensor_parallel_rank, self.tensor_parallel_degree = init_dist_env()
self.model_config.tensor_parallel_rank, self.model_config.tensor_parallel_degree = init_dist_env()
try:
self.generation_config = GenerationConfig.from_pretrained(config.model_name_or_path)
except:
logger.warning(
"Can't find generation config, so it will not use generation_config field in the model config"
)
self.generation_config = None
def _preprocess(self, source):
if self.tokenizer.chat_template is not None:
source = [source] if isinstance(source, str) else source
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]
tokenized_source = self.tokenizer(
source,
max_length=self.config.src_length,
truncation=True,
truncation_side="left",
return_tensors=self.return_tensors,
padding=True,
# when use chat_template, it should not add special tokens
# chatglm2 prefix-tokens can not be tokenized into ids
add_special_tokens=self.tokenizer.chat_template is None or isinstance(self.tokenizer, ChatGLMv2Tokenizer),
)
return tokenized_source
@abstractmethod
def _infer(self, inputs):
raise NotImplementedError
def _postprocess(self, predictions):
decoded_predictions = self.tokenizer.batch_decode(
predictions, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return decoded_predictions
def predict(self, input_texts: str | list[str]):
tokenized_source = self._preprocess(input_texts)
predictions = self._infer(tokenized_source)
decoded_predictions = self._postprocess(predictions)
return decoded_predictions
class DygraphPredictor(BasePredictor):
def __init__(
self, config: PredictorArgument, model: PretrainedModel = None, tokenizer: PretrainedTokenizer = None
):
super().__init__(config, tokenizer)
self.model = model
if config.lora_path is not None:
lora_config = LoRAConfig.from_pretrained(config.lora_path)
dtype = lora_config.dtype
lora_config.merge_weights = True
elif config.prefix_path is not None:
prefix_config = PrefixConfig.from_pretrained(config.prefix_path)
dtype = prefix_config.dtype
elif config.dtype is not None:
dtype = config.dtype
else:
raise ValueError("Please specific the model dtype.")
if self.model is None:
self.model = AutoModelForCausalLM.from_pretrained(
config.model_name_or_path,
dtype=dtype,
tensor_parallel_degree=self.tensor_parallel_degree,
tensor_parallel_rank=self.tensor_parallel_rank,
)
if config.lora_path is not None:
self.model = LoRAModel.from_pretrained(
model=self.model, lora_path=config.lora_path, lora_config=lora_config
)
if config.prefix_path is not None:
prefix_tuning_params = get_prefix_tuning_params(self.model)
self.model = PrefixModelForCausalLM.from_pretrained(
model=self.model,
prefix_path=config.prefix_path,
postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"],
)
self.model.eval()
@paddle.no_grad()
def _infer(self, inputs: dict[str, paddle.Tensor]):
result = self.model.generate(
**inputs,
max_new_tokens=self.config.max_length,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=get_eos_token_id(self.tokenizer, self.generation_config),
pad_token_id=self.tokenizer.pad_token_id,
decode_strategy=self.config.decode_strategy,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
repetition_penalty=self.config.repetition_penalty,
)
result = result[0]
return result
def stream_predict(self, inputs: dict[str, paddle.Tensor]):
text_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
input_features = self._preprocess(inputs)
generation_kwargs = dict(
**input_features,
streamer=text_streamer,
max_new_tokens=self.config.max_length,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=get_eos_token_id(self.tokenizer, self.generation_config),
pad_token_id=self.tokenizer.pad_token_id,
decode_strategy="greedy_search"
if self.config.top_k == 1 and self.config.top_p == 1.0
else self.config.decode_strategy,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
repetition_penalty=self.config.repetition_penalty,
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
return text_streamer
class StaticGraphPredictor(BasePredictor):
def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None):
super().__init__(config, tokenizer)
params_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdiparams")
model_path = os.path.join(self.config.model_name_or_path, self.config.model_prefix + ".pdmodel")
inference_config = paddle.inference.Config(model_path, params_path)
if self.config.device == "gpu":
# set GPU configs accordingly
inference_config.enable_use_gpu(100, 0)
elif self.config.device == "cpu":
# set CPU configs accordingly,
# such as enable_mkldnn, set_cpu_math_library_num_threads
inference_config.disable_gpu()
inference_config.disable_glog_info()
inference_config.enable_new_executor()
with static_mode_guard():
self.predictor = paddle.inference.create_predictor(inference_config)
self.return_tensors = "np"
def _preprocess(self, input_text: str | list[str]):
inputs = super()._preprocess(input_text)
inputs["max_new_tokens"] = np.array(self.config.max_length, dtype="int64")
inputs["top_p"] = np.array(self.config.top_p, dtype="float32")
inputs["temperature"] = np.array(self.config.temperature, dtype="float32")
inputs["top_k"] = np.array(self.config.top_k, dtype="int64")
inputs["repetition_penalty"] = np.array(self.config.repetition_penalty, dtype="float32")
return inputs
def _infer(self, inputs: dict[str, np.ndarray]):
for name in self.predictor.get_input_names():
self.predictor.get_input_handle(name).copy_from_cpu(inputs[name])
self.predictor.run()
output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle(output_names[0])
results = output_handle.copy_to_cpu()
# the first result is decoding_ids
decoded_ids = results.tolist()
return decoded_ids
class InferencePredictorMixin:
def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
self.architectures = self.model_config.architectures[0].lower()
self.dtype = config.dtype or self.model_config
self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape]
self.num_layers, self.num_attention_heads, self.head_dim = (
len(self.cache_kvs),
self.cache_kvs[0].shape[-3],
self.cache_kvs[0].shape[-1],
)
self.pre_ids = paddle.full([config.batch_size, config.total_max_length], -1, dtype="int64")
if "chatglm" in self.architectures:
self.attention_mask = paddle.ones(
shape=(config.batch_size, 1, config.total_max_length, config.total_max_length),
dtype=self.dtype,
)
self.tgt_pos = paddle.ones(
shape=[config.batch_size, 2, 1],
dtype="int64",
)
else:
self.attention_mask = paddle.zeros(
shape=(config.batch_size, 1, config.total_max_length, config.total_max_length),
dtype=self.dtype,
)
self.tgt_generation_mask = paddle.zeros(
shape=[config.batch_size, 1, 1, config.total_max_length],
dtype=self.dtype,
)
self.arange_tensor_encoder = paddle.zeros(
shape=(config.batch_size, 1, config.total_max_length), dtype=self.dtype
)
if config.export_precache:
if config.prefix_path:
prefix_cache = (
paddle.to_tensor(np.load(f"{config.prefix_path}/pre_caches.npy")).astype(self.dtype).unsqueeze(2)
)
prefix_cache = paddle.expand(
prefix_cache,
[
self.num_layers,
2,
config.batch_size,
self.num_attention_heads,
prefix_cache.shape[-2],
self.head_dim,
],
)
self.pre_caches = [item.squeeze_(0) for item in paddle.split(prefix_cache, self.num_layers, axis=0)]
else:
prefix_cache = paddle.zeros(
[self.num_layers, 2, config.batch_size, self.num_attention_heads, 128, self.head_dim],
dtype=self.dtype,
)
self.pre_caches = [item.squeeze_(0) for item in paddle.split(prefix_cache, self.num_layers, axis=0)]
try:
self.generation_config = GenerationConfig.from_pretrained(config.model_name_or_path)
except:
logger.warning(
"Can't find generation config, so it will not use generation_config field in the model config"
)
self.generation_config = None
def _postprocess(self, predictions):
if paddle.distributed.get_rank() == 0:
tokens: np.ndarray = load_real_time_tokens()
decoded_predictions = self.tokenizer.batch_decode(
tokens.tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return decoded_predictions
else:
return None
def _preprocess(self, source):
self.attention_mask[:] = 0
self.tgt_generation_mask[:] = 0
pre_caches_length = 0 if not self.config.export_precache else self.pre_caches[0].shape[-2]
if self.tokenizer.chat_template is not None:
source = [source] if isinstance(source, str) else source
source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source]
inputs = dybatch_preprocess(
self.tokenizer,
source,
self.config.src_length,
self.config.max_length,
self.architectures,
top_p=self.config.top_p,
temperature=self.config.temperature,
eos_token_id=get_eos_token_id(self.tokenizer, self.generation_config),
benchmark=self.config.benchmark,
pre_caches_length=pre_caches_length,
)
if "chatglmforcausallm" == self.architectures.lower():
if inputs["input_ids"].shape[0] < self.config.batch_size:
self.tgt_pos = self.tgt_pos[: inputs["input_ids"].shape[0]]
for i in range(inputs["input_ids"].shape[0]):
length = inputs["seq_len_encoder"][i][0]
self.attention_mask[i, 0, :length, :length] = 1
self.attention_mask[i, 0, : length - 1, length - 1] = 0
self.tgt_pos[i, 0, 0] = paddle.to_tensor([length], dtype="int64")
if pre_caches_length > 0:
prefix_attention_mask = paddle.ones(
[1, length, pre_caches_length], dtype=self.attention_mask.dtype
)
post_attention_mask = paddle.ones(
shape=(length, length), dtype=self.attention_mask.dtype
).unsqueeze_(axis=0)
post_attention_mask[0, : length - 1, length - 1] = 0
self.attention_mask[i, 0, :length, : length + pre_caches_length] = paddle.concat(
[prefix_attention_mask, post_attention_mask], axis=2
)
if self.config.prefix_path is None:
self.tgt_generation_mask[i, 0, 0, pre_caches_length : length + pre_caches_length] = paddle.ones(
shape=[1, length], dtype=self.config.dtype
)
else:
self.tgt_generation_mask[i, 0, 0, : length + pre_caches_length] = paddle.ones(
shape=[1, length + pre_caches_length], dtype=self.config.dtype
)
inputs["tgt_pos"] = self.tgt_pos
elif "bloom" in self.architectures:
for i in range(inputs["input_ids"].shape[0]):
length = inputs["seq_len_encoder"][i][0]
self.attention_mask[i, :, :length, :length] = paddle.tril(
paddle.ones(shape=(length, length), dtype=self.config.dtype)
)
if pre_caches_length > 0:
if self.config.prefix_path is None:
prefix_attention_mask = paddle.zeros([1, length, pre_caches_length], dtype=self.config.dtype)
else:
prefix_attention_mask = paddle.ones([1, length, pre_caches_length], dtype=self.config.dtype)
post_attention_mask = paddle.tril(
paddle.ones(shape=(length, length), dtype=self.config.dtype)
).unsqueeze_(axis=0)
self.attention_mask[i, :, :length, : length + pre_caches_length] = paddle.concat(
[prefix_attention_mask, post_attention_mask], axis=2
)
self.arange_tensor_encoder[i, :, : length + pre_caches_length] = paddle.arange(
length + pre_caches_length
).astype(self.config.dtype)
self.tgt_generation_mask[i, :, 0, : length + pre_caches_length] = paddle.ones(
shape=[1, length + pre_caches_length], dtype=self.config.dtype
)
inputs["tgt_pos"] = inputs["tgt_pos"] + pre_caches_length
# alibi encoder
alibi_slopes = get_alibi_slopes(self.model_config.n_head)
inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32")
alibi = alibi_slopes[..., None] * self.arange_tensor_encoder
alibi = alibi[:, :, None, :]
if self.model_config.tensor_parallel_degree > 1:
block_size = self.model_config.n_head // self.model_config.tensor_parallel_degree
alibi = alibi[
:,
self.model_config.tensor_parallel_rank
* block_size : (self.model_config.tensor_parallel_rank + 1)
* block_size,
]
alibi = alibi.reshape([self.config.batch_size, block_size, 1, self.config.max_length])
inputs["position_ids"] = inputs["position_ids"][
self.model_config.tensor_parallel_rank
* block_size : (self.model.config.tensor_parallel_rank + 1)
* block_size
]
alibi_encoder = alibi.expand(
[
self.config.batch_size,
self.model_config.n_head // self.model_config.tensor_parallel_degree,
self.config.total_max_length,
self.config.total_max_length,
]
)
alibi_decoder = alibi.expand(
[
self.config.batch_size,
self.model_config.n_head // self.model_config.tensor_parallel_degree,
1,
self.config.total_max_length,
]
)
self.attention_mask = (
alibi_encoder + (1 - self.attention_mask) * paddle.finfo(self.attention_mask.dtype).min
)
self.tgt_generation_mask = (
alibi_decoder + (1 - self.tgt_generation_mask) * paddle.finfo(self.tgt_generation_mask.dtype).min
)
else:
for i in range(inputs["input_ids"].shape[0]):
length = inputs["seq_len_encoder"][i][0]
self.attention_mask[i, 0, :length, :length] = paddle.tril(
paddle.ones(shape=(length, length), dtype=self.config.dtype)
)
if pre_caches_length > 0:
if self.config.prefix_path is None:
prefix_attention_mask = paddle.zeros(
[1, length, pre_caches_length], dtype=self.attention_mask.dtype
)
else:
prefix_attention_mask = paddle.ones(
[1, length, pre_caches_length], dtype=self.attention_mask.dtype
)
post_attention_mask = paddle.tril(
paddle.ones(shape=(length, length), dtype=self.attention_mask.dtype)
).unsqueeze_(axis=0)
self.attention_mask[i, 0, :length, : length + pre_caches_length] = paddle.concat(
[prefix_attention_mask, post_attention_mask], axis=2
)
if self.config.prefix_path is None:
self.tgt_generation_mask[i, 0, 0, pre_caches_length : length + pre_caches_length] = paddle.ones(
shape=[1, length], dtype="float16"
)
else:
self.tgt_generation_mask[i, 0, 0, : length + pre_caches_length] = paddle.ones(
shape=[1, length + pre_caches_length], dtype=self.config.dtype
)
inputs["pre_ids"] = self.pre_ids
inputs["attention_mask"] = self.attention_mask
inputs["tgt_generation_mask"] = self.tgt_generation_mask
if pre_caches_length > 0:
if self.config.mode == "dynamic":
inputs["pre_caches"] = self.pre_caches
else:
for i in range(len(self.pre_caches)):
inputs["pre_caches_{}".format(i)] = self.pre_caches[i].numpy()
return inputs
class StaticInferencePredictor(InferencePredictorMixin, BasePredictor):
def __init__(
self,
config: PredictorArgument,
cache_kvs_shape: list[list[int]],
tokenizer: PretrainedTokenizer = None,
):
self.cache_kvs_shape = cache_kvs_shape
BasePredictor.__init__(self, config, tokenizer)
InferencePredictorMixin.__init__(self, config, tokenizer)
self.predictor = self._create_predictor(config)
def _create_predictor(self, predictor_args: PredictorArgument):
if not is_paddlenlp_ops_available():
raise ValueError(
"you should install the paddlenlp ops to run inference predictor, "
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)
# register the custome ops
import_module("paddlenlp_ops.encode_rotary_qk")
import_module("paddlenlp_ops.get_padding_offset")
import_module("paddlenlp_ops.qkv_transpose_split")
import_module("paddlenlp_ops.rebuild_padding")
import_module("paddlenlp_ops.transpose_remove_padding")
import_module("paddlenlp_ops.write_cache_kv")
infer_model_path = get_infer_model_path(predictor_args.model_name_or_path, predictor_args.model_prefix)
config = paddle.inference.Config(infer_model_path + ".pdmodel", infer_model_path + ".pdiparams")
config.switch_ir_optim(True)
# remove `gpu_cpu_map_matmul_v2_to_matmul_pass` to avoid mapping matmul_v2 -> matmul op
if predictor_args.dtype == "bfloat16":
config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
config.enable_new_executor()
if self.tensor_parallel_degree > 1:
trainer_endpoints = fleet.worker_endpoints()
current_endpoint = trainer_endpoints[self.tensor_parallel_rank]
dist_config = config.dist_config()
dist_config.set_ranks(self.tensor_parallel_degree, self.tensor_parallel_rank)
dist_config.set_endpoints(trainer_endpoints, current_endpoint)
dist_config.enable_dist_model(True)
dist_config.set_comm_init_config(os.path.join(predictor_args.model_name_or_path, "rank_mapping.csv"))
config.set_dist_config(dist_config)
predictor = paddle.inference.create_predictor(config)
return predictor
@paddle.no_grad()
def _infer(self, inputs):
for k, v in inputs.items():
input_tensor = self.predictor.get_input_handle(k)
if "mask" in k or "position" in k:
input_tensor.share_external_data(v)
else:
if paddle.is_tensor(v):
v = v.numpy()
input_tensor.copy_from_cpu(v)
for i in range(len(self.cache_kvs_shape)):
input_tensor = self.predictor.get_input_handle("cache_kvs_" + str(i))
input_tensor.share_external_data(self.cache_kvs[i])
input_tensor = self.predictor.get_input_handle("pre_ids")
input_tensor.share_external_data(self.pre_ids)
self.predictor.run()
class DygraphInferencePredictor(InferencePredictorMixin, BasePredictor):
def __init__(
self,
config: PredictorArgument,
model: PretrainedModel = None,
tokenizer: PretrainedTokenizer = None,
):
self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size, config.total_max_length)
BasePredictor.__init__(self, config, tokenizer)
InferencePredictorMixin.__init__(self, config, tokenizer)
self.model = model
@paddle.no_grad()
def _infer(self, inputs: dict[str, paddle.Tensor]):
for key in inputs.keys():
if paddle.is_tensor(inputs[key]):
continue
if isinstance(inputs[key], list):
if paddle.is_tensor(inputs[key]):
continue
inputs[key] = [paddle.to_tensor(item) for item in inputs[key]]
else:
inputs[key] = paddle.to_tensor(inputs[key])
inputs["cache_kvs"] = self.cache_kvs
self.model.generate(
**inputs,
)
return None
def create_predictor(
predictor_args: PredictorArgument,
model_args: ModelArgument,
tensor_parallel_degree: int = 1,
tensor_parallel_rank: int = 0,
):
tokenizer = AutoTokenizer.from_pretrained(predictor_args.model_name_or_path)
# init chat_template for tokenizer
init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template)
# TODO(wj-Mcat): fix llama tokenzier pad_token bug
if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.unk_token
# update config parameter for inference predictor
if predictor_args.decode_strategy == "greedy_search":
predictor_args.top_p = 0.0
predictor_args.temperature = 1.0
tensor_parallel_rank, tensor_parallel_degree = init_dist_env()
if not predictor_args.inference_model:
if predictor_args.mode == "dynamic":
if model_args.model_type == "gpt-3":
sys.path.append("./gpt-3")
from modeling import GPTForCausalLM
model = GPTForCausalLM.from_pretrained(
predictor_args.model_name_or_path,
dtype=predictor_args.dtype,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
)
elif model_args.model_type == "ernie-3.5-se":
sys.path.append("./ernie-3.5-se")
from modeling import Ernie35ForCausalLM
tensor_parallel_degree = paddle.distributed.get_world_size()
tensor_parallel_rank = paddle.distributed.get_rank()
model = Ernie35ForCausalLM.from_pretrained(
predictor_args.model_name_or_path,
dtype=predictor_args.dtype,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
)
else:
model = AutoModelForCausalLM.from_pretrained(
predictor_args.model_name_or_path,
dtype=predictor_args.dtype,
tensor_parallel_degree=tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
)
predictor = DygraphPredictor(predictor_args, model=model, tokenizer=tokenizer)
elif predictor_args.mode == "static":
predictor = StaticGraphPredictor(predictor_args, tokenizer=tokenizer)
else:
raise ValueError("the `mode` should be one of [dynamic, static]")
else:
if predictor_args.mode == "dynamic":
# TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank
config.weight_only_quant_bits = -1
config.quant_type = None
config.model_name_or_path = ""
if predictor_args.quant_type is not None and predictor_args.quant_type.startswith("weight_only_int"):
weight_only_quant_bits = int(predictor_args.quant_type[-1])
config.weight_only_quant_bits = weight_only_quant_bits
config.quant_type = predictor_args.quant_type
if config.quantization_config.quant_type is not None and "a8w8" in config.quantization_config.quant_type:
config.model_name_or_path = predictor_args.model_name_or_path
config.quant_type = config.quantization_config.quant_type
# Turn on GEMM int8 kernel tuning
paddle.base.core.enable_autotune()
paddle.base.core.update_autotune_status()
if "llama" in config.architectures[0].lower():
if model_args.model_type == "llama-img2txt":
# we use llama for img2txt.
from paddlenlp.experimental.transformers import (
LlamaForMiniGPT4InferenceModel as LlamaInferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
)
model = LlamaInferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
elif "opt" in config.architectures[0].lower():
if model_args.model_type == "opt-img2txt":
# we use opt for img2txt.
from paddlenlp.experimental.transformers import (
OPTForBlip2InferenceModel as OPTInferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
OPTForCausalLMInferenceModel as OPTInferenceModel,
)
model = OPTInferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
elif "chatglmv2forcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel as Model,
)
model = Model.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
elif "chatglmforcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMForCausalLMInferenceModel,
)
model = ChatGLMForCausalLMInferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
)
model.eval()
elif "bloom" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
BloomForCausalLMInferenceModel,
)
model = BloomForCausalLMInferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
)
cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
model.eval()
elif "gpt" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
GPTForCausalLMInferenceModel,
)
model = GPTForCausalLMInferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
)
model.eval()
else:
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]")
predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer)
elif predictor_args.mode == "static":
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
if "llama" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel,
)
cache_kvs_shape = LlamaForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "chatglmv2forcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMv2ForCausalLMInferenceModel,
)
cache_kvs_shape = ChatGLMv2ForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "chatglmforcausallm" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
ChatGLMForCausalLMInferenceModel,
)
cache_kvs_shape = ChatGLMForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "bloom" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
BloomForCausalLMInferenceModel,
)
cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "gpt" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
GPTForCausalLMInferenceModel,
)
cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
else:
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]")
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
else:
raise ValueError("the `mode` should be one of [dynamic, static]")
return predictor
def predict():
parser = PdArgumentParser((PredictorArgument, ModelArgument))
predictor_args, model_args = parser.parse_args_into_dataclasses()
paddle.set_device(predictor_args.device)
paddle.set_default_dtype(predictor_args.dtype)
tensor_parallel_degree = paddle.distributed.get_world_size()
if tensor_parallel_degree > 1:
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": tensor_parallel_degree,
"pp_degree": 1,
"sharding_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
predictor = create_predictor(predictor_args, model_args)
source_texts = []
target_texts = []
if model_args.data_file:
with open(model_args.data_file, "r", encoding="utf-8") as f:
for line in f:
example = json.loads(line)
source_texts.append(example["src"])
target_texts.append(example["tgt"])
else:
source_texts = ["解释一下“温故而知新”", "你好,请问你是谁?"]
target_texts = ["", ""]
batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size)
batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size)
with open(model_args.output_file, "w", encoding="utf-8") as f:
for bs, batch_source_text in enumerate(batch_source_texts):
outputs = predictor.predict(batch_source_text)
if predictor.tensor_parallel_rank > 0:
continue
for output, source, target in zip(outputs, batch_source_texts[bs], batch_target_texts[bs]):
print("***********Source**********")
print(source)
print("***********Target**********")
print(target)
print("***********Output**********")
print(output)
out = {"src": source, "tgt": target, "output": output}
f.write(json.dumps(out, ensure_ascii=False) + "\n")
if predictor_args.benchmark:
benchmark(predictor, predictor_args, model_args)
def benchmark(predictor, predictor_args, model_args):
# Just construct a simple benchmark input. We pad input to the src_length.
test_texts = "hello world, how are you?"
benchmark_texts = [test_texts + "<pad>" * predictor_args.src_length for _ in range(predictor_args.batch_size)]
batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size)
print("***********Start Benchmark**********")
warmup_time = 10
test_time = 100
print("***********Start Warmup**********")
for _ in range(warmup_time):
for bs, batch_source_text in enumerate(batch_benchmark_texts):
outputs = predictor.predict(batch_source_text)
print("***********Start Speed Test**********")
start = time.perf_counter()
output_tokens = 0
for _ in range(test_time):
for bs, batch_source_text in enumerate(batch_benchmark_texts):
outputs = predictor.predict(batch_source_text)
output_tokens += sum([len(output) for output in outputs])
end = time.perf_counter()
print("Avg Elapse time is: ", (end - start) / test_time)
print("Output tokens is: ", output_tokens)
print(
"Input length is: {}, Output length is: {}, bs is: {}, IPS: {:.3f} tokens/s, QPS: {:.3f} requests/s. ".format(
predictor_args.src_length,
predictor_args.max_length,
predictor_args.batch_size,
(output_tokens / (end - start)),