Skip to content

Commit

Permalink
Moved the vit config fields to their own config struct
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709034086
  • Loading branch information
theraysmith authored and copybara-github committed Jan 14, 2025
1 parent 9d40f01 commit 93650ae
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 84 deletions.
56 changes: 33 additions & 23 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,19 @@ static LayerConfig LayerConfigVit(size_t model_dim) {

// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
config.vit_model_dim = 1152;
config.vit_config.model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152
config.image_size = image_size;
config.patch_width = 14;
config.vit_config.image_size = image_size;
config.vit_config.patch_width = 14;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
config.vit_config.seq_len = num_patches * num_patches;
for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = false;
}
const size_t num_patches = config.image_size / config.patch_width;
config.vit_seq_len = num_patches * num_patches;
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim);
config.vit_layer_configs = {27, vit_layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size();
LayerConfig vit_layer_config = LayerConfigVit(config.vit_config.model_dim);
config.vit_config.layer_configs = {27, vit_layer_config};
config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
}

static ModelConfig ConfigPaliGemma_224() {
Expand All @@ -283,11 +284,11 @@ static ModelConfig ConfigPaliGemma_448() {
return config;
}

ModelConfig VitConfig(const ModelConfig& config) {
ModelConfig GetVitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_model_dim;
vit_config.seq_len = config.vit_seq_len;
vit_config.layer_configs = config.vit_layer_configs;
vit_config.model_dim = config.vit_config.model_dim;
vit_config.seq_len = config.vit_config.seq_len;
vit_config.layer_configs = config.vit_config.layer_configs;
// The Vit part does not have a vocabulary, the image patches are embedded.
vit_config.vocab_size = 0;
return vit_config;
Expand Down Expand Up @@ -402,9 +403,28 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
return result;
}

bool VitConfig::TestEqual(const VitConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(seq_len, other.seq_len);
if (!partial) {
TEST_EQUAL(num_scales, other.num_scales);
}
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
for (size_t i = 0; i < layer_configs.size(); ++i) {
result &=
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
}
return result;
}

bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_family_version, other.model_family_version);
// We don't care about model_name, model, wrapping, or weight being different,
// but will output in debug mode if they are.
if (debug) {
Expand All @@ -415,13 +435,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
}
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(vit_model_dim, other.vit_model_dim);
TEST_EQUAL(vocab_size, other.vocab_size);
TEST_EQUAL(seq_len, other.seq_len);
TEST_EQUAL(vit_seq_len, other.vit_seq_len);
if (!partial) {
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
TEST_EQUAL(num_vit_scales, other.num_vit_scales);
}
TEST_EQUAL(att_cap, other.att_cap);
TEST_EQUAL(final_cap, other.final_cap);
Expand All @@ -439,11 +456,6 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
}
RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size());
for (size_t i = 0; i < vit_layer_configs.size(); ++i) {
result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i],
partial, debug);
}
if (!partial) {
if (scale_names != other.scale_names) {
result = false;
Expand All @@ -453,9 +465,7 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
}
}
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
TEST_EQUAL(model_family_version, other.model_family_version);
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
vit_config.TestEqual(other.vit_config, partial, debug);
return result;
}

Expand Down
45 changes: 31 additions & 14 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,33 @@ struct LayerConfig : public IFields {
PostQKType post_qk = PostQKType::Rope;
};

// Dimensions related to image processing.
struct VitConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
// If debug is true, then we output the mismatched fields to stderr.
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;

const char* Name() const override { return "VitConfig"; }

void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_dim);
visitor(seq_len);
visitor(num_scales);
visitor(patch_width);
visitor(image_size);
visitor(layer_configs);
}

uint32_t model_dim = 0;
uint32_t seq_len = 0;
uint32_t num_scales = 0;
uint32_t patch_width = 14;
uint32_t image_size = 224;
std::vector<LayerConfig> layer_configs;
};

