Skip to content

Commit f966ef5

Browse files
theraysmithcopybara-github
authored andcommitted
Removed duplicated tensor sizes from weights.h by changing the constructor used for MatPtrT
PiperOrigin-RevId: 705045279
1 parent aed1739 commit f966ef5

File tree

7 files changed

+99
-108
lines changed

7 files changed

+99
-108
lines changed

BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ cc_test(
269269
"@googletest//:gtest_main",
270270
"//compression:compress",
271271
"@highway//:hwy",
272-
"@highway//:thread_pool",
273272
],
274273
)
275274

backprop/backward_scalar_test.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,14 @@ TEST(BackPropTest, LayerVJP) {
411411
using T = double;
412412
using TC = std::complex<T>;
413413
ModelConfig config = TestConfig();
414+
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
415+
/*reshape_att=*/false);
414416
const size_t kOutputSize = config.seq_len * config.model_dim;
415-
LayerWeightsPtrs<T> weights(config.layer_configs[0]);
416-
LayerWeightsPtrs<T> grad(config.layer_configs[0]);
417+
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
418+
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
417419
ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
418420
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
419-
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0]);
421+
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
420422
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
421423
MatStorageT<T> y("y", kOutputSize, 1);
422424
MatStorageT<T> dy("dy", kOutputSize, 1);

backprop/test_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ template <typename T>
9393
class WeightsWrapper {
9494
public:
9595
explicit WeightsWrapper(const ModelConfig& config)
96-
: pool_(0), weights_(config, pool_) {
96+
: pool_(0), weights_(config) {
9797
weights_.Allocate(data_, pool_);
9898
}
9999

compression/compress.h

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,21 +219,27 @@ class MatPtrT : public MatPtr {
219219
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
220220
MatPtrT(const std::string& name, const TensorInfo* tensor)
221221
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
222-
HWY_ASSERT(tensor != nullptr);
223-
cols_ = tensor->shape.back();
224-
rows_ = 1;
225-
if (tensor->cols_take_extra_dims) {
226-
// The columns eat the extra dimensions.
227-
rows_ = tensor->shape[0];
228-
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
229-
cols_ *= tensor->shape[i];
230-
}
222+
if (tensor == nullptr) {
223+
cols_ = 0;
224+
rows_ = 0;
231225
} else {
232-
// The rows eat the extra dimensions.
233-
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
234-
rows_ *= tensor->shape[i];
226+
cols_ = tensor->shape.back();
227+
rows_ = 1;
228+
if (tensor->cols_take_extra_dims) {
229+
// The columns eat the extra dimensions.
230+
rows_ = tensor->shape[0];
231+
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
232+
cols_ *= tensor->shape[i];
233+
}
234+
} else {
235+
// The rows eat the extra dimensions.
236+
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
237+
rows_ *= tensor->shape[i];
238+
}
235239
}
236240
}
241+
stride_ = cols_;
242+
num_elements_ = rows_ * cols_;
237243
}
238244

239245
// Copying allowed as the metadata is small.

gemma/tensor_index_test.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "gemma/weights.h"
1414
#include "util/basics.h"
1515
#include "hwy/aligned_allocator.h"
16-
#include "hwy/contrib/thread_pool/thread_pool.h"
1716

1817
namespace gcpp {
1918
namespace {
@@ -22,7 +21,6 @@ namespace {
2221
// and that the TensorIndex returns the correct shape and name for the tensor,
2322
// for all models.
2423
TEST(TensorIndexTest, FindName) {
25-
hwy::ThreadPool pool(4);
2624
for (Model model : kAllModels) {
2725
fprintf(stderr, "Testing model %d\n", static_cast<int>(model));
2826
ModelConfig config = ConfigFromModel(model);
@@ -44,7 +42,7 @@ TEST(TensorIndexTest, FindName) {
4442
/*split_and_reshape=*/false);
4543
}
4644
// For each tensor in any model, exactly one TensorIndex should find it.
47-
ModelWeightsPtrs<SfpStream> weights(config, pool);
45+
ModelWeightsPtrs<SfpStream> weights(config);
4846
ModelWeightsPtrs<SfpStream>::ForEachTensor(
4947
{&weights}, ForEachType::kInitNoToc,
5048
[&tensor_indexes](const char* name, hwy::Span<MatPtr*> tensors) {

gemma/weights.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,18 +186,18 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
186186
hwy::ThreadPool& pool) {
187187
switch (weight_type) {
188188
case Type::kF32:
189-
float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_, pool);
189+
float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_);
190190
break;
191191
case Type::kBF16:
192-
bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_, pool);
192+
bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_);
193193
break;
194194
case Type::kSFP:
195195
sfp_weights_ =
196-
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_, pool);
196+
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_);
197197
break;
198198
case Type::kNUQ:
199199
nuq_weights_ =
200-
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_, pool);
200+
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_);
201201
break;
202202
default:
203203
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));

gemma/weights.h

