Skip to content

Commit 462d141

Browse files
authored
Add files via upload
1 parent 1566af4 commit 462d141

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+85
-47
lines changed

bert4keras3/Layers_add/Attentions.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def call(self, inputs, mask=None, **kwargs):
173173
o = self.o_dense(ops.reshape(o, [b,s,-1]))
174174
# 返回结果
175175

176-
176+
177177
if use_cache:
178178
return o,cache
179179
if self.return_attention_scores:
@@ -282,17 +282,18 @@ def pay_attention_to(self, inputs, mask=None, **kwargs):
282282
a = a * ops.cast(1/np.sqrt(self.key_size), dtype=qw.dtype)
283283
if a_bias is not None and ops.ndim(a_bias) == 3:
284284
a_bias = align(a_bias, [0, -2, -1], ops.ndim(a))
285-
286-
A,mask = attention_normalize(a, v_mask, -1, self.normalization, a_bias)
285+
A = attention_normalize(a, v_mask, -1, self.normalization, a_bias)
287286

288287
if self.attention_dropout:
289-
A,mask = self.dropout(A)
288+
A = self.dropout(A)
289+
290290
# 完成输出
291291
if self.query_head!=self.heads:
292292
o = ops.einsum("bkgts,bskh->btkgh", A, vw)
293293
o = ops.reshape(o, (b, s, self.query_head, -1))
294294
else:
295295
o = ops.einsum('bhjk,bkhd->bjhd', A, vw)
296+
296297
if p_bias == 'typical_relative':
297298
o = o + ops.einsum('bhjk,jkd->bjhd', A, position_bias)
298299

