Skip to content

Commit cac561a

Browse files
danielkeyserscopybara-github
authored andcommitted
Allow interactive use with new single-file weight format.
Add section about new weights format to README.md. Remove model_type_required parameter. Update error handling for flags. PiperOrigin-RevId: 715750530
1 parent b93231a commit cac561a

File tree

8 files changed

+59
-35
lines changed

8 files changed

+59
-35
lines changed

README.md

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,24 @@ A tall tree stands in front of the building, and a window on the building is
305305
visible from the water. The water is green, and the sky is blue.
306306
```
307307

308+
### Migrating to single-file format
309+
310+
There is now a new format for the weights file, which is a single file that
311+
allows to contain the tokenizer (and the model type) directly. A tool to migrate
312+
from the multi-file format to the single-file format is available.
313+
314+
```sh
315+
compression/migrate_weights \
316+
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
317+
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
318+
```
319+
320+
After migration, you can use the new weights file with gemma.cpp like this:
321+
322+
```sh
323+
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
324+
```
325+
308326
### Troubleshooting and FAQs
309327

310328
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
@@ -331,9 +349,8 @@ and not a pre-trained model (any model with a `-pt` suffix).
331349

332350
**How do I convert my fine-tune to a `.sbs` compressed model file?**
333351

334-
We're working on a python script to convert a standard model format to `.sbs`,
335-
and hope have it available soon. Follow
336-
[this issue](https://github.com/google/gemma.cpp/issues/11) for updates.
352+
See compression/convert_weights.py to convert a pytorch checkpint. (The code may
353+
need updates to work with Gemma-2 models.)
337354

338355
**What are some easy ways to make the model run faster?**
339356

compression/migrate_weights.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ int main(int argc, char** argv) {
5555
fprintf(stderr, "Skipping model load because: %s\n", err);
5656
return 1;
5757
}
58-
gcpp::GemmaEnv env(argc, argv, /*required=*/true);
58+
gcpp::GemmaEnv env(argc, argv);
5959
hwy::ThreadPool pool(0);
6060
env.GetModel()->Save(args.output_weights, pool);
6161
return 0;

evals/benchmark_helper.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ static AppArgs MakeAppArgs(int argc, char** argv) {
9292
return AppArgs(argc, argv);
9393
}
9494

95-
GemmaEnv::GemmaEnv(int argc, char** argv, bool model_type_required)
96-
: GemmaEnv(LoaderArgs(argc, argv, model_type_required),
97-
InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {}
95+
GemmaEnv::GemmaEnv(int argc, char** argv)
96+
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
97+
MakeAppArgs(argc, argv)) {}
9898

9999
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
100100
QueryResult result;

evals/benchmark_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct QueryResult {
4444
class GemmaEnv {
4545
public:
4646
// Calls the other constructor with *Args arguments initialized from argv.
47-
GemmaEnv(int argc, char** argv, bool model_type_required = false);
47+
GemmaEnv(int argc, char** argv);
4848
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
4949
const AppArgs& app);
5050

evals/gemma_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
// This test can be run manually with the downloaded gemma weights.
2929
// To run the test, pass the following flags:
3030
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
31+
// or just use the single-file weights file with --weights <weights_path>.
3132
// It should pass for the following models:
3233
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
3334
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,

gemma/weights.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,9 +525,9 @@ class ModelWeightsStorage {
525525

526526
// Loads the weights from a blob store file. Supports multi-file or
527527
// single-file format. If the weights file contains a TOC, then it is in
528-
// single-file format, and model_type, weight_type, training are ignored,
528+
// single-file format, and model_type, weight_type, wrapping are ignored,
529529
// and tokenizer_proto is required and written to.
530-
// With a multi-file format, file, model_type, weight_type, training are
530+
// With a multi-file format, file, model_type, weight_type, wrapping are
531531
// required and tokenizer_proto is ignored.
532532
BlobError Load(const Path& weights, Model model_type, Type weight_type,
533533
PromptWrapping wrapping, hwy::ThreadPool& pool,

paligemma/paligemma_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
// This test can be run manually with the downloaded PaliGemma weights.
2828
// To run the test, pass the following flags:
2929
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
30+
// or just use the single-file weights file with --weights <weights_path>.
3031
// It should pass for the following models:
3132
// paligemma-3b-mix-224, paligemma2-3b-pt-448
3233

util/app.h

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ static inline NestedPools CreatePools(const AppArgs& app) {
126126
}
127127

128128
struct LoaderArgs : public ArgsBase<LoaderArgs> {
129-
LoaderArgs(int argc, char* argv[], bool required = true)
130-
: model_type_required(required) {
129+
LoaderArgs(int argc, char* argv[]) {
131130
InitAndParse(argc, argv);
132131
}
133132
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
@@ -140,25 +139,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
140139

141140
// Returns error string or nullptr if OK.
142141
const char* Validate() {
143-
info_.model = Model::UNKNOWN;
144-
info_.wrapping = PromptWrapping::GEMMA_PT;
145-
info_.weight = Type::kUnknown;
146-
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
147-
info_.wrapping)) {
148-
if (model_type_required) return err;
149-
}
150-
if (const char* err = ParseType(weight_type_str, info_.weight)) {
151-
if (model_type_required) return err;
152-
}
153-
if (model_type_required) {
154-
if (tokenizer.path.empty()) {
155-
return "Missing --tokenizer flag, a file for the tokenizer is "
156-
"required.";
157-
}
158-
if (!tokenizer.Exists()) {
159-
return "Can't open file specified with --tokenizer flag.";
160-
}
161-
}
162142
if (!compressed_weights.path.empty()) {
163143
if (weights.path.empty()) {
164144
weights = compressed_weights;
@@ -174,6 +154,28 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
174154
if (!weights.Exists()) {
175155
return "Can't open file specified with --weights flag.";
176156
}
157+
info_.model = Model::UNKNOWN;
158+
info_.wrapping = PromptWrapping::GEMMA_PT;
159+
info_.weight = Type::kUnknown;
160+
if (!model_type_str.empty()) {
161+
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
162+
info_.wrapping);
163+
if (err != nullptr) return err;
164+
}
165+
if (!weight_type_str.empty()) {
166+
const char* err = ParseType(weight_type_str, info_.weight);
167+
if (err != nullptr) return err;
168+
}
169+
if (!tokenizer.path.empty()) {
170+
if (!tokenizer.Exists()) {
171+
return "Can't open file specified with --tokenizer flag.";
172+
}
173+
}
174+
// model_type and tokenizer must be either both present or both absent.
175+
// Further checks happen on weight loading.
176+
if (model_type_str.empty() != tokenizer.path.empty()) {
177+
return "Missing or extra flags for model_type or tokenizer.";
178+
}
177179
return nullptr;
178180
}
179181

@@ -182,7 +184,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
182184
Path compressed_weights;
183185
std::string model_type_str;
184186
std::string weight_type_str;
185-
bool model_type_required = true;
186187

187188
template <class Visitor>
188189
void ForEach(const Visitor& visitor) {
@@ -199,7 +200,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
199200
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
200201
"gr2b-pt = griffin 2B parameters, pretrained.");
201202
visitor(weight_type_str, "weight_type", std::string("sfp"),
202-
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP.");
203+
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
203204
}
204205

205206
// Uninitialized before Validate, must call after that.
@@ -212,15 +213,19 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
212213
};
213214

214215
static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
216+
if (Type::kUnknown == loader.Info().weight ||
217+
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
218+
// New weights file format doesn't need tokenizer path or model/weightinfo.
219+
return Gemma(loader.weights, pools);
220+
}
215221
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
216222
}
217223

218224
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
219225
NestedPools& pools) {
220226
if (Type::kUnknown == loader.Info().weight ||
221227
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
222-
// Newer weights file format doesn't need tokenizer path or model/weight
223-
// info.
228+
// New weights file format doesn't need tokenizer path or model/weight info.
224229
return std::make_unique<Gemma>(loader.weights, pools);
225230
}
226231
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,

0 commit comments

Comments
 (0)