-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmodeling.py
executable file
·1144 lines (1002 loc) · 48 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2021 Graphcore Ltd.
#
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file has been modified by Graphcore Ltd.
# It has been modified to run the application on IPU hardware.
# The original file was provided at
# https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/bert/modeling.py
import logging
import paddle
import paddle.nn as nn
import paddle.tensor as tensor
import paddle.nn.functional as F
from paddle.nn import TransformerEncoder, Linear, Layer, Embedding, LayerNorm, Tanh
from paddlenlp.transformers import PretrainedModel, register_base_model
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)s %(levelname)s %(message)s",
datefmt='%Y-%m-%d %H:%M:%S %a')
__all__ = [
'BertModel', "BertPretrainedModel", 'BertForPretraining',
'BertPretrainingCriterion', 'BertPretrainingHeads',
'BertForSequenceClassification', 'BertForTokenClassification',
'BertForQuestionAnswering', 'BertPretrainingAccuracy'
]
def set_serialize_factor(serialize_factor):
main_prog = paddle.static.default_main_program()
op = main_prog.current_block().ops[-1]
op._set_attr('serialize_factor', serialize_factor)
op._set_attr('serialize_mode', 'input_channels')
class BertEmbeddings(Layer):
"""
Include embeddings from word, position and token_type embeddings
"""
def __init__(self,
vocab_size,
hidden_size=768,
hidden_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_embeddings,
hidden_size)
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
self.layer_norm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None, position_ids=None):
if position_ids is None:
#ones = paddle.ones_like(input_ids, dtype="int64")
ones = paddle.ones_like(input_ids, dtype="float32")
seq_length = paddle.cumsum(ones, axis=-1)
seq_length = paddle.cast(seq_length, dtype="int32")
position_ids = seq_length - ones
position_ids.stop_gradient = True
if token_type_ids is None:
#token_type_ids = paddle.zeros_like(input_ids, dtype="int64")
token_type_ids = paddle.zeros_like(input_ids, dtype="int32")
input_embedings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = input_embedings + position_embeddings + token_type_embeddings
embeddings = self.layer_norm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertPooler(Layer):
"""
Pool the result of BertEncoder.
"""
def __init__(self, hidden_size, pool_act="tanh", num_ipus=1):
super(BertPooler, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
self.pool_act = pool_act
self.num_ipus = num_ipus
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
with paddle.fluid.ipu_shard(
ipu_index=self.num_ipus - 1, ipu_stage=self.num_ipus - 1):
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
with paddle.fluid.ipu_shard(ipu_index=0, ipu_stage=self.num_ipus):
if self.pool_act == "tanh":
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPretrainedModel(PretrainedModel):
"""
An abstract class for pretrained BERT models. It provides BERT related
`model_config_file`, `resource_files_names`, `pretrained_resource_files_map`,
`pretrained_init_configuration`, `base_model_prefix` for downloading and
loading pretrained models. See `PretrainedModel` for more details.
"""
model_config_file = "model_config.json"
pretrained_init_configuration = {
"bert-base-uncased": {
"vocab_size": 30522,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"bert-large-uncased": {
"vocab_size": 30522,
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"intermediate_size": 4096,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"bert-base-multilingual-uncased": {
"vocab_size": 105879,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"bert-base-cased": {
"vocab_size": 28996,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"bert-base-chinese": {
"vocab_size": 21128,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"bert-base-multilingual-cased": {
"vocab_size": 119547,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"bert-large-cased": {
"vocab_size": 28996,
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"intermediate_size": 4096,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"bert-wwm-chinese": {
"vocab_size": 21128,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"bert-wwm-ext-chinese": {
"vocab_size": 21128,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"macbert-base-chinese": {
"vocab_size": 21128,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"macbert-large-chinese": {
"vocab_size": 21128,
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"intermediate_size": 4096,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
"simbert-base-chinese": {
"vocab_size": 13685,
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02,
"pad_token_id": 0,
},
}
resource_files_names = {"model_state": "model_state.pdparams"}
pretrained_resource_files_map = {
"model_state": {
"bert-base-uncased":
"https://paddlenlp.bj.bcebos.com/models/transformers/bert-base-uncased.pdparams",
"bert-large-uncased":
"https://paddlenlp.bj.bcebos.com/models/transformers/bert-large-uncased.pdparams",
"bert-base-multilingual-uncased":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert-base-multilingual-uncased.pdparams",
"bert-base-cased":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-cased.pdparams",
"bert-base-chinese":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-chinese.pdparams",
"bert-base-multilingual-cased":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-base-multilingual-cased.pdparams",
"bert-large-cased":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-large-cased.pdparams",
"bert-wwm-chinese":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-wwm-chinese.pdparams",
"bert-wwm-ext-chinese":
"http://paddlenlp.bj.bcebos.com/models/transformers/bert/bert-wwm-ext-chinese.pdparams",
"macbert-base-chinese":
"https://paddlenlp.bj.bcebos.com/models/transformers/macbert/macbert-base-chinese.pdparams",
"macbert-large-chinese":
"https://paddlenlp.bj.bcebos.com/models/transformers/macbert/macbert-large-chinese.pdparams",
"simbert-base-chinese":
"https://paddlenlp.bj.bcebos.com/models/transformers/simbert/simbert-base-chinese-v1.pdparams",
}
}
base_model_prefix = "bert"
def init_weights(self, layer):
""" Initialization hook """
if isinstance(layer, (nn.Linear, nn.Embedding)):
# In the dygraph mode, use the `set_value` to reset the parameter directly,
# and reset the `state_dict` to update parameter in static mode.
if isinstance(layer.weight, paddle.Tensor):
layer.weight.set_value(
paddle.tensor.normal(
mean=0.0,
std=self.initializer_range
if hasattr(self, "initializer_range") else
self.bert.config["initializer_range"],
shape=layer.weight.shape))
elif isinstance(layer, nn.LayerNorm):
layer._epsilon = 1e-3
@register_base_model
class BertModel(BertPretrainedModel):
"""
The bare BERT Model transformer outputting raw hidden-states without any specific head on top.
This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`.
Refer to the superclass documentation for the generic methods.
This model is also a Paddle `paddle.nn.Layer <https://www.paddlepaddle.org.cn/documentation
/docs/en/api/paddle/fluid/dygraph/layers/Layer_en.html>`__ subclass. Use it as a regular Paddle Layer
and refer to the Paddle documentation for all matter related to general usage and behavior.
Args:
vocab_size (int):
Vocabulary size of `inputs_ids` in `BertModel`. Also is the vocab size of token embedding matrix.
Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `BertModel`.
hidden_size (int, optional):
Dimensionality of the embedding layer, encoder layer and pooler layer. Defaults to `768`.
num_hidden_layers (int, optional):
Number of hidden layers in the Transformer encoder. Defaults to `12`.
num_attention_heads (int, optional):
Number of attention heads for each attention layer in the Transformer encoder.
Defaults to `12`.
intermediate_size (int, optional):
Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors
to ff layers are firstly projected from `hidden_size` to `intermediate_size`,
and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`.
Defaults to `3072`.
hidden_act (str, optional):
The non-linear activation function in the feed-forward layer.
``"gelu"``, ``"relu"`` and any other paddle supported activation functions
are supported. Defaults to `"gelu"`.
hidden_dropout_prob (float, optional):
The dropout probability for all fully connected layers in the embeddings and encoder.
Defaults to `0.1`.
attention_probs_dropout_prob (float, optional):
The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target.
Defaults to `0.1`.
max_position_embeddings (int, optional):
The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input
sequence. Defaults to `512`.
type_vocab_size (int, optional):
The vocabulary size of `token_type_ids`.
Defaults to `16`.
initializer_range (float, optional):
The standard deviation of the normal initializer.
Defaults to 0.02.
.. note::
A normal_initializer initializes weight matrices as normal distributions.
See :meth:`BertPretrainedModel.init_weights()` for how weights are initialized in `BertModel`.
pad_token_id (int, optional):
The index of padding token in the token vocabulary.
Defaults to `0`.
pooled_act (str, optional):
The non-linear activation function in the pooling layer.
Defaults to `"tanh"`.
"""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
pad_token_id=0,
pool_act="tanh",
num_ipus=4,
layer_per_ipu=4,
encoder_start_ipu=1):
super(BertModel, self).__init__()
self.pad_token_id = pad_token_id
self.num_ipus = num_ipus
self.layer_per_ipu = layer_per_ipu
self.encoder_start_ipu = encoder_start_ipu
self.num_hidden_layers = num_hidden_layers
self.initializer_range = initializer_range
self.embeddings = BertEmbeddings(
vocab_size, hidden_size, hidden_dropout_prob,
max_position_embeddings, type_vocab_size)
encoder_layer = nn.TransformerEncoderLayer(
hidden_size,
num_attention_heads,
intermediate_size,
dropout=hidden_dropout_prob,
activation=hidden_act,
approximate=True,
attn_dropout=attention_probs_dropout_prob,
act_dropout=0)
self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
self.pooler = BertPooler(hidden_size, pool_act, num_ipus)
self.apply(self.init_weights)
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None,
output_hidden_states=False,
need_pooler=True):
r'''
The BertModel forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
Indices of input sequence tokens in the vocabulary. They are
numerical representations of tokens that build the input sequence.
Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
token_type_ids (Tensor, optional):
Segment token indices to indicate different portions of the inputs.
Selected in the range ``[0, type_vocab_size - 1]``.
If `type_vocab_size` is 2, which means the inputs have two portions.
Indices can either be 0 or 1:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
Defaults to `None`, which means we don't add segment embeddings.
position_ids(Tensor, optional):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
max_position_embeddings - 1]``.
Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`.
attention_mask (Tensor, optional):
Mask used in multi-head attention to avoid performing attention on to some unwanted positions,
usually the paddings or the subsequent positions.
Its data type can be int, float and bool.
When the data type is bool, the `masked` tokens have `False` values and the others have `True` values.
When the data type is int, the `masked` tokens have `0` values and the others have `1` values.
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
Defaults to `None`, which means nothing needed to be prevented attention to.
output_hidden_states (bool, optional):
Whether to return the output of each hidden layers.
Defaults to `False`.
Returns:
tuple: Returns tuple (`sequence_output`, `pooled_output`) or (`encoder_outputs`, `pooled_output`).
With the fields:
- `sequence_output` (Tensor):
Sequence of hidden-states at the last layer of the model.
It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].
- `pooled_output` (Tensor):
The output of first token (`[CLS]`) in sequence.
We "pool" the model by simply taking the hidden state corresponding to the first token.
Its data type should be float32 and its shape is [batch_size, hidden_size].
- `encoder_outputs` (List(Tensor)):
A list of Tensor containing hidden-states of the model at each hidden layer in the Transformer encoder.
The length of the list is `num_hidden_layers`.
Each Tensor has a data type of float32 and its shape is [batch_size, sequence_length, hidden_size].
Example:
.. code-block::
import paddle
from paddlenlp.transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-wwm-chinese')
model = BertModel.from_pretrained('bert-wwm-chinese')
inputs = tokenizer("欢迎使用百度飞桨!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
output = model(**inputs)
'''
logging.info("Emb Layer - ipu_index:%d, ipu_stage:%d" % (0, 0))
with paddle.fluid.ipu_shard(ipu_index=0, ipu_stage=0):
with paddle.static.name_scope("Embedding"):
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id
).astype(self.pooler.dense.weight.dtype) * -1e4,
axis=[1, 2])
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
if output_hidden_states: # False
output = embedding_output
encoder_outputs = []
for mod in self.encoder.layers:
output = mod(output, src_mask=attention_mask)
encoder_outputs.append(output)
if self.encoder.norm is not None:
encoder_outputs[-1] = self.encoder.norm(encoder_outputs[-1])
pooled_output = self.pooler(encoder_outputs[-1])
else:
sequence_output = embedding_output
for i, encoder_layer in enumerate(self.encoder.layers):
ipu_index = self.encoder_start_ipu + i // self.layer_per_ipu
logging.info("Enc-Layer - ipu_index:%d, ipu_stage:%d" %
(ipu_index, ipu_index))
with paddle.fluid.ipu_shard(
ipu_index=ipu_index, ipu_stage=ipu_index):
with paddle.static.name_scope("Encoder_" + str(i)):
sequence_output = encoder_layer(sequence_output,
attention_mask)
if need_pooler:
logging.info("Poo Layer - ipu_index:%d, ipu_stage:%d" %
(0, self.num_ipus))
#with paddle.fluid.ipu_shard(ipu_index=0, ipu_stage=self.num_ipus):
with paddle.static.name_scope("Pooler"):
pooled_output = self.pooler(sequence_output)
else:
return sequence_output
if output_hidden_states:
return encoder_outputs, pooled_output
else:
return sequence_output, pooled_output
class BertForQuestionAnswering(BertPretrainedModel):
"""
Bert Model with a span classification head on top for extractive question-answering tasks like
SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and
`span end logits`).
Args:
bert (:class:`BertModel`):
An instance of BertModel.
dropout (float, optional):
The dropout probability for output of BERT.
If None, use the same value as `hidden_dropout_prob` of `BertModel`
instance `bert`. Defaults to `None`.
"""
def __init__(self, bert, num_ipus, dropout=None):
super(BertForQuestionAnswering, self).__init__()
self.bert = bert # allow bert to be config
self.dropout = nn.Dropout(dropout if dropout is not None else
self.bert.config["hidden_dropout_prob"])
self.classifier = nn.Linear(self.bert.config["hidden_size"], 2)
self.apply(self.init_weights)
self.num_ipus = num_ipus
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
input_mask=None):
r"""
The BertForQuestionAnswering forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
See :class:`BertModel`.
token_type_ids (Tensor, optional):
See :class:`BertModel`.
Returns:
tuple: Returns tuple (`start_logits`, `end_logits`).
With the fields:
- `start_logits` (Tensor):
A tensor of the input token classification logits, indicates the start position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
- `end_logits` (Tensor):
A tensor of the input token classification logits, indicates the end position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
Example:
.. code-block::
import paddle
from paddlenlp.transformers.bert.modeling import BertForQuestionAnswering
from paddlenlp.transformers.bert.tokenizer import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForQuestionAnswering.from_pretrained('bert-base-cased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
outputs = model(**inputs)
start_logits = outputs[0]
end_logits =outputs[1]
"""
sequence_output = self.bert(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=input_mask,
need_pooler=False)
ipu_index = self.num_ipus - 1
ipu_stage = self.num_ipus - 1
logging.info("Head Layer - ipu_index:%d, ipu_stage:%d" %
(ipu_index, ipu_stage))
with paddle.fluid.ipu_shard(ipu_index=ipu_index, ipu_stage=ipu_stage):
with paddle.static.name_scope("SQURD"):
logits = self.classifier(sequence_output)
# not support unstack
# logits = paddle.transpose(logits, perm=[2, 0, 1])
# start_logits, end_logits = paddle.unstack(x=logits, axis=0)
# just for squad
start_logits = paddle.slice(
input=logits, axes=[2], starts=[0], ends=[1])
end_logits = paddle.slice(
input=logits, axes=[2], starts=[1], ends=[2])
start_logits = paddle.squeeze(start_logits, axis=-1)
end_logits = paddle.squeeze(end_logits, axis=-1)
return start_logits, end_logits
class BertForSequenceClassification(BertPretrainedModel):
"""
Bert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g.
for GLUE tasks.
Args:
bert (:class:`BertModel`):
An instance of BertModel.
num_classes (int, optional):
The number of classes. Defaults to `2`.
dropout (float, optional):
The dropout probability for output of BERT.
If None, use the same value as `hidden_dropout_prob` of `BertModel`
instance `bert`. Defaults to None.
"""
def __init__(self, bert, num_classes=2, dropout=None):
super(BertForSequenceClassification, self).__init__()
self.num_classes = num_classes
self.bert = bert # allow bert to be config
self.dropout = nn.Dropout(dropout if dropout is not None else
self.bert.config["hidden_dropout_prob"])
self.classifier = nn.Linear(self.bert.config["hidden_size"],
num_classes)
self.apply(self.init_weights)
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
r"""
The BertForSequenceClassification forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
See :class:`BertModel`.
token_type_ids (Tensor, optional):
See :class:`BertModel`.
position_ids(Tensor, optional):
See :class:`BertModel`.
attention_mask (list, optional):
See :class:`BertModel`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input text classification logits.
Shape as `[batch_size, num_classes]` and dtype as float32.
Example:
.. code-block::
import paddle
from paddlenlp.transformers.bert.modeling import BertForSequenceClassification
from paddlenlp.transformers.bert.tokenizer import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForSequenceClassification.from_pretrained('bert-base-cased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
outputs = model(**inputs)
logits = outputs[0]
"""
_, pooled_output = self.bert(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
class BertForTokenClassification(BertPretrainedModel):
"""
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
Args:
bert (:class:`BertModel`):
An instance of BertModel.
num_classes (int, optional):
The number of classes. Defaults to `2`.
dropout (float, optional):
The dropout probability for output of BERT.
If None, use the same value as `hidden_dropout_prob` of `BertModel`
instance `bert`. Defaults to None.
"""
def __init__(self, bert, num_classes=2, dropout=None):
super(BertForTokenClassification, self).__init__()
self.num_classes = num_classes
self.bert = bert # allow bert to be config
self.dropout = nn.Dropout(dropout if dropout is not None else
self.bert.config["hidden_dropout_prob"])
self.classifier = nn.Linear(self.bert.config["hidden_size"],
num_classes)
self.apply(self.init_weights)
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
r"""
The BertForSequenceClassification forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
See :class:`BertModel`.
token_type_ids (Tensor, optional):
See :class:`BertModel`.
position_ids(Tensor, optional):
See :class:`BertModel`.
attention_mask (list, optional):
See :class:`BertModel`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input token classification logits.
Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`.
Example:
.. code-block::
import paddle
from paddlenlp.transformers.bert.modeling import BertForTokenClassification
from paddlenlp.transformers.bert.tokenizer import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForTokenClassification.from_pretrained('bert-base-cased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
outputs = model(**inputs)
logits = outputs[0]
"""
sequence_output, _ = self.bert(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return logits
class BertLMPredictionHead(Layer):
"""
Bert Model with a `language modeling` head on top for CLM fine-tuning.
"""
def __init__(self,
hidden_size,
vocab_size,
activation,
approximate,
embedding_weights=None):
super(BertLMPredictionHead, self).__init__()
self.transform = nn.Linear(hidden_size, hidden_size)
self.act_name = activation
self.approximate = approximate
self.activation = getattr(nn.functional, activation)
self.layer_norm = nn.LayerNorm(hidden_size)
self.decoder_weight = self.create_parameter(
shape=[vocab_size, hidden_size],
dtype=self.transform.weight.dtype,
is_bias=False) if embedding_weights is None else embedding_weights
self.decoder_bias = self.create_parameter(
shape=[vocab_size], dtype=self.decoder_weight.dtype, is_bias=True)
def forward(self, hidden_states, masked_positions=None):
if masked_positions is not None:
hidden_states = paddle.reshape(hidden_states,
[-1, hidden_states.shape[-1]])
hidden_states = paddle.tensor.gather(hidden_states,
masked_positions)
# gather masked tokens might be more quick
hidden_states = self.transform(hidden_states)
if self.act_name == "gelu":
hidden_states = self.activation(hidden_states, self.approximate)
else:
hidden_states = self.activation(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = paddle.tensor.matmul(
hidden_states, self.decoder_weight,
transpose_y=True) # + self.decoder_bias
return hidden_states
class BertPretrainingHeads(Layer):
"""
Perform language modeling task and next sentence classification task.
Args:
hidden_size (int):
See :class:`BertModel`.
vocab_size (int):
See :class:`BertModel`.
activation (str):
Activation function used in the language modeling task.
embedding_weights (Tensor, optional):
Decoding weights used to map hidden_states to logits of the masked token prediction.
Its data type should be float32 and its shape is [vocab_size, hidden_size].
Defaults to `None`, which means use the same weights of the embedding layer.
"""
def __init__(self,
hidden_size,
vocab_size,
activation,
embedding_weights=None):
super(BertPretrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(
hidden_size, vocab_size, activation, True, embedding_weights)
self.seq_relationship = nn.Linear(hidden_size, 2)
def forward(self, sequence_output, pooled_output, masked_positions=None):
"""
Args:
sequence_output(Tensor):
Sequence of hidden-states at the last layer of the model.
It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].
pooled_output(Tensor):
The output of first token (`[CLS]`) in sequence.
We "pool" the model by simply taking the hidden state corresponding to the first token.
Its data type should be float32 and its shape is [batch_size, hidden_size].
masked_positions(Tensor, optional):
A tensor indicates positions to be masked in the position embedding.
Its data type should be int64 and its shape is [batch_size, mask_token_num].
`mask_token_num` is the number of masked tokens. It should be no bigger than `sequence_length`.
Defaults to `None`, which means we output hidden-states of all tokens in masked token prediction.
Returns:
tuple: Returns tuple (``prediction_scores``, ``seq_relationship_score``).
With the fields:
- `prediction_scores` (Tensor):
The scores of masked token prediction. Its data type should be float32.
If `masked_positions` is None, its shape is [batch_size, sequence_length, vocab_size].
Otherwise, its shape is [batch_size, mask_token_num, vocab_size].
- `seq_relationship_score` (Tensor):
The scores of next sentence prediction.
Its data type should be float32 and its shape is [batch_size, 2].
"""
prediction_scores = self.predictions(sequence_output, masked_positions)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertForPretraining(BertPretrainedModel):
"""
Bert Model with pretraining tasks on top.
Args:
bert (:class:`BertModel`):
An instance of :class:`BertModel`.
"""
def __init__(self, bert):
super(BertForPretraining, self).__init__()
self.bert = bert
self.cls = BertPretrainingHeads(
self.bert.config["hidden_size"],
self.bert.config["vocab_size"],
self.bert.config["hidden_act"],
embedding_weights=self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_weights)
def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None,
masked_positions=None):
r"""
Args:
input_ids (Tensor):
See :class:`BertModel`.
token_type_ids (Tensor, optional):
See :class:`BertModel`.
position_ids (Tensor, optional):
See :class:`BertModel`.
attention_mask (Tensor, optional):
See :class:`BertModel`.
masked_positions(Tensor, optional):
See :class:`BertPretrainingHeads`.
Returns:
tuple: Returns tuple (``prediction_scores``, ``seq_relationship_score``).
With the fields:
- `prediction_scores` (Tensor):
The scores of masked token prediction. Its data type should be float32.
If `masked_positions` is None, its shape is [batch_size, sequence_length, vocab_size].
Otherwise, its shape is [batch_size, mask_token_num, vocab_size].
- `seq_relationship_score` (Tensor):
The scores of next sentence prediction.
Its data type should be float32 and its shape is [batch_size, 2].
"""
#with paddle.static.amp.fp16_guard():
outputs = self.bert(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
sequence_output, pooled_output = outputs[:2]
logging.info("CLS Layer - ipu_index:%d, ipu_stage:%d" %
(0, self.bert.num_ipus))
with paddle.fluid.ipu_shard(ipu_index=0, ipu_stage=self.bert.num_ipus):
with paddle.static.name_scope("CLS_MLM"):
prediction_scores, seq_relationship_score = self.cls(
sequence_output, pooled_output, masked_positions)
return prediction_scores, seq_relationship_score
class BertPretrainingAccuracy(paddle.nn.Layer):
def __init__(self,
mlm_label,