Skip to content

Commit f163a1c

Browse files
theraysmithcopybara-github
authored andcommitted
Moved the vit config fields to their own config struct
PiperOrigin-RevId: 709034086
1 parent 9d40f01 commit f163a1c

File tree

10 files changed

+117
-84
lines changed

10 files changed

+117
-84
lines changed

gemma/configs.cc

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -253,18 +253,19 @@ static LayerConfig LayerConfigVit(size_t model_dim) {
253253

254254
// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
255255
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
256-
config.vit_model_dim = 1152;
256+
config.vit_config.model_dim = 1152;
257257
config.vocab_size = 256000 + 1024 + 128; // = 257152
258-
config.image_size = image_size;
259-
config.patch_width = 14;
258+
config.vit_config.image_size = image_size;
259+
config.vit_config.patch_width = 14;
260+
const size_t num_patches =
261+
config.vit_config.image_size / config.vit_config.patch_width;
262+
config.vit_config.seq_len = num_patches * num_patches;
260263
for (auto& layer_config : config.layer_configs) {
261264
layer_config.optimized_gating = false;
262265
}
263-
const size_t num_patches = config.image_size / config.patch_width;
264-
config.vit_seq_len = num_patches * num_patches;
265-
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim);
266-
config.vit_layer_configs = {27, vit_layer_config};
267-
config.num_vit_scales = 4 * config.vit_layer_configs.size();
266+
LayerConfig vit_layer_config = LayerConfigVit(config.vit_config.model_dim);
267+
config.vit_config.layer_configs = {27, vit_layer_config};
268+
config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
268269
}
269270

270271
static ModelConfig ConfigPaliGemma_224() {
@@ -283,11 +284,11 @@ static ModelConfig ConfigPaliGemma_448() {
283284
return config;
284285
}
285286

286-
ModelConfig VitConfig(const ModelConfig& config) {
287+
ModelConfig GetVitConfig(const ModelConfig& config) {
287288
ModelConfig vit_config = ConfigNoSSM();
288-
vit_config.model_dim = config.vit_model_dim;
289-
vit_config.seq_len = config.vit_seq_len;
290-
vit_config.layer_configs = config.vit_layer_configs;
289+
vit_config.model_dim = config.vit_config.model_dim;
290+
vit_config.seq_len = config.vit_config.seq_len;
291+
vit_config.layer_configs = config.vit_config.layer_configs;
291292
// The Vit part does not have a vocabulary, the image patches are embedded.
292293
vit_config.vocab_size = 0;
293294
return vit_config;
@@ -402,9 +403,28 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
402403
return result;
403404
}
404405

406+
bool VitConfig::TestEqual(const VitConfig& other, bool partial,
407+
bool debug) const {
408+
bool result = true;
409+
TEST_EQUAL(model_dim, other.model_dim);
410+
TEST_EQUAL(seq_len, other.seq_len);
411+
if (!partial) {
412+
TEST_EQUAL(num_scales, other.num_scales);
413+
}
414+
TEST_EQUAL(patch_width, other.patch_width);
415+
TEST_EQUAL(image_size, other.image_size);
416+
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
417+
for (size_t i = 0; i < layer_configs.size(); ++i) {
418+
result &=
419+
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
420+
}
421+
return result;
422+
}
423+
405424
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
406425
bool debug) const {
407426
bool result = true;
427+
TEST_EQUAL(model_family_version, other.model_family_version);
408428
// We don't care about model_name, model, wrapping, or weight being different,
409429
// but will output in debug mode if they are.
410430
if (debug) {
@@ -415,13 +435,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
415435
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
416436
}
417437
TEST_EQUAL(model_dim, other.model_dim);
418-
TEST_EQUAL(vit_model_dim, other.vit_model_dim);
419438
TEST_EQUAL(vocab_size, other.vocab_size);
420439
TEST_EQUAL(seq_len, other.seq_len);
421-
TEST_EQUAL(vit_seq_len, other.vit_seq_len);
422440
if (!partial) {
423441
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
424-
TEST_EQUAL(num_vit_scales, other.num_vit_scales);
425442
}
426443
TEST_EQUAL(att_cap, other.att_cap);
427444
TEST_EQUAL(final_cap, other.final_cap);
@@ -439,11 +456,6 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
439456
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
440457
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
441458
}
442-
RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size());
443-
for (size_t i = 0; i < vit_layer_configs.size(); ++i) {
444-
result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i],
445-
partial, debug);
446-
}
447459
if (!partial) {
448460
if (scale_names != other.scale_names) {
449461
result = false;
@@ -453,9 +465,7 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
453465
}
454466
}
455467
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
456-
TEST_EQUAL(model_family_version, other.model_family_version);
457-
TEST_EQUAL(patch_width, other.patch_width);
458-
TEST_EQUAL(image_size, other.image_size);
468+
vit_config.TestEqual(other.vit_config, partial, debug);
459469
return result;
460470
}
461471