struct ModelConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
Expand Down Expand Up @@ -277,40 +304,30 @@ struct ModelConfig : public IFields {
visitor(layer_configs);
visitor(attention_window_sizes);
visitor(norm_num_groups);
visitor(vit_model_dim);
visitor(vit_seq_len);
visitor(num_vit_scales);
visitor(vit_layer_configs);
visitor(patch_width);
visitor(image_size);
visitor(vit_config);
}

uint32_t model_family_version = 1;
std::string model_name;
Model model = Model::UNKNOWN;
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
Type weight = Type::kUnknown;
uint32_t num_layers = 0;
uint32_t model_dim = 0;
uint32_t vit_model_dim = 0;
uint32_t vocab_size = 0;
uint32_t seq_len = 0;
uint32_t vit_seq_len = 0;
uint32_t num_tensor_scales = 0;
uint32_t num_vit_scales = 0;
float att_cap = 0.0f;
float final_cap = 0.0f;
bool absolute_pe = false;
bool use_local_attention = false; // griffin only
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs;
std::vector<uint32_t> attention_window_sizes;
std::vector<LayerConfig> vit_layer_configs;
std::unordered_set<std::string> scale_names;
uint32_t norm_num_groups = 1;
uint32_t model_family_version = 1;
// Dimensions related to image processing.
uint32_t patch_width = 14;
uint32_t image_size = 224;
VitConfig vit_config;
};

// Returns the config for the given model.
Expand All @@ -320,7 +337,7 @@ ModelConfig ConfigFromModel(Model model);
Model ModelFromConfig(const ModelConfig& config);

// Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig VitConfig(const ModelConfig& config);
ModelConfig GetVitConfig(const ModelConfig& config);

} // namespace gcpp

Expand Down
11 changes: 6 additions & 5 deletions gemma/configs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,13 @@ template <class TConfig>
void AssertMatch(const ModelConfig& config) {
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
if constexpr (TConfig::VitConfig::kModelDim != 0) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales);
for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
config.vit_config.num_scales);
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
config.vit_layer_configs[i].type);
config.vit_config.layer_configs[i].type);
}
}
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
Expand Down
17 changes: 9 additions & 8 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1042,9 +1042,9 @@ template <typename T>
HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelWeightsPtrs<T>& weights,
Activations& activations) {
const size_t model_dim = weights.weights_config.vit_model_dim;
const size_t patch_width = weights.weights_config.patch_width;
const size_t seq_len = weights.weights_config.vit_seq_len;
const size_t model_dim = weights.weights_config.vit_config.model_dim;
const size_t patch_width = weights.weights_config.vit_config.patch_width;
const size_t seq_len = weights.weights_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
patch_size * model_dim);
Expand Down Expand Up @@ -1087,14 +1087,15 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
const Image& image, ImageTokens& image_tokens,
Activations& activations) {
PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = weights.weights_config.vit_seq_len;
const size_t vit_model_dim = weights.weights_config.vit_model_dim;
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.BatchSize());
// Embed the image patches.
EmbedImagePatches(image, weights, activations);
// Go through all layers.
for (size_t layer = 0;
layer < weights.weights_config.vit_layer_configs.size(); ++layer) {
layer < weights.weights_config.vit_config.layer_configs.size();
++layer) {
const auto* layer_weights = weights.GetVitLayer(layer);
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
}
Expand Down Expand Up @@ -1413,11 +1414,11 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
NestedPools& pools) {
if (model.Config().vit_layer_configs.empty()) {
if (model.Config().vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = VitConfig(model.Config());
ModelConfig vit_config = GetVitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
Activations prefill_activations(vit_config);
prefill_activations.Allocate(vit_config.seq_len, pools);
Expand Down
7 changes: 4 additions & 3 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
Image image;
ImageTokens image_tokens;
if (have_image) {
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().model_dim));
image_tokens =
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().image_size;
const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
Expand Down
Loading

0 comments on commit 93650ae

Please sign in to comment.