Skip to content

Commit

Permalink
Allow interactive use with new single-file weight format.
Browse files Browse the repository at this point in the history
Add section about new weights format to README.md.
Remove model_type_required parameter.
Update error handling for flags.

PiperOrigin-RevId: 715750530
  • Loading branch information
danielkeysers authored and copybara-github committed Jan 15, 2025
1 parent b93231a commit cac561a
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 35 deletions.
23 changes: 20 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,24 @@ A tall tree stands in front of the building, and a window on the building is
visible from the water. The water is green, and the sky is blue.
```

### Migrating to single-file format

There is now a new format for the weights file, which is a single file that
allows to contain the tokenizer (and the model type) directly. A tool to migrate
from the multi-file format to the single-file format is available.

```sh
compression/migrate_weights \
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
```

After migration, you can use the new weights file with gemma.cpp like this:

```sh
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
```

### Troubleshooting and FAQs

**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
Expand All @@ -331,9 +349,8 @@ and not a pre-trained model (any model with a `-pt` suffix).

**How do I convert my fine-tune to a `.sbs` compressed model file?**

We're working on a python script to convert a standard model format to `.sbs`,
and hope have it available soon. Follow
[this issue](https://github.com/google/gemma.cpp/issues/11) for updates.
See compression/convert_weights.py to convert a pytorch checkpint. (The code may
need updates to work with Gemma-2 models.)

**What are some easy ways to make the model run faster?**

Expand Down
2 changes: 1 addition & 1 deletion compression/migrate_weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ int main(int argc, char** argv) {
fprintf(stderr, "Skipping model load because: %s\n", err);
return 1;
}
gcpp::GemmaEnv env(argc, argv, /*required=*/true);
gcpp::GemmaEnv env(argc, argv);
hwy::ThreadPool pool(0);
env.GetModel()->Save(args.output_weights, pool);
return 0;
Expand Down
6 changes: 3 additions & 3 deletions evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ static AppArgs MakeAppArgs(int argc, char** argv) {
return AppArgs(argc, argv);
}

GemmaEnv::GemmaEnv(int argc, char** argv, bool model_type_required)
: GemmaEnv(LoaderArgs(argc, argv, model_type_required),
InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {}
GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
MakeAppArgs(argc, argv)) {}

QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;
Expand Down
2 changes: 1 addition & 1 deletion evals/benchmark_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct QueryResult {
class GemmaEnv {
public:
// Calls the other constructor with *Args arguments initialized from argv.
GemmaEnv(int argc, char** argv, bool model_type_required = false);
GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app);

Expand Down
1 change: 1 addition & 0 deletions evals/gemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
// This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags:
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models:
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
Expand Down
4 changes: 2 additions & 2 deletions gemma/weights.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ class ModelWeightsStorage {

// Loads the weights from a blob store file. Supports multi-file or
// single-file format. If the weights file contains a TOC, then it is in
// single-file format, and model_type, weight_type, training are ignored,
// single-file format, and model_type, weight_type, wrapping are ignored,
// and tokenizer_proto is required and written to.
// With a multi-file format, file, model_type, weight_type, training are
// With a multi-file format, file, model_type, weight_type, wrapping are
// required and tokenizer_proto is ignored.
BlobError Load(const Path& weights, Model model_type, Type weight_type,
PromptWrapping wrapping, hwy::ThreadPool& pool,
Expand Down
1 change: 1 addition & 0 deletions paligemma/paligemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
// This test can be run manually with the downloaded PaliGemma weights.
// To run the test, pass the following flags:
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models:
// paligemma-3b-mix-224, paligemma2-3b-pt-448

Expand Down
55 changes: 30 additions & 25 deletions util/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,7 @@ static inline NestedPools CreatePools(const AppArgs& app) {
}

struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[], bool required = true)
: model_type_required(required) {
LoaderArgs(int argc, char* argv[]) {
InitAndParse(argc, argv);
}
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
Expand All @@ -140,25 +139,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {

// Returns error string or nullptr if OK.
const char* Validate() {
info_.model = Model::UNKNOWN;
info_.wrapping = PromptWrapping::GEMMA_PT;
info_.weight = Type::kUnknown;
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping)) {
if (model_type_required) return err;
}
if (const char* err = ParseType(weight_type_str, info_.weight)) {
if (model_type_required) return err;
}
if (model_type_required) {
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is "
"required.";
}
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
}
if (!compressed_weights.path.empty()) {
if (weights.path.empty()) {
weights = compressed_weights;
Expand All @@ -174,6 +154,28 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
info_.model = Model::UNKNOWN;
info_.wrapping = PromptWrapping::GEMMA_PT;
info_.weight = Type::kUnknown;
if (!model_type_str.empty()) {
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping);
if (err != nullptr) return err;
}
if (!weight_type_str.empty()) {
const char* err = ParseType(weight_type_str, info_.weight);
if (err != nullptr) return err;
}
if (!tokenizer.path.empty()) {
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
}
// model_type and tokenizer must be either both present or both absent.
// Further checks happen on weight loading.
if (model_type_str.empty() != tokenizer.path.empty()) {
return "Missing or extra flags for model_type or tokenizer.";
}
return nullptr;
}

Expand All @@ -182,7 +184,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path compressed_weights;
std::string model_type_str;
std::string weight_type_str;
bool model_type_required = true;

template <class Visitor>
void ForEach(const Visitor& visitor) {
Expand All @@ -199,7 +200,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP.");
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
}

// Uninitialized before Validate, must call after that.
Expand All @@ -212,15 +213,19 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
};

static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// New weights file format doesn't need tokenizer path or model/weightinfo.
return Gemma(loader.weights, pools);
}
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
}

static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
NestedPools& pools) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// Newer weights file format doesn't need tokenizer path or model/weight
// info.
// New weights file format doesn't need tokenizer path or model/weight info.
return std::make_unique<Gemma>(loader.weights, pools);
}
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
Expand Down

0 comments on commit cac561a

Please sign in to comment.