Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved the vit config fields to their own config struct #471

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,19 @@ static LayerConfig LayerConfigVit(size_t model_dim) {

// Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config.
static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
config.vit_model_dim = 1152;
config.vit_config.model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152
config.image_size = image_size;
config.patch_width = 14;
config.vit_config.image_size = image_size;
config.vit_config.patch_width = 14;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
config.vit_config.seq_len = num_patches * num_patches;
for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = false;
}
const size_t num_patches = config.image_size / config.patch_width;
config.vit_seq_len = num_patches * num_patches;
LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim);
config.vit_layer_configs = {27, vit_layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size();
LayerConfig vit_layer_config = LayerConfigVit(config.vit_config.model_dim);
config.vit_config.layer_configs = {27, vit_layer_config};
config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size();
}

static ModelConfig ConfigPaliGemma_224() {
Expand All @@ -283,11 +284,11 @@ static ModelConfig ConfigPaliGemma_448() {
return config;
}

ModelConfig VitConfig(const ModelConfig& config) {
ModelConfig GetVitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_model_dim;
vit_config.seq_len = config.vit_seq_len;
vit_config.layer_configs = config.vit_layer_configs;
vit_config.model_dim = config.vit_config.model_dim;
vit_config.seq_len = config.vit_config.seq_len;
vit_config.layer_configs = config.vit_config.layer_configs;
// The Vit part does not have a vocabulary, the image patches are embedded.
vit_config.vocab_size = 0;
return vit_config;
Expand Down Expand Up @@ -402,9 +403,28 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
return result;
}

bool VitConfig::TestEqual(const VitConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(seq_len, other.seq_len);
if (!partial) {
TEST_EQUAL(num_scales, other.num_scales);
}
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size());
for (size_t i = 0; i < layer_configs.size(); ++i) {
result &=
layer_configs[i].TestEqual(other.layer_configs[i], partial, debug);
}
return result;
}

bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
bool debug) const {
bool result = true;
TEST_EQUAL(model_family_version, other.model_family_version);
// We don't care about model_name, model, wrapping, or weight being different,
// but will output in debug mode if they are.
if (debug) {
Expand All @@ -415,13 +435,10 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
}
TEST_EQUAL(model_dim, other.model_dim);
TEST_EQUAL(vit_model_dim, other.vit_model_dim);
TEST_EQUAL(vocab_size, other.vocab_size);
TEST_EQUAL(seq_len, other.seq_len);
TEST_EQUAL(vit_seq_len, other.vit_seq_len);
if (!partial) {
TEST_EQUAL(num_tensor_scales, other.num_tensor_scales);
TEST_EQUAL(num_vit_scales, other.num_vit_scales);
}
TEST_EQUAL(att_cap, other.att_cap);
TEST_EQUAL(final_cap, other.final_cap);
Expand All @@ -439,11 +456,6 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]);
}
RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size());
for (size_t i = 0; i < vit_layer_configs.size(); ++i) {
result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i],
partial, debug);
}
if (!partial) {
if (scale_names != other.scale_names) {
result = false;
Expand All @@ -453,9 +465,7 @@ bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
}
}
TEST_EQUAL(norm_num_groups, other.norm_num_groups);
TEST_EQUAL(model_family_version, other.model_family_version);
TEST_EQUAL(patch_width, other.patch_width);
TEST_EQUAL(image_size, other.image_size);
result &= vit_config.TestEqual(other.vit_config, partial, debug);
return result;
}

Expand Down
47 changes: 33 additions & 14 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,33 @@ struct LayerConfig : public IFields {
PostQKType post_qk = PostQKType::Rope;
};

// Dimensions related to image processing.
struct VitConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
// the tensors are loaded from the checkpoint.
// If debug is true, then we output the mismatched fields to stderr.
bool TestEqual(const VitConfig& other, bool partial, bool debug) const;

const char* Name() const override { return "VitConfig"; }

void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_dim);
visitor(seq_len);
visitor(num_scales);
visitor(patch_width);
visitor(image_size);
visitor(layer_configs);
}

uint32_t model_dim = 0;
uint32_t seq_len = 0;
uint32_t num_scales = 0;
uint32_t patch_width = 14;
uint32_t image_size = 224;
std::vector<LayerConfig> layer_configs;
};

