@@ -65,22 +65,6 @@ def _qkv_pre_load_convert(module: "GQA", state_dict, prefix: str, *args, **kwarg
65
65
)
66
66
67
67
68
- def _qkv_save_convert (module : "GQA" , state_dict , prefix : str , * args , ** kwargs ) -> Dict : # pylint: disable=W0613
69
- wq_name , wk_name , wv_name , fused_name = (
70
- f"{ prefix } wq.weight" ,
71
- f"{ prefix } wk.weight" ,
72
- f"{ prefix } wv.weight" ,
73
- f"{ prefix } wqkv.weight" ,
74
- )
75
-
76
- if module .enable_qkv_fusion :
77
- state_dict [wq_name ], state_dict [wk_name ], state_dict [wv_name ] = split_fused_wqkv_weight (
78
- state_dict .pop (fused_name ), * args , ** kwargs
79
- )
80
-
81
- return state_dict
82
-
83
-
84
68
class MHA (nn .Module ):
85
69
"""
86
70
Multi-head self-attention and cross-attention.
@@ -462,15 +446,15 @@ def __init__(
462
446
if enable_qkv_fusion :
463
447
assert bias is False , "Fuesd wqkv only support bias is False."
464
448
self .wqkv = new_linear ("wqkv" , embed_dim , q_dim + 2 * self .kv_dim , bias , ** factory_kwargs )
465
- self ._register_load_state_dict_pre_hook (
466
- partial (_qkv_pre_load_convert , q_dim = q_dim , kv_dim = self .kv_dim ), with_module = True
467
- )
468
- self ._register_state_dict_hook (partial (_qkv_save_convert , q_dim = q_dim , kv_dim = self .kv_dim ))
469
449
else :
470
450
self .wq = new_linear ("wq" , embed_dim , q_dim , bias , ** factory_kwargs )
471
451
self .wk = new_linear ("wk" , embed_dim , self .kv_dim , bias , ** factory_kwargs )
472
452
self .wv = new_linear ("wv" , embed_dim , self .kv_dim , bias , ** factory_kwargs )
473
453
454
+ self ._register_load_state_dict_pre_hook (
455
+ partial (_qkv_pre_load_convert , q_dim = q_dim , kv_dim = self .kv_dim ), with_module = True
456
+ )
457
+
474
458
self .inner_attn = SelfAttention (
475
459
causal = causal , softmax_scale = softmax_scale , attention_dropout = dropout , layer_idx = layer_idx
476
460
)
0 commit comments