Skip to content

Commit

Permalink
Removed duplicated tensor sizes from weights.h by changing the constr…
Browse files Browse the repository at this point in the history
…uctor used for MatPtrT

PiperOrigin-RevId: 705045279
  • Loading branch information
theraysmith authored and copybara-github committed Dec 11, 2024
1 parent aed1739 commit f966ef5
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 108 deletions.
1 change: 0 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ cc_test(
"@googletest//:gtest_main",
"//compression:compress",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

Expand Down
8 changes: 5 additions & 3 deletions backprop/backward_scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,14 @@ TEST(BackPropTest, LayerVJP) {
using T = double;
using TC = std::complex<T>;
ModelConfig config = TestConfig();
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
/*reshape_att=*/false);
const size_t kOutputSize = config.seq_len * config.model_dim;
LayerWeightsPtrs<T> weights(config.layer_configs[0]);
LayerWeightsPtrs<T> grad(config.layer_configs[0]);
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0]);
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
MatStorageT<T> y("y", kOutputSize, 1);
MatStorageT<T> dy("dy", kOutputSize, 1);
Expand Down
2 changes: 1 addition & 1 deletion backprop/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ template <typename T>
class WeightsWrapper {
public:
explicit WeightsWrapper(const ModelConfig& config)
: pool_(0), weights_(config, pool_) {
: pool_(0), weights_(config) {
weights_.Allocate(data_, pool_);
}

Expand Down
30 changes: 18 additions & 12 deletions compression/compress.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,21 +219,27 @@ class MatPtrT : public MatPtr {
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
MatPtrT(const std::string& name, const TensorInfo* tensor)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
HWY_ASSERT(tensor != nullptr);
cols_ = tensor->shape.back();
rows_ = 1;
if (tensor->cols_take_extra_dims) {
// The columns eat the extra dimensions.
rows_ = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols_ *= tensor->shape[i];
}
if (tensor == nullptr) {
cols_ = 0;
rows_ = 0;
} else {
// The rows eat the extra dimensions.
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows_ *= tensor->shape[i];
cols_ = tensor->shape.back();
rows_ = 1;
if (tensor->cols_take_extra_dims) {
// The columns eat the extra dimensions.
rows_ = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols_ *= tensor->shape[i];
}
} else {
// The rows eat the extra dimensions.
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows_ *= tensor->shape[i];
}
}
}
stride_ = cols_;
num_elements_ = rows_ * cols_;
}

// Copying allowed as the metadata is small.
Expand Down
4 changes: 1 addition & 3 deletions gemma/tensor_index_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "gemma/weights.h"
#include "util/basics.h"
#include "hwy/aligned_allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