gemma/configs.h

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,33 @@ struct LayerConfig : public IFields {
220220
PostQKType post_qk = PostQKType::Rope;
221221
};
222222

223+
// Dimensions related to image processing.
224+
struct VitConfig : public IFields {
225+
// Returns true if *this and other are equal.
226+
// If partial is true, then we don't check for items that are only set after
227+
// the tensors are loaded from the checkpoint.
228+
// If debug is true, then we output the mismatched fields to stderr.
229+
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;
230+
231+
const char* Name() const override { return "VitConfig"; }
232+
233+
void VisitFields(IFieldsVisitor& visitor) override {
234+
visitor(model_dim);
235+
visitor(seq_len);
236+
visitor(num_scales);
237+
visitor(patch_width);
238+
visitor(image_size);
239+
visitor(layer_configs);
240+
}
241+
242+
uint32_t model_dim = 0;
243+
uint32_t seq_len = 0;
244+
uint32_t num_scales = 0;
245+
uint32_t patch_width = 14;
246+
uint32_t image_size = 224;
247+
std::vector<LayerConfig> layer_configs;
248+
};
249+
223250
struct ModelConfig : public IFields {
224251
// Returns true if *this and other are equal.
225252
// If partial is true, then we don't check for items that are only set after
@@ -277,40 +304,30 @@ struct ModelConfig : public IFields {
277304
visitor(layer_configs);
278305
visitor(attention_window_sizes);
279306
visitor(norm_num_groups);
280-
visitor(vit_model_dim);
281-
visitor(vit_seq_len);
282-
visitor(num_vit_scales);
283-
visitor(vit_layer_configs);
284-
visitor(patch_width);
285-
visitor(image_size);
307+
visitor(vit_config);
286308
}
287309

310+
uint32_t model_family_version = 1;
288311
std::string model_name;
289312
Model model = Model::UNKNOWN;
290313
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
291314
Type weight = Type::kUnknown;
292315
uint32_t num_layers = 0;
293316
uint32_t model_dim = 0;
294-
uint32_t vit_model_dim = 0;
295317
uint32_t vocab_size = 0;
296318
uint32_t seq_len = 0;
297-
uint32_t vit_seq_len = 0;
298319
uint32_t num_tensor_scales = 0;
299-
uint32_t num_vit_scales = 0;
300320
float att_cap = 0.0f;
301321
float final_cap = 0.0f;
302322
bool absolute_pe = false;
303323
bool use_local_attention = false; // griffin only
304324
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
305325
std::vector<LayerConfig> layer_configs;
306326
std::vector<uint32_t> attention_window_sizes;
307-
std::vector<LayerConfig> vit_layer_configs;
308327
std::unordered_set<std::string> scale_names;
309328
uint32_t norm_num_groups = 1;
310-
uint32_t model_family_version = 1;
311329
// Dimensions related to image processing.
312-
uint32_t patch_width = 14;
313-
uint32_t image_size = 224;
330+
VitConfig vit_config;
314331
};
315332

316333
// Returns the config for the given model.
@@ -320,7 +337,7 @@ ModelConfig ConfigFromModel(Model model);
320337
Model ModelFromConfig(const ModelConfig& config);
321338

