Skip to content

Commit 03c2a42

Browse files
pcullitoncopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 700761457
1 parent 6a34e9c commit 03c2a42

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

gemma/common.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ constexpr const char* kModelFlags[] = {
3939
"9b-pt", "9b-it", // Gemma2 9B
4040
"27b-pt", "27b-it", // Gemma2 27B
4141
"paligemma-224", // PaliGemma 224
42+
"paligemma2-3b-224", // PaliGemma2 3B 224
43+
"paligemma2-10b-224", // PaliGemma2 10B 224
4244
};
4345
constexpr Model kModelTypes[] = {
4446
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
@@ -49,6 +51,8 @@ constexpr Model kModelTypes[] = {
4951
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
5052
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
5153
Model::PALIGEMMA_224, // PaliGemma 224
54+
Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224
55+
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
5256
};
5357
constexpr ModelTraining kModelTraining[] = {
5458
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
@@ -59,6 +63,8 @@ constexpr ModelTraining kModelTraining[] = {
5963
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
6064
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
6165
ModelTraining::PALIGEMMA, // PaliGemma 224
66+
ModelTraining::PALIGEMMA, // PaliGemma2 3B 224
67+
ModelTraining::PALIGEMMA, // PaliGemma2 10B 224
6268
};
6369

6470
constexpr size_t kNumModelFlags =

gemma/configs.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,22 @@ ModelConfig VitConfig(const ModelConfig& config) {
246246
return vit_config;
247247
}
248248

249+
static ModelConfig ConfigPaliGemma2_3B_224() {
250+
ModelConfig config = ConfigGemma2_2B();
251+
config.model_name = "PaliGemma2_3B_224";
252+
config.model = Model::PALIGEMMA2_3B_224;
253+
AddVitConfig(config);
254+
return config;
255+
}
256+
257+
static ModelConfig ConfigPaliGemma2_10B_224() {
258+
ModelConfig config = ConfigGemma2_9B();
259+
config.model_name = "PaliGemma2_10B_224";
260+
config.model = Model::PALIGEMMA2_10B_224;
261+
AddVitConfig(config);
262+
return config;
263+
}
264+
249265
ModelConfig ConfigFromModel(Model model) {
250266
switch (model) {
251267
case Model::GEMMA_2B:
@@ -264,6 +280,10 @@ ModelConfig ConfigFromModel(Model model) {
264280
return ConfigGemmaTiny();
265281
case Model::PALIGEMMA_224:
266282
return ConfigPaliGemma_224();
283+
case Model::PALIGEMMA2_3B_224:
284+
return ConfigPaliGemma2_3B_224();
285+
case Model::PALIGEMMA2_10B_224:
286+
return ConfigPaliGemma2_10B_224();
267287
default:
268288
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
269289
}

gemma/configs.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,15 @@ enum class Model {
114114
GEMMA_TINY,
115115
GEMMA2_2B,
116116
PALIGEMMA_224,
117+
PALIGEMMA2_3B_224,
118+
PALIGEMMA2_10B_224,
117119
};
118120

119121
// Allows the Model enum to be iterated over.
120122
static constexpr Model kAllModels[] = {
121123
Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B,
122124
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
123-
Model::PALIGEMMA_224,
125+
Model::PALIGEMMA_224, Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_10B_224,
124126
};
125127

126128
struct LayerConfig {

0 commit comments

Comments
 (0)