namespace gcpp {
namespace {
Expand All @@ -22,7 +21,6 @@ namespace {
// and that the TensorIndex returns the correct shape and name for the tensor,
// for all models.
TEST(TensorIndexTest, FindName) {
hwy::ThreadPool pool(4);
for (Model model : kAllModels) {
fprintf(stderr, "Testing model %d\n", static_cast<int>(model));
ModelConfig config = ConfigFromModel(model);
Expand All @@ -44,7 +42,7 @@ TEST(TensorIndexTest, FindName) {
/*split_and_reshape=*/false);
}
// For each tensor in any model, exactly one TensorIndex should find it.
ModelWeightsPtrs<SfpStream> weights(config, pool);
ModelWeightsPtrs<SfpStream> weights(config);
ModelWeightsPtrs<SfpStream>::ForEachTensor(
{&weights}, ForEachType::kInitNoToc,
[&tensor_indexes](const char* name, hwy::Span<MatPtr*> tensors) {
Expand Down
8 changes: 4 additions & 4 deletions gemma/weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,18 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
hwy::ThreadPool& pool) {
switch (weight_type) {
case Type::kF32:
float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_, pool);
float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_);
break;
case Type::kBF16:
bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_, pool);
bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_);
break;
case Type::kSFP:
sfp_weights_ =
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_, pool);
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_);
break;
case Type::kNUQ:
nuq_weights_ =
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_, pool);
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_);
break;
default:
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
Expand Down
154 changes: 70 additions & 84 deletions gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/tensor_index.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
Expand All @@ -56,73 +57,48 @@ enum class ForEachType {
template <class Weight>
struct LayerWeightsPtrs {
// Large data is constructed separately.
explicit LayerWeightsPtrs(const LayerConfig& config)
: attn_vec_einsum_w("att_ein", config.heads * config.model_dim,
config.qkv_dim),
qkv_einsum_w("qkv_ein",
(config.heads + 2 * config.kv_heads) * config.qkv_dim,
config.model_dim),
qkv_einsum_w1("qkv1_w", config.heads * config.qkv_dim,
config.model_dim),
qkv_einsum_w2("qkv2_w", 2 * config.kv_heads * config.qkv_dim,
config.model_dim),
attention_output_biases(
"attn_ob", 1,
config.softmax_attn_output_biases ? config.model_dim : 0),
griffin(
{.linear_x_w = {"gr_lin_x_w", config.griffin_dim,
config.griffin_dim},
.linear_x_biases = {"gr_lin_x_b", 1, config.griffin_dim},
.linear_y_w = {"gr_lin_y_w", config.griffin_dim,
config.griffin_dim},
.linear_y_biases = {"gr_lin_y_b", 1, config.griffin_dim},
.linear_out_w = {"gr_lin_out_w", config.griffin_dim,
config.griffin_dim},
.linear_out_biases = {"gr_lin_out_b", 1, config.griffin_dim},
.conv_w = {"gr_conv_w", config.conv1d_width, config.griffin_dim},
.conv_biases = {"gr_conv_b", 1, config.griffin_dim},
.gate_w = {"gr_gate_w", 2 * config.griffin_dim,
config.griffin_dim / config.heads},
.gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2},
.a = {"gr_a", 1, config.griffin_dim}}),
explicit LayerWeightsPtrs(const LayerConfig& config,
const TensorIndex& tensor_index)
: attn_vec_einsum_w("att_ein", tensor_index),
qkv_einsum_w("qkv_ein", tensor_index),
qkv_einsum_w1("qkv1_w", tensor_index),
qkv_einsum_w2("qkv2_w", tensor_index),
attention_output_biases("attn_ob", tensor_index),
griffin({.linear_x_w = {"gr_lin_x_w", tensor_index},
.linear_x_biases = {"gr_lin_x_b", tensor_index},
.linear_y_w = {"gr_lin_y_w", tensor_index},
.linear_y_biases = {"gr_lin_y_b", tensor_index},
.linear_out_w = {"gr_lin_out_w", tensor_index},
.linear_out_biases = {"gr_lin_out_b", tensor_index},
.conv_w = {"gr_conv_w", tensor_index},
.conv_biases = {"gr_conv_b", tensor_index},
.gate_w = {"gr_gate_w", tensor_index},
.gate_biases = {"gr_gate_b", tensor_index},
.a = {"gr_a", tensor_index}}),
// MultiHeadDotProductAttention.
vit({.attn_out_w = {"attn_out_w", config.model_dim,
config.heads * config.qkv_dim},
.attn_out_b = {"attn_out_b", 1, config.model_dim},
.qkv_einsum_w = {"qkv_ein_w",
(config.heads + 2 * config.kv_heads) *
config.qkv_dim,
config.model_dim},
.qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
config.qkv_dim},
.linear_0_w = {"linear_0_w", config.ff_hidden_dim,
config.model_dim},
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
.linear_1_w = {"linear_1_w", config.model_dim,
config.ff_hidden_dim},
.linear_1_b = {"linear_1_b", 1, config.model_dim},
.layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
.layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
.layer_norm_1_bias = {"ln_1_bias", 1, config.model_dim},
.layer_norm_1_scale = {"ln_1_scale", 1, config.model_dim}}),
gating_einsum_w("gating_ein", 2 * config.ff_hidden_dim,
config.model_dim),
gating_einsum_w1("gating1_w", config.ff_hidden_dim, config.model_dim),
gating_einsum_w2("gating2_w", config.ff_hidden_dim, config.model_dim),
linear_w("linear_w", config.model_dim, config.ff_hidden_dim),
pre_attention_norm_scale("pre_att_ns", 1, config.model_dim),
pre_ffw_norm_scale("pre_ff_ns", 1, config.model_dim),
post_attention_norm_scale(
"post_att_ns", 1,
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
post_ffw_norm_scale(
"post_ff_ns", 1,
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
ffw_gating_biases("ffw_gat_b", 1,
config.ff_biases ? 2 * config.ff_hidden_dim : 0),
ffw_output_biases("ffw_out_b", 1,
config.ff_biases ? config.model_dim : 0),
att_weights("att_w", config.model_dim, config.heads * config.qkv_dim),
vit({.attn_out_w = {"attn_out_w", tensor_index},
.attn_out_b = {"attn_out_b", tensor_index},
.qkv_einsum_w = {"qkv_ein_w", tensor_index},
.qkv_einsum_b = {"qkv_ein_b", tensor_index},
.linear_0_w = {"linear_0_w", tensor_index},
.linear_0_b = {"linear_0_b", tensor_index},
.linear_1_w = {"linear_1_w", tensor_index},
.linear_1_b = {"linear_1_b", tensor_index},
.layer_norm_0_bias = {"ln_0_bias", tensor_index},
.layer_norm_0_scale = {"ln_0_scale", tensor_index},
.layer_norm_1_bias = {"ln_1_bias", tensor_index},
.layer_norm_1_scale = {"ln_1_scale", tensor_index}}),
gating_einsum_w("gating_ein", tensor_index),
gating_einsum_w1("gating1_w", tensor_index),
gating_einsum_w2("gating2_w", tensor_index),
linear_w("linear_w", tensor_index),
pre_attention_norm_scale("pre_att_ns", tensor_index),
pre_ffw_norm_scale("pre_ff_ns", tensor_index),
post_attention_norm_scale("post_att_ns", tensor_index),
post_ffw_norm_scale("post_ff_ns", tensor_index),
ffw_gating_biases("ffw_gat_b", tensor_index),
ffw_output_biases("ffw_out_b", tensor_index),
att_weights("att_w", tensor_index),
layer_config(config) {}
~LayerWeightsPtrs() = default;

Expand Down Expand Up @@ -342,28 +318,38 @@ struct LayerWeightsPtrs {

template <class Weight>
struct ModelWeightsPtrs {
ModelWeightsPtrs(const ModelConfig& config, hwy::ThreadPool& pool)
: embedder_input_embedding("c_embedding", config.vocab_size,
config.model_dim),
final_norm_scale("c_final_norm", 1, config.model_dim),
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim,
config.patch_width * config.patch_width * 3),
vit_img_pos_embedding("img_pos_emb", config.vit_seq_len,
config.vit_model_dim),
vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_kernel("img_head_kernel", config.model_dim,
config.vit_model_dim),
explicit ModelWeightsPtrs(const ModelConfig& config)
: ModelWeightsPtrs(
config,
TensorIndex(config, /*llm_layer_idx=*/-1, /*vit_layer_idx=*/-1,
/*reshape_att=*/false)) {}
ModelWeightsPtrs(const ModelConfig& config, const TensorIndex& tensor_index)
: embedder_input_embedding("c_embedding", tensor_index),
final_norm_scale("c_final_norm", tensor_index),
vit_encoder_norm_bias("enc_norm_bias", tensor_index),
vit_encoder_norm_scale("enc_norm_scale", tensor_index),
vit_img_embedding_bias("img_emb_bias", tensor_index),
vit_img_embedding_kernel("img_emb_kernel", tensor_index),
vit_img_pos_embedding("img_pos_emb", tensor_index),
vit_img_head_bias("img_head_bias", tensor_index),
vit_img_head_kernel("img_head_kernel", tensor_index),
scale_names(config.scale_names),
weights_config(config) {
c_layers.reserve(config.layer_configs.size());
for (const auto& layer_config : config.layer_configs) {
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
for (int index = 0; index < static_cast<int>(config.layer_configs.size());
++index) {
const auto& layer_config = config.layer_configs[index];
TensorIndex tensor_index(config, index, /*vit_layer_idx=*/-1,
/*reshape_att=*/false);
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config, tensor_index));
}
for (const auto& layer_config : config.vit_layer_configs) {
vit_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
for (int index = 0;
index < static_cast<int>(config.vit_layer_configs.size()); ++index) {
const auto& layer_config = config.vit_layer_configs[index];
TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index,
/*reshape_att=*/false);
vit_layers.push_back(
LayerWeightsPtrs<Weight>(layer_config, tensor_index));
}
}

Expand Down

0 comments on commit f966ef5

Please sign in to comment.