From 39cb5ac2ce7533d3daee89071c857e79d69114c0 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Wed, 11 Dec 2024 03:33:30 -0800 Subject: [PATCH] Removed duplicated tensor sizes from weights.h by changing the constructor used for MatPtrT PiperOrigin-RevId: 705045279 --- BUILD.bazel | 1 - backprop/backward_scalar_test.cc | 8 +- backprop/test_util.h | 2 +- compression/compress.h | 30 +++--- compression/compress_weights.cc | 4 +- gemma/tensor_index_test.cc | 4 +- gemma/weights.cc | 8 +- gemma/weights.h | 154 ++++++++++++++----------------- 8 files changed, 101 insertions(+), 110 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 47d5852..d75b62b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -269,7 +269,6 @@ cc_test( "@googletest//:gtest_main", "//compression:compress", "@highway//:hwy", - "@highway//:thread_pool", ], ) diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index b5e39db..d99a067 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -411,12 +411,14 @@ TEST(BackPropTest, LayerVJP) { using T = double; using TC = std::complex; ModelConfig config = TestConfig(); + TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1, + /*reshape_att=*/false); const size_t kOutputSize = config.seq_len * config.model_dim; - LayerWeightsPtrs weights(config.layer_configs[0]); - LayerWeightsPtrs grad(config.layer_configs[0]); + LayerWeightsPtrs weights(config.layer_configs[0], tensor_index); + LayerWeightsPtrs grad(config.layer_configs[0], tensor_index); ForwardLayer forward(config.layer_configs[0], config.seq_len); ForwardLayer backward(config.layer_configs[0], config.seq_len); - LayerWeightsPtrs c_weights(config.layer_configs[0]); + LayerWeightsPtrs c_weights(config.layer_configs[0], tensor_index); ForwardLayer c_forward(config.layer_configs[0], config.seq_len); MatStorageT y("y", kOutputSize, 1); MatStorageT dy("dy", kOutputSize, 1); diff --git a/backprop/test_util.h b/backprop/test_util.h index 86f99b1..a83e3d5 100644 --- a/backprop/test_util.h +++ b/backprop/test_util.h @@ -93,7 +93,7 @@ template class WeightsWrapper { public: explicit WeightsWrapper(const ModelConfig& config) - : pool_(0), weights_(config, pool_) { + : pool_(0), weights_(config) { weights_.Allocate(data_, pool_); } diff --git a/compression/compress.h b/compression/compress.h index b717ac8..8d4635b 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -219,21 +219,27 @@ class MatPtrT : public MatPtr { : MatPtrT(name, tensor_index.FindName(name)) {} MatPtrT(const std::string& name, const TensorInfo* tensor) : MatPtr(name, TypeEnum(), sizeof(MatT), 0, 0) { - HWY_ASSERT(tensor != nullptr); - cols_ = tensor->shape.back(); - rows_ = 1; - if (tensor->cols_take_extra_dims) { - // The columns eat the extra dimensions. - rows_ = tensor->shape[0]; - for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { - cols_ *= tensor->shape[i]; - } + if (tensor == nullptr) { + cols_ = 0; + rows_ = 0; } else { - // The rows eat the extra dimensions. - for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { - rows_ *= tensor->shape[i]; + cols_ = tensor->shape.back(); + rows_ = 1; + if (tensor->cols_take_extra_dims) { + // The columns eat the extra dimensions. + rows_ = tensor->shape[0]; + for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { + cols_ *= tensor->shape[i]; + } + } else { + // The rows eat the extra dimensions. + for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { + rows_ *= tensor->shape[i]; + } } } + stride_ = cols_; + num_elements_ = rows_ * cols_; } // Copying allowed as the metadata is small. diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index 7d40333..ff607e7 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -165,9 +165,9 @@ void CompressWeights(const Path& weights_path, compressed_weights_path.path.c_str()); ModelConfig config = ConfigFromModel(model_type); std::vector model_storage; - ModelWeightsPtrs c_weights(config, pool); + ModelWeightsPtrs c_weights(config); c_weights.Allocate(model_storage, pool); - ModelWeightsPtrs uc_weights(config, pool); + ModelWeightsPtrs uc_weights(config); uc_weights.Allocate(model_storage, pool); // Get uncompressed weights, compress, and store. FILE* fptr = fopen(weights_path.path.c_str(), "rb"); diff --git a/gemma/tensor_index_test.cc b/gemma/tensor_index_test.cc index 7fd1268..8928ad4 100644 --- a/gemma/tensor_index_test.cc +++ b/gemma/tensor_index_test.cc @@ -13,7 +13,6 @@ #include "gemma/weights.h" #include "util/basics.h" #include "hwy/aligned_allocator.h" -#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { namespace { @@ -22,7 +21,6 @@ namespace { // and that the TensorIndex returns the correct shape and name for the tensor, // for all models. TEST(TensorIndexTest, FindName) { - hwy::ThreadPool pool(4); for (Model model : kAllModels) { fprintf(stderr, "Testing model %d\n", static_cast(model)); ModelConfig config = ConfigFromModel(model); @@ -44,7 +42,7 @@ TEST(TensorIndexTest, FindName) { /*split_and_reshape=*/false); } // For each tensor in any model, exactly one TensorIndex should find it. - ModelWeightsPtrs weights(config, pool); + ModelWeightsPtrs weights(config); ModelWeightsPtrs::ForEachTensor( {&weights}, ForEachType::kInitNoToc, [&tensor_indexes](const char* name, hwy::Span tensors) { diff --git a/gemma/weights.cc b/gemma/weights.cc index e0f4d8c..6fd8480 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -186,18 +186,18 @@ void ModelWeightsStorage::CreateForType(Type weight_type, hwy::ThreadPool& pool) { switch (weight_type) { case Type::kF32: - float_weights_ = std::make_unique>(config_, pool); + float_weights_ = std::make_unique>(config_); break; case Type::kBF16: - bf16_weights_ = std::make_unique>(config_, pool); + bf16_weights_ = std::make_unique>(config_); break; case Type::kSFP: sfp_weights_ = - std::make_unique>(config_, pool); + std::make_unique>(config_); break; case Type::kNUQ: nuq_weights_ = - std::make_unique>(config_, pool); + std::make_unique>(config_); break; default: HWY_ABORT("Weight type %d unsupported.", static_cast(weight_type)); diff --git a/gemma/weights.h b/gemma/weights.h index ecd917b..2db9811 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -30,6 +30,7 @@ #include "compression/shared.h" #include "gemma/common.h" #include "gemma/configs.h" +#include "gemma/tensor_index.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -56,73 +57,48 @@ enum class ForEachType { template struct LayerWeightsPtrs { // Large data is constructed separately. - explicit LayerWeightsPtrs(const LayerConfig& config) - : attn_vec_einsum_w("att_ein", config.heads * config.model_dim, - config.qkv_dim), - qkv_einsum_w("qkv_ein", - (config.heads + 2 * config.kv_heads) * config.qkv_dim, - config.model_dim), - qkv_einsum_w1("qkv1_w", config.heads * config.qkv_dim, - config.model_dim), - qkv_einsum_w2("qkv2_w", 2 * config.kv_heads * config.qkv_dim, - config.model_dim), - attention_output_biases( - "attn_ob", 1, - config.softmax_attn_output_biases ? config.model_dim : 0), - griffin( - {.linear_x_w = {"gr_lin_x_w", config.griffin_dim, - config.griffin_dim}, - .linear_x_biases = {"gr_lin_x_b", 1, config.griffin_dim}, - .linear_y_w = {"gr_lin_y_w", config.griffin_dim, - config.griffin_dim}, - .linear_y_biases = {"gr_lin_y_b", 1, config.griffin_dim}, - .linear_out_w = {"gr_lin_out_w", config.griffin_dim, - config.griffin_dim}, - .linear_out_biases = {"gr_lin_out_b", 1, config.griffin_dim}, - .conv_w = {"gr_conv_w", config.conv1d_width, config.griffin_dim}, - .conv_biases = {"gr_conv_b", 1, config.griffin_dim}, - .gate_w = {"gr_gate_w", 2 * config.griffin_dim, - config.griffin_dim / config.heads}, - .gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2}, - .a = {"gr_a", 1, config.griffin_dim}}), + explicit LayerWeightsPtrs(const LayerConfig& config, + const TensorIndex& tensor_index) + : attn_vec_einsum_w("att_ein", tensor_index), + qkv_einsum_w("qkv_ein", tensor_index), + qkv_einsum_w1("qkv1_w", tensor_index), + qkv_einsum_w2("qkv2_w", tensor_index), + attention_output_biases("attn_ob", tensor_index), + griffin({.linear_x_w = {"gr_lin_x_w", tensor_index}, + .linear_x_biases = {"gr_lin_x_b", tensor_index}, + .linear_y_w = {"gr_lin_y_w", tensor_index}, + .linear_y_biases = {"gr_lin_y_b", tensor_index}, + .linear_out_w = {"gr_lin_out_w", tensor_index}, + .linear_out_biases = {"gr_lin_out_b", tensor_index}, + .conv_w = {"gr_conv_w", tensor_index}, + .conv_biases = {"gr_conv_b", tensor_index}, + .gate_w = {"gr_gate_w", tensor_index}, + .gate_biases = {"gr_gate_b", tensor_index}, + .a = {"gr_a", tensor_index}}), // MultiHeadDotProductAttention. - vit({.attn_out_w = {"attn_out_w", config.model_dim, - config.heads * config.qkv_dim}, - .attn_out_b = {"attn_out_b", 1, config.model_dim}, - .qkv_einsum_w = {"qkv_ein_w", - (config.heads + 2 * config.kv_heads) * - config.qkv_dim, - config.model_dim}, - .qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads), - config.qkv_dim}, - .linear_0_w = {"linear_0_w", config.ff_hidden_dim, - config.model_dim}, - .linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim}, - .linear_1_w = {"linear_1_w", config.model_dim, - config.ff_hidden_dim}, - .linear_1_b = {"linear_1_b", 1, config.model_dim}, - .layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim}, - .layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim}, - .layer_norm_1_bias = {"ln_1_bias", 1, config.model_dim}, - .layer_norm_1_scale = {"ln_1_scale", 1, config.model_dim}}), - gating_einsum_w("gating_ein", 2 * config.ff_hidden_dim, - config.model_dim), - gating_einsum_w1("gating1_w", config.ff_hidden_dim, config.model_dim), - gating_einsum_w2("gating2_w", config.ff_hidden_dim, config.model_dim), - linear_w("linear_w", config.model_dim, config.ff_hidden_dim), - pre_attention_norm_scale("pre_att_ns", 1, config.model_dim), - pre_ffw_norm_scale("pre_ff_ns", 1, config.model_dim), - post_attention_norm_scale( - "post_att_ns", 1, - config.post_norm == PostNormType::Scale ? config.model_dim : 0), - post_ffw_norm_scale( - "post_ff_ns", 1, - config.post_norm == PostNormType::Scale ? config.model_dim : 0), - ffw_gating_biases("ffw_gat_b", 1, - config.ff_biases ? 2 * config.ff_hidden_dim : 0), - ffw_output_biases("ffw_out_b", 1, - config.ff_biases ? config.model_dim : 0), - att_weights("att_w", config.model_dim, config.heads * config.qkv_dim), + vit({.attn_out_w = {"attn_out_w", tensor_index}, + .attn_out_b = {"attn_out_b", tensor_index}, + .qkv_einsum_w = {"qkv_ein_w", tensor_index}, + .qkv_einsum_b = {"qkv_ein_b", tensor_index}, + .linear_0_w = {"linear_0_w", tensor_index}, + .linear_0_b = {"linear_0_b", tensor_index}, + .linear_1_w = {"linear_1_w", tensor_index}, + .linear_1_b = {"linear_1_b", tensor_index}, + .layer_norm_0_bias = {"ln_0_bias", tensor_index}, + .layer_norm_0_scale = {"ln_0_scale", tensor_index}, + .layer_norm_1_bias = {"ln_1_bias", tensor_index}, + .layer_norm_1_scale = {"ln_1_scale", tensor_index}}), + gating_einsum_w("gating_ein", tensor_index), + gating_einsum_w1("gating1_w", tensor_index), + gating_einsum_w2("gating2_w", tensor_index), + linear_w("linear_w", tensor_index), + pre_attention_norm_scale("pre_att_ns", tensor_index), + pre_ffw_norm_scale("pre_ff_ns", tensor_index), + post_attention_norm_scale("post_att_ns", tensor_index), + post_ffw_norm_scale("post_ff_ns", tensor_index), + ffw_gating_biases("ffw_gat_b", tensor_index), + ffw_output_biases("ffw_out_b", tensor_index), + att_weights("att_w", tensor_index), layer_config(config) {} ~LayerWeightsPtrs() = default; @@ -342,28 +318,38 @@ struct LayerWeightsPtrs { template struct ModelWeightsPtrs { - ModelWeightsPtrs(const ModelConfig& config, hwy::ThreadPool& pool) - : embedder_input_embedding("c_embedding", config.vocab_size, - config.model_dim), - final_norm_scale("c_final_norm", 1, config.model_dim), - vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim), - vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim), - vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), - vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim, - config.patch_width * config.patch_width * 3), - vit_img_pos_embedding("img_pos_emb", config.vit_seq_len, - config.vit_model_dim), - vit_img_head_bias("img_head_bias", 1, config.model_dim), - vit_img_head_kernel("img_head_kernel", config.model_dim, - config.vit_model_dim), + explicit ModelWeightsPtrs(const ModelConfig& config) + : ModelWeightsPtrs( + config, + TensorIndex(config, /*llm_layer_idx=*/-1, /*vit_layer_idx=*/-1, + /*reshape_att=*/false)) {} + ModelWeightsPtrs(const ModelConfig& config, const TensorIndex& tensor_index) + : embedder_input_embedding("c_embedding", tensor_index), + final_norm_scale("c_final_norm", tensor_index), + vit_encoder_norm_bias("enc_norm_bias", tensor_index), + vit_encoder_norm_scale("enc_norm_scale", tensor_index), + vit_img_embedding_bias("img_emb_bias", tensor_index), + vit_img_embedding_kernel("img_emb_kernel", tensor_index), + vit_img_pos_embedding("img_pos_emb", tensor_index), + vit_img_head_bias("img_head_bias", tensor_index), + vit_img_head_kernel("img_head_kernel", tensor_index), scale_names(config.scale_names), weights_config(config) { c_layers.reserve(config.layer_configs.size()); - for (const auto& layer_config : config.layer_configs) { - c_layers.push_back(LayerWeightsPtrs(layer_config)); + for (int index = 0; index < static_cast(config.layer_configs.size()); + ++index) { + const auto& layer_config = config.layer_configs[index]; + TensorIndex tensor_index(config, index, /*vit_layer_idx=*/-1, + /*reshape_att=*/false); + c_layers.push_back(LayerWeightsPtrs(layer_config, tensor_index)); } - for (const auto& layer_config : config.vit_layer_configs) { - vit_layers.push_back(LayerWeightsPtrs(layer_config)); + for (int index = 0; + index < static_cast(config.vit_layer_configs.size()); ++index) { + const auto& layer_config = config.vit_layer_configs[index]; + TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index, + /*reshape_att=*/false); + vit_layers.push_back( + LayerWeightsPtrs(layer_config, tensor_index)); } }