struct ModelConfig : public IFields {
// Returns true if *this and other are equal.
// If partial is true, then we don't check for items that are only set after
Expand Down Expand Up @@ -277,40 +304,32 @@ struct ModelConfig : public IFields {
visitor(layer_configs);
visitor(attention_window_sizes);
visitor(norm_num_groups);
visitor(vit_model_dim);
visitor(vit_seq_len);
visitor(num_vit_scales);
visitor(vit_layer_configs);
visitor(patch_width);
visitor(image_size);
visitor(vit_config);
}

// Major version of the model family. It is used as a fallback to distinguish
// between model types when there is no explicit information in the config.
uint32_t model_family_version = 1;
std::string model_name;
Model model = Model::UNKNOWN;
PromptWrapping wrapping = PromptWrapping::GEMMA_PT;
Type weight = Type::kUnknown;
uint32_t num_layers = 0;
uint32_t model_dim = 0;
uint32_t vit_model_dim = 0;
uint32_t vocab_size = 0;
uint32_t seq_len = 0;
uint32_t vit_seq_len = 0;
uint32_t num_tensor_scales = 0;
uint32_t num_vit_scales = 0;
float att_cap = 0.0f;
float final_cap = 0.0f;
bool absolute_pe = false;
bool use_local_attention = false; // griffin only
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs;
std::vector<uint32_t> attention_window_sizes;
std::vector<LayerConfig> vit_layer_configs;
std::unordered_set<std::string> scale_names;
uint32_t norm_num_groups = 1;
uint32_t model_family_version = 1;
// Dimensions related to image processing.
uint32_t patch_width = 14;
uint32_t image_size = 224;
VitConfig vit_config;
};

// Returns the config for the given model.
Expand All @@ -320,7 +339,7 @@ ModelConfig ConfigFromModel(Model model);
Model ModelFromConfig(const ModelConfig& config);

// Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig VitConfig(const ModelConfig& config);
ModelConfig GetVitConfig(const ModelConfig& config);

} // namespace gcpp

Expand Down
11 changes: 6 additions & 5 deletions gemma/configs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,13 @@ template <class TConfig>
void AssertMatch(const ModelConfig& config) {
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
if constexpr (TConfig::VitConfig::kModelDim != 0) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales);
for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_config.model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_config.seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales,
config.vit_config.num_scales);
for (size_t i = 0; i < config.vit_config.layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
config.vit_layer_configs[i].type);
config.vit_config.layer_configs[i].type);
}
}
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
Expand Down
17 changes: 9 additions & 8 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1042,9 +1042,9 @@ template <typename T>
HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelWeightsPtrs<T>& weights,
Activations& activations) {
const size_t model_dim = weights.weights_config.vit_model_dim;
const size_t patch_width = weights.weights_config.patch_width;
const size_t seq_len = weights.weights_config.vit_seq_len;
const size_t model_dim = weights.weights_config.vit_config.model_dim;
const size_t patch_width = weights.weights_config.vit_config.patch_width;
const size_t seq_len = weights.weights_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
patch_size * model_dim);
Expand Down Expand Up @@ -1087,14 +1087,15 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
const Image& image, ImageTokens& image_tokens,
Activations& activations) {
PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = weights.weights_config.vit_seq_len;
const size_t vit_model_dim = weights.weights_config.vit_model_dim;
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.BatchSize());
// Embed the image patches.
EmbedImagePatches(image, weights, activations);
// Go through all layers.
for (size_t layer = 0;
layer < weights.weights_config.vit_layer_configs.size(); ++layer) {
layer < weights.weights_config.vit_config.layer_configs.size();
++layer) {
const auto* layer_weights = weights.GetVitLayer(layer);
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
}
Expand Down Expand Up @@ -1413,11 +1414,11 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
NestedPools& pools) {
if (model.Config().vit_layer_configs.empty()) {
if (model.Config().vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = VitConfig(model.Config());
ModelConfig vit_config = GetVitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
Activations prefill_activations(vit_config);
prefill_activations.Allocate(vit_config.seq_len, pools);
Expand Down
7 changes: 4 additions & 3 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
Image image;
ImageTokens image_tokens;
if (have_image) {
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().model_dim));
image_tokens =
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().image_size;
const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
Expand Down
Loading
Loading