30
30
#include " compression/shared.h"
31
31
#include " gemma/common.h"
32
32
#include " gemma/configs.h"
33
+ #include " gemma/tensor_index.h"
33
34
#include " hwy/aligned_allocator.h"
34
35
#include " hwy/base.h"
35
36
#include " hwy/contrib/thread_pool/thread_pool.h"
@@ -56,73 +57,48 @@ enum class ForEachType {
56
57
template <class Weight >
57
58
struct LayerWeightsPtrs {
58
59
// Large data is constructed separately.
59
- explicit LayerWeightsPtrs (const LayerConfig& config)
60
- : attn_vec_einsum_w(" att_ein" , config.heads * config.model_dim,
61
- config.qkv_dim),
62
- qkv_einsum_w(" qkv_ein" ,
63
- (config.heads + 2 * config.kv_heads) * config.qkv_dim,
64
- config.model_dim),
65
- qkv_einsum_w1(" qkv1_w" , config.heads * config.qkv_dim,
66
- config.model_dim),
67
- qkv_einsum_w2(" qkv2_w" , 2 * config.kv_heads * config.qkv_dim,
68
- config.model_dim),
69
- attention_output_biases(
70
- " attn_ob" , 1 ,
71
- config.softmax_attn_output_biases ? config.model_dim : 0 ),
72
- griffin(
73
- {.linear_x_w = {" gr_lin_x_w" , config.griffin_dim ,
74
- config.griffin_dim },
75
- .linear_x_biases = {" gr_lin_x_b" , 1 , config.griffin_dim },
76
- .linear_y_w = {" gr_lin_y_w" , config.griffin_dim ,
77
- config.griffin_dim },
78
- .linear_y_biases = {" gr_lin_y_b" , 1 , config.griffin_dim },
79
- .linear_out_w = {" gr_lin_out_w" , config.griffin_dim ,
80
- config.griffin_dim },
81
- .linear_out_biases = {" gr_lin_out_b" , 1 , config.griffin_dim },
82
- .conv_w = {" gr_conv_w" , config.conv1d_width , config.griffin_dim },
83
- .conv_biases = {" gr_conv_b" , 1 , config.griffin_dim },
84
- .gate_w = {" gr_gate_w" , 2 * config.griffin_dim ,
85
- config.griffin_dim / config.heads },
86
- .gate_biases = {" gr_gate_b" , 1 , config.griffin_dim * 2 },
87
- .a = {" gr_a" , 1 , config.griffin_dim }}),
60
+ explicit LayerWeightsPtrs (const LayerConfig& config,
61
+ const TensorIndex& tensor_index)
62
+ : attn_vec_einsum_w(" att_ein" , tensor_index),
63
+ qkv_einsum_w(" qkv_ein" , tensor_index),
64
+ qkv_einsum_w1(" qkv1_w" , tensor_index),
65
+ qkv_einsum_w2(" qkv2_w" , tensor_index),
66
+ attention_output_biases(" attn_ob" , tensor_index),
67
+ griffin({.linear_x_w = {" gr_lin_x_w" , tensor_index},
68
+ .linear_x_biases = {" gr_lin_x_b" , tensor_index},
69
+ .linear_y_w = {" gr_lin_y_w" , tensor_index},
70
+ .linear_y_biases = {" gr_lin_y_b" , tensor_index},
71
+ .linear_out_w = {" gr_lin_out_w" , tensor_index},
72
+ .linear_out_biases = {" gr_lin_out_b" , tensor_index},
73
+ .conv_w = {" gr_conv_w" , tensor_index},
74
+ .conv_biases = {" gr_conv_b" , tensor_index},
75
+ .gate_w = {" gr_gate_w" , tensor_index},
76
+ .gate_biases = {" gr_gate_b" , tensor_index},
77
+ .a = {" gr_a" , tensor_index}}),
88
78
// MultiHeadDotProductAttention.
89
- vit({.attn_out_w = {" attn_out_w" , config.model_dim ,
90
- config.heads * config.qkv_dim },
91
- .attn_out_b = {" attn_out_b" , 1 , config.model_dim },
92
- .qkv_einsum_w = {" qkv_ein_w" ,
93
- (config.heads + 2 * config.kv_heads ) *
94
- config.qkv_dim ,
95
- config.model_dim },
96
- .qkv_einsum_b = {" qkv_ein_b" , (config.heads + 2 * config.kv_heads ),
97
- config.qkv_dim },
98
- .linear_0_w = {" linear_0_w" , config.ff_hidden_dim ,
99
- config.model_dim },
100
- .linear_0_b = {" linear_0_b" , 1 , config.ff_hidden_dim },
101
- .linear_1_w = {" linear_1_w" , config.model_dim ,
102
- config.ff_hidden_dim },
103
- .linear_1_b = {" linear_1_b" , 1 , config.model_dim },
104
- .layer_norm_0_bias = {" ln_0_bias" , 1 , config.model_dim },
105
- .layer_norm_0_scale = {" ln_0_scale" , 1 , config.model_dim },
106
- .layer_norm_1_bias = {" ln_1_bias" , 1 , config.model_dim },
107
- .layer_norm_1_scale = {" ln_1_scale" , 1 , config.model_dim }}),
108
- gating_einsum_w(" gating_ein" , 2 * config.ff_hidden_dim,
109
- config.model_dim),
110
- gating_einsum_w1(" gating1_w" , config.ff_hidden_dim, config.model_dim),
111
- gating_einsum_w2(" gating2_w" , config.ff_hidden_dim, config.model_dim),
112
- linear_w(" linear_w" , config.model_dim, config.ff_hidden_dim),
113
- pre_attention_norm_scale(" pre_att_ns" , 1 , config.model_dim),
114
- pre_ffw_norm_scale(" pre_ff_ns" , 1 , config.model_dim),
115
- post_attention_norm_scale(
116
- " post_att_ns" , 1 ,
117
- config.post_norm == PostNormType::Scale ? config.model_dim : 0 ),
118
- post_ffw_norm_scale(
119
- " post_ff_ns" , 1 ,
120
- config.post_norm == PostNormType::Scale ? config.model_dim : 0 ),
121
- ffw_gating_biases(" ffw_gat_b" , 1 ,
122
- config.ff_biases ? 2 * config.ff_hidden_dim : 0 ),
123
- ffw_output_biases(" ffw_out_b" , 1 ,
124
- config.ff_biases ? config.model_dim : 0 ),
125
- att_weights(" att_w" , config.model_dim, config.heads * config.qkv_dim),
79
+ vit({.attn_out_w = {" attn_out_w" , tensor_index},
80
+ .attn_out_b = {" attn_out_b" , tensor_index},
81
+ .qkv_einsum_w = {" qkv_ein_w" , tensor_index},
82
+ .qkv_einsum_b = {" qkv_ein_b" , tensor_index},
83
+ .linear_0_w = {" linear_0_w" , tensor_index},
84
+ .linear_0_b = {" linear_0_b" , tensor_index},
85
+ .linear_1_w = {" linear_1_w" , tensor_index},
86
+ .linear_1_b = {" linear_1_b" , tensor_index},
87
+ .layer_norm_0_bias = {" ln_0_bias" , tensor_index},
88
+ .layer_norm_0_scale = {" ln_0_scale" , tensor_index},
89
+ .layer_norm_1_bias = {" ln_1_bias" , tensor_index},
90
+ .layer_norm_1_scale = {" ln_1_scale" , tensor_index}}),
91
+ gating_einsum_w(" gating_ein" , tensor_index),
92
+ gating_einsum_w1(" gating1_w" , tensor_index),
93
+ gating_einsum_w2(" gating2_w" , tensor_index),
94
+ linear_w(" linear_w" , tensor_index),
95
+ pre_attention_norm_scale(" pre_att_ns" , tensor_index),
96
+ pre_ffw_norm_scale(" pre_ff_ns" , tensor_index),
97
+ post_attention_norm_scale(" post_att_ns" , tensor_index),
98
+ post_ffw_norm_scale(" post_ff_ns" , tensor_index),
99
+ ffw_gating_biases(" ffw_gat_b" , tensor_index),
100
+ ffw_output_biases(" ffw_out_b" , tensor_index),
101
+ att_weights(" att_w" , tensor_index),
126
102
layer_config(config) {}
127
103
~LayerWeightsPtrs () = default ;
128
104
@@ -342,28 +318,38 @@ struct LayerWeightsPtrs {
342
318
343
319
template <class Weight >
344
320
struct ModelWeightsPtrs {
345
- ModelWeightsPtrs (const ModelConfig& config, hwy::ThreadPool& pool)
346
- : embedder_input_embedding(" c_embedding" , config.vocab_size,
347
- config.model_dim),
348
- final_norm_scale (" c_final_norm" , 1 , config.model_dim),
349
- vit_encoder_norm_bias(" enc_norm_bias" , 1 , config.vit_model_dim),
350
- vit_encoder_norm_scale(" enc_norm_scale" , 1 , config.vit_model_dim),
351
- vit_img_embedding_bias(" img_emb_bias" , 1 , config.vit_model_dim),
352
- vit_img_embedding_kernel(" img_emb_kernel" , config.vit_model_dim,
353
- config.patch_width * config.patch_width * 3 ),
354
- vit_img_pos_embedding(" img_pos_emb" , config.vit_seq_len,
355
- config.vit_model_dim),
356
- vit_img_head_bias(" img_head_bias" , 1 , config.model_dim),
357
- vit_img_head_kernel(" img_head_kernel" , config.model_dim,
358
- config.vit_model_dim),
321
+ explicit ModelWeightsPtrs (const ModelConfig& config)
322
+ : ModelWeightsPtrs(
323
+ config,
324
+ TensorIndex (config, /* llm_layer_idx=*/ -1 , /* vit_layer_idx=*/ -1 ,
325
+ /* reshape_att=*/ false )) {}
326
+ ModelWeightsPtrs (const ModelConfig& config, const TensorIndex& tensor_index)
327
+ : embedder_input_embedding(" c_embedding" , tensor_index),
328
+ final_norm_scale(" c_final_norm" , tensor_index),
329
+ vit_encoder_norm_bias(" enc_norm_bias" , tensor_index),
330
+ vit_encoder_norm_scale(" enc_norm_scale" , tensor_index),
331
+ vit_img_embedding_bias(" img_emb_bias" , tensor_index),
332
+ vit_img_embedding_kernel(" img_emb_kernel" , tensor_index),
333
+ vit_img_pos_embedding(" img_pos_emb" , tensor_index),
334
+ vit_img_head_bias(" img_head_bias" , tensor_index),
335
+ vit_img_head_kernel(" img_head_kernel" , tensor_index),
359
336
scale_names(config.scale_names),
360
337
weights_config(config) {
361
338
c_layers.reserve (config.layer_configs .size ());
362
- for (const auto & layer_config : config.layer_configs ) {
363
- c_layers.push_back (LayerWeightsPtrs<Weight>(layer_config));
339
+ for (int index = 0 ; index < static_cast <int >(config.layer_configs .size ());
340
+ ++index) {
341
+ const auto & layer_config = config.layer_configs [index];
342
+ TensorIndex tensor_index (config, index, /* vit_layer_idx=*/ -1 ,
343
+ /* reshape_att=*/ false );
344
+ c_layers.push_back (LayerWeightsPtrs<Weight>(layer_config, tensor_index));
364
345
}
365
- for (const auto & layer_config : config.vit_layer_configs ) {
366
- vit_layers.push_back (LayerWeightsPtrs<Weight>(layer_config));
346
+ for (int index = 0 ;
347
+ index < static_cast <int >(config.vit_layer_configs .size ()); ++index) {
348
+ const auto & layer_config = config.vit_layer_configs [index];
349
+ TensorIndex tensor_index (config, /* llm_layer_idx=*/ -1 , index,
350
+ /* reshape_att=*/ false );
351
+ vit_layers.push_back (
352
+ LayerWeightsPtrs<Weight>(layer_config, tensor_index));
367
353
}
368
354
}
369
355
0 commit comments