From 93650aeb048603f38c21e23c52d39f1091191dde Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Mon, 23 Dec 2024 05:21:33 -0800 Subject: [PATCH] Moved the vit config fields to their own config struct PiperOrigin-RevId: 709034086 --- gemma/configs.cc | 56 ++++++++++++++++++++++--------------- gemma/configs.h | 45 +++++++++++++++++++---------- gemma/configs_test.cc | 11 ++++---- gemma/gemma-inl.h | 17 +++++------ gemma/run.cc | 7 +++-- gemma/tensor_index.cc | 45 ++++++++++++++--------------- gemma/tensor_index_test.cc | 2 +- gemma/weights.cc | 2 +- gemma/weights.h | 9 +++--- paligemma/paligemma_test.cc | 7 +++-- 10 files changed, 117 insertions(+), 84 deletions(-) diff --git a/gemma/configs.cc b/gemma/configs.cc index 7cb4574..89f8f8d 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -253,18 +253,19 @@ static LayerConfig LayerConfigVit(size_t model_dim) { // Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config. static void AddVitConfig(ModelConfig& config, size_t image_size = 224) { - config.vit_model_dim = 1152; + config.vit_config.model_dim = 1152; config.vocab_size = 256000 + 1024 + 128; // = 257152 - config.image_size = image_size; - config.patch_width = 14; + config.vit_config.image_size = image_size; + config.vit_config.patch_width = 14; + const size_t num_patches = + config.vit_config.image_size / config.vit_config.patch_width; + config.vit_config.seq_len = num_patches * num_patches; for (auto& layer_config : config.layer_configs) { layer_config.optimized_gating = false; } - const size_t num_patches = config.image_size / config.patch_width; - config.vit_seq_len = num_patches * num_patches; - LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim); - config.vit_layer_configs = {27, vit_layer_config}; - config.num_vit_scales = 4 * config.vit_layer_configs.size(); + LayerConfig vit_layer_config = LayerConfigVit(config.vit_config.model_dim); + config.vit_config.layer_configs = {27, vit_layer_config}; + config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size(); } static ModelConfig ConfigPaliGemma_224() { @@ -283,11 +284,11 @@ static ModelConfig ConfigPaliGemma_448() { return config; } -ModelConfig VitConfig(const ModelConfig& config) { +ModelConfig GetVitConfig(const ModelConfig& config) { ModelConfig vit_config = ConfigNoSSM(); - vit_config.model_dim = config.vit_model_dim; - vit_config.seq_len = config.vit_seq_len; - vit_config.layer_configs = config.vit_layer_configs; + vit_config.model_dim = config.vit_config.model_dim; + vit_config.seq_len = config.vit_config.seq_len; + vit_config.layer_configs = config.vit_config.layer_configs; // The Vit part does not have a vocabulary, the image patches are embedded. vit_config.vocab_size = 0; return vit_config; @@ -402,9 +403,28 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial, return result; } +bool VitConfig::TestEqual(const VitConfig& other, bool partial, + bool debug) const { + bool result = true; + TEST_EQUAL(model_dim, other.model_dim); + TEST_EQUAL(seq_len, other.seq_len); + if (!partial) { + TEST_EQUAL(num_scales, other.num_scales); + } + TEST_EQUAL(patch_width, other.patch_width); + TEST_EQUAL(image_size, other.image_size); + RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size()); + for (size_t i = 0; i < layer_configs.size(); ++i) { + result &= + layer_configs[i].TestEqual(other.layer_configs[i], partial, debug); + } + return result; +} + bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, bool debug) const { bool result = true; + TEST_EQUAL(model_family_version, other.model_family_version); // We don't care about model_name, model, wrapping, or weight being different, // but will output in debug mode if they are. if (debug) { @@ -415,13 +435,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, WARN_IF_NOT_EQUAL(static_cast(weight), static_cast(other.weight)); } TEST_EQUAL(model_dim, other.model_dim); - TEST_EQUAL(vit_model_dim, other.vit_model_dim); TEST_EQUAL(vocab_size, other.vocab_size); TEST_EQUAL(seq_len, other.seq_len); - TEST_EQUAL(vit_seq_len, other.vit_seq_len); if (!partial) { TEST_EQUAL(num_tensor_scales, other.num_tensor_scales); - TEST_EQUAL(num_vit_scales, other.num_vit_scales); } TEST_EQUAL(att_cap, other.att_cap); TEST_EQUAL(final_cap, other.final_cap); @@ -439,11 +456,6 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, for (size_t i = 0; i < attention_window_sizes.size(); ++i) { TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]); } - RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size()); - for (size_t i = 0; i < vit_layer_configs.size(); ++i) { - result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i], - partial, debug); - } if (!partial) { if (scale_names != other.scale_names) { result = false; @@ -453,9 +465,7 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, } } TEST_EQUAL(norm_num_groups, other.norm_num_groups); - TEST_EQUAL(model_family_version, other.model_family_version); - TEST_EQUAL(patch_width, other.patch_width); - TEST_EQUAL(image_size, other.image_size); + vit_config.TestEqual(other.vit_config, partial, debug); return result; } diff --git a/gemma/configs.h b/gemma/configs.h index aad5d32..b353a51 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -220,6 +220,33 @@ struct LayerConfig : public IFields { PostQKType post_qk = PostQKType::Rope; }; +// Dimensions related to image processing. +struct VitConfig : public IFields { + // Returns true if *this and other are equal. + // If partial is true, then we don't check for items that are only set after + // the tensors are loaded from the checkpoint. + // If debug is true, then we output the mismatched fields to stderr. + bool TestEqual(const VitConfig& other, bool partial, bool debug) const; + + const char* Name() const override { return "VitConfig"; } + + void VisitFields(IFieldsVisitor& visitor) override { + visitor(model_dim); + visitor(seq_len); + visitor(num_scales); + visitor(patch_width); + visitor(image_size); + visitor(layer_configs); + } + + uint32_t model_dim = 0; + uint32_t seq_len = 0; + uint32_t num_scales = 0; + uint32_t patch_width = 14; + uint32_t image_size = 224; + std::vector layer_configs; +}; + struct ModelConfig : public IFields { // Returns true if *this and other are equal. // If partial is true, then we don't check for items that are only set after @@ -277,26 +304,19 @@ struct ModelConfig : public IFields { visitor(layer_configs); visitor(attention_window_sizes); visitor(norm_num_groups); - visitor(vit_model_dim); - visitor(vit_seq_len); - visitor(num_vit_scales); - visitor(vit_layer_configs); - visitor(patch_width); - visitor(image_size); + visitor(vit_config); } + uint32_t model_family_version = 1; std::string model_name; Model model = Model::UNKNOWN; PromptWrapping wrapping = PromptWrapping::GEMMA_PT; Type weight = Type::kUnknown; uint32_t num_layers = 0; uint32_t model_dim = 0; - uint32_t vit_model_dim = 0; uint32_t vocab_size = 0; uint32_t seq_len = 0; - uint32_t vit_seq_len = 0; uint32_t num_tensor_scales = 0; - uint32_t num_vit_scales = 0; float att_cap = 0.0f; float final_cap = 0.0f; bool absolute_pe = false; @@ -304,13 +324,10 @@ struct ModelConfig : public IFields { QueryScaleType query_scale = QueryScaleType::SqrtKeySize; std::vector layer_configs; std::vector attention_window_sizes; - std::vector vit_layer_configs; std::unordered_set scale_names; uint32_t norm_num_groups = 1; - uint32_t model_family_version = 1; // Dimensions related to image processing. - uint32_t patch_width = 14; - uint32_t image_size = 224; + VitConfig vit_config; }; // Returns the config for the given model. @@ -320,7 +337,7 @@ ModelConfig ConfigFromModel(Model model); Model ModelFromConfig(const ModelConfig& config); // Returns the sub-config for the ViT model of the PaliGemma model. -ModelConfig VitConfig(const ModelConfig& config); +ModelConfig GetVitConfig(const ModelConfig& config); } // namespace gcpp diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 456d5fa..3efd2cb 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -367,12 +367,13 @@ template void AssertMatch(const ModelConfig& config) { ASSERT_EQ(TConfig::kModelDim, config.model_dim); if constexpr (TConfig::VitConfig::kModelDim != 0) { - ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim); - ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len); - ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales); - for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) { + ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim); + ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len); + ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, + config.vit_config.num_scales); + for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) { ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i], - config.vit_layer_configs[i].type); + config.vit_config.layer_configs[i].type); } } ASSERT_EQ(TConfig::kVocabSize, config.vocab_size); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 81a9469..e1deea1 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1042,9 +1042,9 @@ template HWY_NOINLINE void EmbedImagePatches(const Image& image, const ModelWeightsPtrs& weights, Activations& activations) { - const size_t model_dim = weights.weights_config.vit_model_dim; - const size_t patch_width = weights.weights_config.patch_width; - const size_t seq_len = weights.weights_config.vit_seq_len; + const size_t model_dim = weights.weights_config.vit_config.model_dim; + const size_t patch_width = weights.weights_config.vit_config.patch_width; + const size_t seq_len = weights.weights_config.vit_config.seq_len; const size_t patch_size = patch_width * patch_width * 3; HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() == patch_size * model_dim); @@ -1087,14 +1087,15 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, const Image& image, ImageTokens& image_tokens, Activations& activations) { PROFILER_ZONE("Gen.PrefillVit"); - const size_t num_tokens = weights.weights_config.vit_seq_len; - const size_t vit_model_dim = weights.weights_config.vit_model_dim; + const size_t num_tokens = weights.weights_config.vit_config.seq_len; + const size_t vit_model_dim = weights.weights_config.vit_config.model_dim; HWY_ASSERT(num_tokens == activations.x.BatchSize()); // Embed the image patches. EmbedImagePatches(image, weights, activations); // Go through all layers. for (size_t layer = 0; - layer < weights.weights_config.vit_layer_configs.size(); ++layer) { + layer < weights.weights_config.vit_config.layer_configs.size(); + ++layer) { const auto* layer_weights = weights.GetVitLayer(layer); VitTransformerLayer(num_tokens, layer, layer_weights, activations); } @@ -1413,11 +1414,11 @@ void GenerateImageTokensT(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens, NestedPools& pools) { - if (model.Config().vit_layer_configs.empty()) { + if (model.Config().vit_config.layer_configs.empty()) { HWY_ABORT("Model does not support generating image tokens."); } RuntimeConfig prefill_runtime_config = runtime_config; - ModelConfig vit_config = VitConfig(model.Config()); + ModelConfig vit_config = GetVitConfig(model.Config()); prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len; Activations prefill_activations(vit_config); prefill_activations.Allocate(vit_config.seq_len, pools); diff --git a/gemma/run.cc b/gemma/run.cc index b7082e4..712bdcb 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -94,11 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, Image image; ImageTokens image_tokens; if (have_image) { - image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, - model.GetModelConfig().model_dim)); + image_tokens = + ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len, + model.GetModelConfig().model_dim)); HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(image.ReadPPM(args.image_file.path)); - const size_t image_size = model.GetModelConfig().image_size; + const size_t image_size = model.GetModelConfig().vit_config.image_size; image.Resize(image_size, image_size); RuntimeConfig runtime_config = { .gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin}; diff --git a/gemma/tensor_index.cc b/gemma/tensor_index.cc index 50c7af0..354a1b4 100644 --- a/gemma/tensor_index.cc +++ b/gemma/tensor_index.cc @@ -36,29 +36,29 @@ std::vector ModelTensors(const ModelConfig& config) { .name = "enc_norm_bias", .source_names = {"img/Transformer/encoder_norm/bias"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "enc_norm_scale", .source_names = {"img/Transformer/encoder_norm/scale"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "img_emb_bias", .source_names = {"img/embedding/bias"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kF32, }, TensorInfo{ .name = "img_emb_kernel", .source_names = {"img/embedding/kernel"}, .axes = {3, 0, 1, 2}, - .shape = {config.vit_model_dim, config.patch_width, - config.patch_width, 3}, + .shape = {config.vit_config.model_dim, config.vit_config.patch_width, + config.vit_config.patch_width, 3}, .min_size = Type::kBF16, .cols_take_extra_dims = true, }, @@ -73,14 +73,15 @@ std::vector ModelTensors(const ModelConfig& config) { .name = "img_head_kernel", .source_names = {"img/head/kernel"}, .axes = {1, 0}, - .shape = {config.model_dim, config.vit_model_dim}, + .shape = {config.model_dim, config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "img_pos_emb", .source_names = {"img/pos_embedding"}, .axes = {0, 1}, - .shape = {/*1,*/ config.vit_seq_len, config.vit_model_dim}, + .shape = {/*1,*/ config.vit_config.seq_len, + config.vit_config.model_dim}, .min_size = Type::kF32, }, }; @@ -95,7 +96,7 @@ std::vector ImageLayerTensors(const ModelConfig& config, .name = "attn_out_w", .source_names = {"MultiHeadDotProductAttention_0/out/kernel"}, .axes = {2, 0, 1}, - .shape = {config.vit_model_dim, layer_config.heads, + .shape = {config.vit_config.model_dim, layer_config.heads, layer_config.qkv_dim}, .min_size = Type::kBF16, .cols_take_extra_dims = true, @@ -104,7 +105,7 @@ std::vector ImageLayerTensors(const ModelConfig& config, .name = "attn_out_b", .source_names = {"MultiHeadDotProductAttention_0/out/bias"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kF32, }, TensorInfo{ @@ -112,7 +113,7 @@ std::vector ImageLayerTensors(const ModelConfig& config, .source_names = {"MultiHeadDotProductAttention_0/query/kernel"}, .axes = {1, 2, 0}, .shape = {layer_config.heads, layer_config.qkv_dim, - config.vit_model_dim}, + config.vit_config.model_dim}, .concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"}, .concat_axis = 1, .min_size = Type::kBF16, @@ -122,7 +123,7 @@ std::vector ImageLayerTensors(const ModelConfig& config, .source_names = {"MultiHeadDotProductAttention_0/key/kernel"}, .axes = {1, 2, 0}, .shape = {layer_config.heads, layer_config.qkv_dim, - config.vit_model_dim}, + config.vit_config.model_dim}, .concat_names = {""}, .min_size = Type::kBF16, }, @@ -131,7 +132,7 @@ std::vector ImageLayerTensors(const ModelConfig& config, .source_names = {"MultiHeadDotProductAttention_0/value/kernel"}, .axes = {1, 2, 0}, .shape = {layer_config.heads, layer_config.qkv_dim, - config.vit_model_dim}, + config.vit_config.model_dim}, .concat_names = {""}, .min_size = Type::kBF16, }, @@ -140,7 +141,7 @@ std::vector ImageLayerTensors(const ModelConfig& config, .source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"}, .axes = {1, 2, 0}, .shape = {layer_config.heads, 3 * layer_config.qkv_dim, - config.vit_model_dim}, + config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ @@ -180,7 +181,7 @@ std::vector ImageLayerTensors(const ModelConfig& config, .name = "linear_0_w", .source_names = {"MlpBlock_0/Dense_0/kernel"}, .axes = {1, 0}, - .shape = {layer_config.ff_hidden_dim, config.vit_model_dim}, + .shape = {layer_config.ff_hidden_dim, config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ @@ -194,42 +195,42 @@ std::vector ImageLayerTensors(const ModelConfig& config, .name = "linear_1_w", .source_names = {"MlpBlock_0/Dense_1/kernel"}, .axes = {1, 0}, - .shape = {config.vit_model_dim, layer_config.ff_hidden_dim}, + .shape = {config.vit_config.model_dim, layer_config.ff_hidden_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "linear_1_b", .source_names = {"MlpBlock_0/Dense_1/bias"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kF32, }, TensorInfo{ .name = "ln_0_bias", .source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "ln_0_scale", .source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "ln_1_bias", .source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "ln_1_scale", .source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"}, .axes = {0}, - .shape = {config.vit_model_dim}, + .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, }; @@ -526,8 +527,8 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx, if (llm_layer_idx < 0 && img_layer_idx < 0) { tensors_ = ModelTensors(config); } else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx && - img_layer_idx < config.vit_layer_configs.size()) { - const auto& layer_config = config.vit_layer_configs[img_layer_idx]; + img_layer_idx < config.vit_config.layer_configs.size()) { + const auto& layer_config = config.vit_config.layer_configs[img_layer_idx]; tensors_ = ImageLayerTensors(config, layer_config); } else if (0 <= llm_layer_idx && llm_layer_idx < config.layer_configs.size()) { diff --git a/gemma/tensor_index_test.cc b/gemma/tensor_index_test.cc index 43eaa4e..50ff0b6 100644 --- a/gemma/tensor_index_test.cc +++ b/gemma/tensor_index_test.cc @@ -35,7 +35,7 @@ TEST(TensorIndexTest, FindName) { /*split_and_reshape=*/false); } for (size_t img_layer_idx = 0; - img_layer_idx < config.vit_layer_configs.size(); + img_layer_idx < config.vit_config.layer_configs.size(); ++img_layer_idx) { tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1, static_cast(img_layer_idx), diff --git a/gemma/weights.cc b/gemma/weights.cc index 568edb8..426de6d 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -86,7 +86,7 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type, config_ = ConfigFromModel(model_type); config_.weight = weight_type; config_.wrapping = wrapping; - scales.resize(config_.num_tensor_scales + config_.num_vit_scales); + scales.resize(config_.num_tensor_scales + config_.vit_config.num_scales); } CreateForType(config_.weight, pool); CallForModelWeightT(fet, loader); diff --git a/gemma/weights.h b/gemma/weights.h index 33da842..8cb4bce 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -344,8 +344,9 @@ struct ModelWeightsPtrs { c_layers.push_back(LayerWeightsPtrs(layer_config, tensor_index)); } for (int index = 0; - index < static_cast(config.vit_layer_configs.size()); ++index) { - const auto& layer_config = config.vit_layer_configs[index]; + index < static_cast(config.vit_config.layer_configs.size()); + ++index) { + const auto& layer_config = config.vit_config.layer_configs[index]; TensorIndex tensor_index(config, /*llm_layer_idx=*/-1, index, /*reshape_att=*/false); vit_layers.push_back( @@ -479,7 +480,7 @@ struct ModelWeightsPtrs { int sep_index = -1; GEMMA_CALL_FUNC(embedder_input_embedding); GEMMA_CALL_FUNC(final_norm_scale); - if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) { + if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) { // Vit parts. GEMMA_CALL_FUNC(vit_encoder_norm_bias); GEMMA_CALL_FUNC(vit_encoder_norm_scale); @@ -498,7 +499,7 @@ struct ModelWeightsPtrs { } // Vit layers. Not supported for compress_weights. - if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) { + if (ptrs[0]->weights_config.vit_config.layer_configs.size() > 0) { for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size(); ++layer_idx) { auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type; diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index f5d5304..b28849e 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -50,12 +50,13 @@ class PaliGemmaTest : public ::testing::Test { void PaliGemmaTest::InitVit(const std::string& path) { ASSERT_NE(s_env->GetModel(), nullptr); Gemma& model = *(s_env->GetModel()); - image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, - model.GetModelConfig().model_dim)); + image_tokens_ = + ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len, + model.GetModelConfig().model_dim)); Image image; HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(image.ReadPPM(path)); - const size_t image_size = model.GetModelConfig().image_size; + const size_t image_size = model.GetModelConfig().vit_config.image_size; image.Resize(image_size, image_size); RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; model.GenerateImageTokens(runtime_config, image, image_tokens_);