Lines changed: 70 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "compression/shared.h"
3131
#include "gemma/common.h"
3232
#include "gemma/configs.h"
33+
#include "gemma/tensor_index.h"
3334
#include "hwy/aligned_allocator.h"
3435
#include "hwy/base.h"
3536
#include "hwy/contrib/thread_pool/thread_pool.h"
@@ -56,73 +57,48 @@ enum class ForEachType {
5657
template <class Weight>
5758
struct LayerWeightsPtrs {
5859
// Large data is constructed separately.
59-
explicit LayerWeightsPtrs(const LayerConfig& config)
60-
: attn_vec_einsum_w("att_ein", config.heads * config.model_dim,
61-
config.qkv_dim),
62-
qkv_einsum_w("qkv_ein",
63-
(config.heads + 2 * config.kv_heads) * config.qkv_dim,
64-
config.model_dim),
65-
qkv_einsum_w1("qkv1_w", config.heads * config.qkv_dim,
66-
config.model_dim),
67-
qkv_einsum_w2("qkv2_w", 2 * config.kv_heads * config.qkv_dim,
68-
config.model_dim),
69-
attention_output_biases(
70-
"attn_ob", 1,
71-
config.softmax_attn_output_biases ? config.model_dim : 0),
72-
griffin(
73-
{.linear_x_w = {"gr_lin_x_w", config.griffin_dim,
74-
config.griffin_dim},
75-
.linear_x_biases = {"gr_lin_x_b", 1, config.griffin_dim},
76-
.linear_y_w = {"gr_lin_y_w", config.griffin_dim,
77-
config.griffin_dim},
78-
.linear_y_biases = {"gr_lin_y_b", 1, config.griffin_dim},
79-
.linear_out_w = {"gr_lin_out_w", config.griffin_dim,
80-
config.griffin_dim},
81-
.linear_out_biases = {"gr_lin_out_b", 1, config.griffin_dim},
82-
.conv_w = {"gr_conv_w", config.conv1d_width, config.griffin_dim},
83-
.conv_biases = {"gr_conv_b", 1, config.griffin_dim},
84-
.gate_w = {"gr_gate_w", 2 * config.griffin_dim,
85-
config.griffin_dim / config.heads},
86-
.gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2},
87-
.a = {"gr_a", 1, config.griffin_dim}}),
60+
explicit LayerWeightsPtrs(const LayerConfig& config,
61+
const TensorIndex& tensor_index)
62+
: attn_vec_einsum_w("att_ein", tensor_index),
63+
qkv_einsum_w("qkv_ein", tensor_index),
64+
qkv_einsum_w1("qkv1_w", tensor_index),
65+
qkv_einsum_w2("qkv2_w", tensor_index),
66+
attention_output_biases("attn_ob", tensor_index),
67+
griffin({.linear_x_w = {"gr_lin_x_w", tensor_index},
68+
.linear_x_biases = {"gr_lin_x_b", tensor_index},
69+
.linear_y_w = {"gr_lin_y_w", tensor_index},
70+
.linear_y_biases = {"gr_lin_y_b", tensor_index},
71+
.linear_out_w = {"gr_lin_out_w", tensor_index},
72+
.linear_out_biases = {"gr_lin_out_b", tensor_index},
73+
.conv_w = {"gr_conv_w", tensor_index},
74+
.conv_biases = {"gr_conv_b", tensor_index},
75+
.gate_w = {"gr_gate_w", tensor_index},
76+
.gate_biases = {"gr_gate_b", tensor_index},
77+
.a = {"gr_a", tensor_index}}),
8878
// MultiHeadDotProductAttention.
89-
vit({.attn_out_w = {"attn_out_w", config.model_dim,
90-
config.heads * config.qkv_dim},
91-
.attn_out_b = {"attn_out_b", 1, config.model_dim},
92-
.qkv_einsum_w = {"qkv_ein_w",
93-
(config.heads + 2 * config.kv_heads) *
94-
config.qkv_dim,
95-
config.model_dim},
96-
.qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
97-
config.qkv_dim},
98-
.linear_0_w = {"linear_0_w", config.ff_hidden_dim,
99-
config.model_dim},
100-
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
101-
.linear_1_w = {"linear_1_w", config.model_dim,
102-
config.ff_hidden_dim},
103-
.linear_1_b = {"linear_1_b", 1, config.model_dim},
104-
.layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
105-
.layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
106-
.layer_norm_1_bias = {"ln_1_bias", 1, config.model_dim},
107-
.layer_norm_1_scale = {"ln_1_scale", 1, config.model_dim}}),
108-
gating_einsum_w("gating_ein", 2 * config.ff_hidden_dim,
109-
config.model_dim),
110-
gating_einsum_w1("gating1_w", config.ff_hidden_dim, config.model_dim),
111-
gating_einsum_w2("gating2_w", config.ff_hidden_dim, config.model_dim),
112-
linear_w("linear_w", config.model_dim, config.ff_hidden_dim),
113-
pre_attention_norm_scale("pre_att_ns", 1, config.model_dim),
114-
pre_ffw_norm_scale("pre_ff_ns", 1, config.model_dim),
115-
post_attention_norm_scale(
116-
"post_att_ns", 1,
117-
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
118-
post_ffw_norm_scale(
119-
"post_ff_ns", 1,
120-
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
121-
ffw_gating_biases("ffw_gat_b", 1,
122-
config.ff_biases ? 2 * config.ff_hidden_dim : 0),
123-
ffw_output_biases("ffw_out_b", 1,
124-
config.ff_biases ? config.model_dim : 0),
125-
att_weights("att_w", config.model_dim, config.heads * config.qkv_dim),
79+
vit({.attn_out_w = {"attn_out_w", tensor_index},
80+
.attn_out_b = {"attn_out_b", tensor_index},
81+
.qkv_einsum_w = {"qkv_ein_w", tensor_index},
82+
.qkv_einsum_b = {"qkv_ein_b", tensor_index},
83+
.linear_0_w = {"linear_0_w", tensor_index},
84+
.linear_0_b = {"linear_0_b", tensor_index},
85+
.linear_1_w = {"linear_1_w", tensor_index},
86+
.linear_1_b = {"linear_1_b", tensor_index},
87+
.layer_norm_0_bias = {"ln_0_bias", tensor_index},
88+
.layer_norm_0_scale = {"ln_0_scale", tensor_index},
89+
.layer_norm_1_bias = {"ln_1_bias", tensor_index},
90+
.layer_norm_1_scale = {"ln_1_scale", tensor_index}}),
91+
gating_einsum_w("gating_ein", tensor_index),
92+
gating_einsum_w1("gating1_w", tensor_index),
93+
gating_einsum_w2("gating2_w", tensor_index),
94+
linear_w("linear_w", tensor_index),
95+
pre_attention_norm_scale("pre_att_ns", tensor_index),
96+
pre_ffw_norm_scale("pre_ff_ns", tensor_index),
97+
post_attention_norm_scale("post_att_ns", tensor_index),
98+
post_ffw_norm_scale("post_ff_ns", tensor_index),
99+
ffw_gating_biases("ffw_gat_b", tensor_index),
100+
ffw_output_biases("ffw_out_b", tensor_index),
101+
att_weights("att_w", tensor_index),
126102
layer_config(config) {}
127103
~LayerWeightsPtrs() = default;
128104

@@ -342,28 +318,38 @@ struct LayerWeightsPtrs {
342318

343319
template <class Weight>
344320
struct ModelWeightsPtrs {
345-
ModelWeightsPtrs(const ModelConfig& config, hwy::ThreadPool& pool)
346-
: embedder_input_embedding("c_embedding", config.vocab_size,
347-
config.model_dim),
348-
final_norm_scale("c_final_norm", 1, config.model_dim),
349-
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
350-
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
351-
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
352-
vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim,
353-
config.patch_width * config.patch_width * 3),
354-
vit_img_pos_embedding("img_pos_emb", config.vit_seq_len,
355-
config.vit_model_dim),
356-
vit_img_head_bias("img_head_bias", 1, config.model_dim),
357-
vit_img_head_kernel("img_head_kernel", config.model_dim,
358-
config.vit_model_dim),
321+
explicit ModelWeightsPtrs(const ModelConfig& config)
322+
: ModelWeightsPtrs(
323+
config,
324+
TensorIndex(config, /*llm_layer_idx=*/-1, /*vit_layer_idx=*/-1,
325+
/*reshape_att=*/false)) {}
326+
ModelWeightsPtrs(const ModelConfig& config, const TensorIndex& tensor_index)
327+
: embedder_input_embedding("c_embedding", tensor_index),
328+
final_norm_scale("c_final_norm", tensor_index),
329+
vit_encoder_norm_bias("enc_norm_bias", tensor_index),
330+
vit_encoder_norm_scale("enc_norm_scale", tensor_index),
331+
vit_img_embedding_bias("img_emb_bias", tensor_index),
332+
vit_img_embedding_kernel("img_emb_kernel", tensor_index),
333+
vit_img_pos_embedding("img_pos_emb", tensor_index),
334+
vit_img_head_bias("img_head_bias", tensor_index),
335+
vit_img_head_kernel("img_head_kernel", tensor_index),
359336
scale_names(config.scale_names),
360337
weights_config(config) {
361338
c_layers.reserve(config.layer_configs.size());
362-
for (const auto& layer_config : config.layer_configs) {
363-
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
339+
for (int index = 0; index < static_cast<int>(config.layer_configs.size());
340+
++index) {
341+
const auto& layer_config = config.layer_configs[index];
342+
TensorIndex tensor_index(config, index, /*vit_layer_idx=*/-1,
343+
/*reshape_att=*/false);
344+
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index));
364345
}
365-
for (const auto& layer_config : config.vit_layer_configs) {
366-
vit_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
346+
for (int index = 0;
347+
index < static_cast<int>(config.vit_layer_configs.size()); ++index) {
348+
const auto& layer_config = config.vit_layer_configs[index];
349+
TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index,
350+
/*reshape_att=*/false);
351+
vit_layers.push_back(
352+
LayerWeightsPtrs<Weight>(layer_config, tensor_index));
367353
}
368354
}
369355

0 commit comments

Comments
 (0)