322339
// Returns the sub-config for the ViT model of the PaliGemma model.
323-
ModelConfig VitConfig(const ModelConfig& config);
340+
ModelConfig GetVitConfig(const ModelConfig& config);
324341

325342
} // namespace gcpp
326343

gemma/configs_test.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,13 @@ template <class TConfig>
367367
void AssertMatch(const ModelConfig& config) {
368368
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
369369
if constexpr (TConfig::VitConfig::kModelDim != 0) {
370-
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim);
371-
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len);
372-
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales);
373-
for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) {
370+
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
371+
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
372+
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
373+
config.vit_config.num_scales);
374+
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
374375
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
375-
config.vit_layer_configs[i].type);
376+
config.vit_config.layer_configs[i].type);
376377
}
377378
}
378379
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);

gemma/gemma-inl.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,9 +1042,9 @@ template <typename T>
10421042
HWY_NOINLINE void EmbedImagePatches(const Image& image,
10431043
const ModelWeightsPtrs<T>& weights,
10441044
Activations& activations) {
1045-
const size_t model_dim = weights.weights_config.vit_model_dim;
1046-
const size_t patch_width = weights.weights_config.patch_width;
1047-
const size_t seq_len = weights.weights_config.vit_seq_len;
1045+
const size_t model_dim = weights.weights_config.vit_config.model_dim;
1046+
const size_t patch_width = weights.weights_config.vit_config.patch_width;
1047+
const size_t seq_len = weights.weights_config.vit_config.seq_len;
10481048
const size_t patch_size = patch_width * patch_width * 3;
10491049
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
10501050
patch_size * model_dim);
@@ -1087,14 +1087,15 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
10871087
const Image& image, ImageTokens& image_tokens,
10881088
Activations& activations) {
10891089
PROFILER_ZONE("Gen.PrefillVit");
1090-
const size_t num_tokens = weights.weights_config.vit_seq_len;
1091-
const size_t vit_model_dim = weights.weights_config.vit_model_dim;
1090+
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
1091+
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
10921092
HWY_ASSERT(num_tokens == activations.x.BatchSize());
10931093
// Embed the image patches.
10941094
EmbedImagePatches(image, weights, activations);
10951095
// Go through all layers.
10961096
for (size_t layer = 0;
1097-
layer < weights.weights_config.vit_layer_configs.size(); ++layer) {
1097+
layer < weights.weights_config.vit_config.layer_configs.size();
1098+
++layer) {
10981099
const auto* layer_weights = weights.GetVitLayer(layer);
10991100
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
11001101
}
@@ -1413,11 +1414,11 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
14131414
const RuntimeConfig& runtime_config,
14141415
const Image& image, ImageTokens& image_tokens,
14151416
NestedPools& pools) {
1416-
if (model.Config().vit_layer_configs.empty()) {
1417+
if (model.Config().vit_config.layer_configs.empty()) {
14171418
HWY_ABORT("Model does not support generating image tokens.");
14181419
}
14191420
RuntimeConfig prefill_runtime_config = runtime_config;
1420-
ModelConfig vit_config = VitConfig(model.Config());
1421+
ModelConfig vit_config = GetVitConfig(model.Config());
14211422
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
14221423
Activations prefill_activations(vit_config);
14231424
prefill_activations.Allocate(vit_config.seq_len, pools);

gemma/run.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
9494
Image image;
9595
ImageTokens image_tokens;
9696
if (have_image) {
97-
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
98-
model.GetModelConfig().model_dim));
97+
image_tokens =
98+
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
99+
model.GetModelConfig().model_dim));
99100
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
100101
HWY_ASSERT(image.ReadPPM(args.image_file.path));
101-
const size_t image_size = model.GetModelConfig().image_size;
102+
const size_t image_size = model.GetModelConfig().vit_config.image_size;
102103
image.Resize(image_size, image_size);
103104
RuntimeConfig runtime_config = {
104105
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};

0 commit comments

Comments
 (0)