bert4keras3/Layers_add/FFN.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def __init__(
1313
activation='relu',
1414
use_bias=True,
1515
kernel_initializer='glorot_uniform',
16+
1617
**kwargs
1718
):
1819
super(FeedForward, self).__init__(**kwargs)
@@ -98,34 +99,35 @@ class LLamaFeedForward(FeedForward):
9899
def build(self, input_shape):
99100
super(FeedForward, self).build(input_shape)
100101
output_dim = input_shape[-1]
101-
self._feedforward_intermediate_dense = keras.layers.Dense(
102+
self._feedforward_gate_dense = keras.layers.Dense(
102103
self.units,
103104
kernel_initializer=self.kernel_initializer,
104105
use_bias=self.use_bias,
105-
name="feedforward_intermediate_dense",
106+
name="feedforward_gate_dense",
106107
)
107-
self._feedforward_gate_dense = keras.layers.Dense(
108+
self._feedforward_intermediate_dense = keras.layers.Dense(
108109
self.units,
109110
kernel_initializer=self.kernel_initializer,
110111
use_bias=self.use_bias,
111-
name="feedforward_gate_dense",
112+
name="feedforward_intermediate_dense",
112113
)
114+
113115

114116
self._feedforward_output_dense = keras.layers.Dense(
115117
output_dim,
116118
kernel_initializer=self.kernel_initializer,
117-
use_bias=False,
118-
dtype=self.use_bias,
119+
use_bias=self.use_bias,
119120
name="feedforward_output_dense",
120121
)
121122
@recompute_grad
122123
def call(self, x):
124+
123125
activation = activations.get(self.activation[0])
124126
gate_output = self._feedforward_gate_dense(x)
125-
gate_output = ops.cast(gate_output, "float32")
127+
#gate_output = ops.cast(gate_output, "float32")
126128
gate_output = activation(gate_output)
127-
gate_output = ops.cast(gate_output, x.dtype)
129+
#gate_output = ops.cast(gate_output, x.dtype)
128130
x = self._feedforward_intermediate_dense(x)
129131
x = self._feedforward_output_dense(ops.multiply(x, gate_output))
130-
return x
132+
return x#
131133

bert4keras3/Layers_add/LayerNorms.py

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def call(self, x):
204204
x = ops.cast(x, "float32")
205205
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
206206
x = x * ops.rsqrt(var + self.epsilon)
207+
207208
return ops.cast(x, self.compute_dtype) * self.scale
208209

209210
def get_config(self):
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

bert4keras3/Models/LLamas.py

+59-27
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,23 @@ def __init__(self, with_lm=True,
77
use_EinsumDense = True,
88
flatten_o_dense=False,
99
use_bias = False,
10+
input_scale =True,
11+
share_emebding=True,
12+
rope_mode='keras',
1013
**kwargs):
1114
super(Gemma, self).__init__(**kwargs)
1215
self.with_lm = with_lm
1316
self.max_wavelength = max_wavelength
1417
self.scaling_factor = scaling_factor
18+
self.rope_mode = rope_mode
19+
self.share_emebding = share_emebding
1520
self.use_dense_bias = use_dense_bias
21+
self.input_scale = input_scale
1622
self.flatten_o_dense = flatten_o_dense
1723
self.use_EinsumDense = use_EinsumDense
1824
self.use_bias = use_bias
25+
self.layer_norm_type = RMSNormalization
26+
self.ffn_type = GemmaFeedForward
1927
def apply_embeddings(self, inputs):
2028
inputs = inputs[:]
2129

@@ -55,12 +63,13 @@ def apply_embeddings(self, inputs):
5563

5664
def mul(x):
5765
return x * ops.cast(ops.sqrt(self.hidden_size), x.dtype)
58-
x = self.apply(
59-
inputs=x,
60-
layer=Lambda,
61-
function=mul,
62-
name='Multiply'
63-
)
66+
if self.input_scale:
67+
x = self.apply(
68+
inputs=x,
69+
layer=Lambda,
70+
function=mul,
71+
name='Multiply'
72+
)
6473

6574
x = self.apply(
6675
inputs=x,
@@ -92,7 +101,7 @@ def apply_main_layers(self, inputs, index):
92101

93102
x = self.apply(
94103
inputs=x,
95-
layer=RMSNormalization,
104+
layer=self.layer_norm_type,
96105
epsilon=1e-6,
97106
name='%s-Norm' % attention_name
98107
)
@@ -137,19 +146,11 @@ def apply_main_layers(self, inputs, index):
137146

138147
x = self.apply(
139148
inputs=x,
140-
layer=RMSNormalization,
149+
layer=self.layer_norm_type,
141150
epsilon=1e-6,
142151
name='%s-Norm' % feed_forward_name
143152
)
144-
x = self.apply(
145-
inputs=x,
146-
layer=GemmaFeedForward,
147-
units=self.intermediate_size,
148-
activation=self.hidden_act,
149-
use_bias=self.use_dense_bias,
150-
kernel_initializer=self.initializer,
151-
name=feed_forward_name
152-
)
153+
x = self.apply_ffn(x,feed_forward_name)
153154
x = self.apply(
154155
inputs=x,
155156
layer=Dropout,
@@ -170,7 +171,7 @@ def apply_final_layers(self, inputs):
170171

171172
x = self.apply(
172173
inputs=x,
173-
layer=RMSNormalization,
174+
layer=self.layer_norm_type,
174175
epsilon=1e-6,
175176
name='Output-Norm'
176177
)
@@ -183,19 +184,42 @@ def apply_final_layers(self, inputs):
183184

184185
if self.with_lm:
185186
lm_activation = 'softmax' if self.with_lm is True else self.with_lm
186-
x = self.apply(
187+
if self.share_emebding:
188+
x = self.apply(
189+
inputs=x,
190+
layer=Embedding,
191+
arguments={'mode': 'dense'},
192+
name='Embedding-Token'
193+
)
194+
x = self.apply(
195+
inputs=x,
196+
layer=Activation,
197+
activation=lm_activation,
198+
name='Output-LM-Activation'
199+
)
200+
else:
201+
x = self.apply(
187202
inputs=x,
188-
layer=Embedding,
189-
arguments={'mode': 'dense'},
190-
name='Embedding-Token'
191-
)
192-
x = self.apply(
193-
inputs=x,
194-
layer=Activation,
203+
layer=Dense,
204+
units=self.vocab_size,
195205
activation=lm_activation,
196-
name='Output-LM-Activation'
206+
use_bias=False,
207+
kernel_initializer=self.initializer,
208+
name='Decoder-Output-LM'
197209
)
210+
198211

212+
return x
213+
def apply_ffn(self,x,feed_forward_name):
214+
x = self.apply(
215+
inputs=x,
216+
layer=self.ffn_type,
217+
units=self.intermediate_size,
218+
activation=self.hidden_act,
219+
use_bias=self.use_dense_bias,
220+
kernel_initializer=self.initializer,
221+
name=feed_forward_name
222+
)
199223
return x
200224
def apply_main_cache_layers(self, inputs, index,self_cache_update_index,
201225
cross_cache_update_index=None,
@@ -243,3 +267,11 @@ def apply_main_cache_layers(self, inputs, index,self_cache_update_index,
243267

244268
return [x,caches]
245269

270+
class Llama(Gemma):
271+
def __init__(self, input_scale =False,use_EinsumDense=False,
272+
share_emebding=False,**kwargs):
273+
super(Llama, self).__init__(input_scale=input_scale,
274+
use_EinsumDense=use_EinsumDense,
275+
share_emebding=share_emebding,**kwargs)
276+
self.layer_norm_type = LlamaLayerNorm
277+
self.ffn_type = LLamaFeedForward
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

bert4keras3/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#! -*- coding: utf-8 -*-
22

3-
__version__ = '1.1.2'
3+
__version__ = '1.3'
44

55
from bert4keras3 import backend,layers,models,snippets,tokenizers
66
from bert4keras3.backend import ops
-16 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
2.41 KB
Binary file not shown.
-45 Bytes
Binary file not shown.
-29.3 KB
Binary file not shown.
0 Bytes
Binary file not shown.
-36.4 KB
Binary file not shown.
70 Bytes
Binary file not shown.
-16 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
-16 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

bert4keras3/backend.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -365,23 +365,23 @@ def attention_normalize(a, mask=None, axis=-1, method='softmax', bias=None):
365365
att_mask = mask
366366
for i in range(ops.ndim(a)-ops.ndim(mask)):
367367
att_mask = ops.expand_dims(att_mask,0)
368-
return ops.cast(keras.layers.Softmax(dtype="float32",axis=axis)(a,mask=att_mask),ori_dtype),mask
368+
return ops.cast(keras.layers.Softmax(dtype="float32",axis=axis)(a,mask=att_mask),ori_dtype)
369369
a, mask = sequence_masking(a, mask, -np.inf, axis, bias, True)
370370

371371
if method == 'softmax' :
372-
return ops.softmax(a,axis=axis),mask
372+
return ops.softmax(a,axis=axis)
373373
else:
374374
if mask is None:
375375
l = ops.cast(ops.shape(a)[-1], keras.mixed_precision.dtype_policy().name)
376376
else:
377377
mask = ops.cast(mask, keras.mixed_precision.dtype_policy().name)
378378
l = ops.sum(mask, axis=axis, keepdims=True)
379379
if method == 'squared_relu':
380-
return ops.relu(a)**2 / l,mask
380+
return ops.relu(a)**2 / l
381381
elif method == 'softmax_plus':
382382
l = ops.maximum(l, 16) # 极短序列scale反而不好
383-
return ops.softmax(a * ops.log(l) / np.log(512), axis=axis),mask
384-
return a,mask
383+
return ops.softmax(a * ops.log(l) / np.log(512), axis=axis)
384+
return a
385385

386386

387387
def sinusoidal_embeddings(pos, dim, base=10000):

bert4keras3/models.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def build_transformer_model(
9393
'mt5.1.1_encoder': T5_Encoder,
9494
'mt5.1.1_decoder': T5_Decoder,
9595
'gemma':Gemma,
96-
96+
'llama':Llama,
97+
'qwen':Llama,
98+
'yi':Llama,
9799
'misakat5':MisakaT5,
98100
}
99101

0 commit comments

Comments
 (0)