From 9d40f0117e45c12eca0945ad7b35bb4f61869df4 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 19 Dec 2024 07:59:08 -0800 Subject: [PATCH] Added ability to load/save a complete model file, including tokenizer. PiperOrigin-RevId: 707914366 --- BUILD.bazel | 3 + backprop/backward_scalar_test.cc | 13 +- backprop/backward_test.cc | 13 +- compression/BUILD.bazel | 15 +- compression/compress-inl.h | 39 ++-- compression/compress.h | 211 +++++++++++++------- compression/compress_weights.cc | 40 +++- compression/fields.cc | 32 +++ compression/fields.h | 9 +- compression/fields_test.cc | 16 ++ compression/migrate_weights.cc | 62 ++++++ compression/python/BUILD.bazel | 1 + compression/python/compression_clif_aux.cc | 25 ++- compression/python/compression_clif_aux.h | 5 +- compression/python/compression_extension.cc | 4 +- compression/shared.h | 10 + evals/benchmark_helper.cc | 10 +- evals/benchmark_helper.h | 2 +- gemma/configs.cc | 177 +++++++++------- gemma/configs.h | 155 ++++++++++---- gemma/configs_test.cc | 12 ++ gemma/gemma.cc | 17 +- gemma/gemma.h | 15 +- gemma/kv_cache.cc | 2 +- gemma/tensor_index.cc | 2 + gemma/tensor_index_test.cc | 3 +- gemma/tokenizer.cc | 17 ++ gemma/tokenizer.h | 3 + gemma/weights.cc | 61 +++++- gemma/weights.h | 11 + ops/dot_test.cc | 1 - util/app.h | 41 ++-- 32 files changed, 770 insertions(+), 257 deletions(-) create mode 100644 compression/migrate_weights.cc diff --git a/BUILD.bazel b/BUILD.bazel index d06acd5b..1dc6a1f4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -245,6 +245,7 @@ cc_library( "gemma/tensor_index.h", ], deps = [ + "//compression:fields", "//compression:sfp", "@highway//:hwy", # base.h "@highway//:thread_pool", @@ -257,6 +258,7 @@ cc_test( deps = [ ":common", "@googletest//:gtest_main", + "@highway//:hwy", ], ) @@ -388,6 +390,7 @@ cc_library( ":ops", ":threading", "//compression:io", + "//compression:sfp", "@highway//:hwy", ], ) diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index d99a0673..e40f3ed5 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -390,13 +390,12 @@ static ModelConfig TestConfig() { config.model_dim = 32; config.vocab_size = 12; config.seq_len = 18; - LayerConfig layer_config = { - .model_dim = config.model_dim, - .ff_hidden_dim = 48, - .heads = 3, - .kv_heads = 1, - .qkv_dim = 12, - }; + LayerConfig layer_config; + layer_config.model_dim = config.model_dim; + layer_config.ff_hidden_dim = 48; + layer_config.heads = 3; + layer_config.kv_heads = 1; + layer_config.qkv_dim = 12; config.layer_configs = {2, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtKeySize; diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 2b82c120..0df079d2 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -191,13 +191,12 @@ static ModelConfig TestConfig() { config.model_dim = 32; config.vocab_size = 16; config.seq_len = 24; - LayerConfig layer_config = { - .model_dim = config.model_dim, - .ff_hidden_dim = 64, - .heads = 3, - .kv_heads = 1, - .qkv_dim = 16, - }; + LayerConfig layer_config; + layer_config.model_dim = config.model_dim; + layer_config.ff_hidden_dim = 64; + layer_config.heads = 3; + layer_config.kv_heads = 1; + layer_config.qkv_dim = 16; config.layer_configs = {2, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtKeySize; diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index c4f8d6be..25df258e 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -58,7 +58,6 @@ cc_test( deps = [ ":fields", "@googletest//:gtest_main", # buildcleaner: keep - "@highway//:hwy", "@highway//:hwy_test_util", ], ) @@ -202,6 +201,7 @@ cc_library( deps = [ ":blob_store", ":distortion", + ":fields", ":io", ":nuq", ":sfp", @@ -210,7 +210,6 @@ cc_library( "//:common", "@highway//:hwy", "@highway//:nanobenchmark", - "@highway//:profiler", "@highway//:stats", "@highway//:thread_pool", ], @@ -261,6 +260,7 @@ cc_binary( "//:allocator", "//:args", "//:common", + "//:tokenizer", "//:weights", "@highway//:hwy", "@highway//:thread_pool", @@ -277,3 +277,14 @@ cc_binary( "@highway//:hwy_test_util", ], ) + +cc_binary( + name = "migrate_weights", + srcs = ["migrate_weights.cc"], + deps = [ + "//:app", + "//:args", + "//:benchmark_helper", + "//:gemma_lib", + ], +) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 651b8a22..e850a0e3 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -22,10 +22,13 @@ #include #include // lroundf, only if COMPRESS_STATS +#include +#include #include "compression/blob_store.h" #include "compression/compress.h" // IWYU pragma: export #include "compression/distortion.h" +#include "gemma/configs.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -673,36 +676,37 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan v, // their scaling factors to BlobStore. class Compressor { public: - explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {} + explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {} template void operator()(MatPtrT* compressed, const char* decorated_name, const float* HWY_RESTRICT weights) { size_t num_weights = compressed->NumElements(); + if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr) + return; size_t num_compressed = compressed->NumElements(); PackedSpan packed = MakeSpan(compressed->data(), num_compressed); fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name, num_weights / (1000 * 1000)); - Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, pool_); - const size_t num_bytes = packed.num * sizeof(Packed); - writer_.Add(MakeKey(decorated_name), packed.ptr, num_bytes); + Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, + writer_.pool()); + writer_(compressed, decorated_name); + } + + void AddTokenizer(const std::string& tokenizer) { + writer_.AddTokenizer(tokenizer); } void AddScales(const float* scales, size_t len) { - if (len) { - MatPtrT scales_ptr("scales", 0, 1); - writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales, - len * sizeof(scales[0])); - } + writer_.AddScales(scales, len); } - BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) { - const BlobError err = writer_.WriteAll(pool, blob_filename); - if (err != 0) { - fprintf(stderr, "Failed to write blobs to %s (error %d)\n", - blob_filename.path.c_str(), err); - } - return err; + // Writes all blobs to disk in the given order. The config is optional and + // if given, it is written to the file, along with the TOC, making it + // single-file format. Otherwise, the file is written in the multi-file format + // without a TOC. + BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) { + return writer_.WriteAll(blob_filename, config); } // Returns the number of blobs added. @@ -710,8 +714,7 @@ class Compressor { private: CompressWorkingSet work_; - hwy::ThreadPool& pool_; - BlobWriter writer_; + WriteToBlobStore writer_; }; // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/compression/compress.h b/compression/compress.h index 8d4635b1..ddfd16cc 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -32,11 +33,13 @@ // IWYU pragma: begin_exports #include "compression/blob_store.h" +#include "compression/fields.h" #include "compression/io.h" #include "compression/shared.h" #include "gemma/tensor_index.h" #include "util/basics.h" // IWYU pragma: end_exports +#include "gemma/configs.h" #include "util/allocator.h" #include "hwy/per_target.h" #if COMPRESS_STATS @@ -55,7 +58,7 @@ namespace gcpp { // fixed inner dimension and type. // It is designed to be put in a vector, and has default copy and operator=, so // it is easy to read/write a blob_store file. -class MatPtr { +class MatPtr : public IFields { public: // Full constructor for dynamic sizing. MatPtr(const std::string& name, Type type, size_t element_size, size_t rows, @@ -73,36 +76,6 @@ class MatPtr { MatPtr() = default; virtual ~MatPtr(); - // Number of hwy::uint128_t in a TOC entry. - // Note that the old-style BlobStore files only have a list of keys and size. - // The new-style BlobStore files have an entry called "toc" that contains a - // vector of 4-tuples of - // (name, type, (num_elements, element_size), (rows, cols)). - // The listed blobs can be read directly into MatPtr from the BlobStore - // file, without needing any external knowledge of the number of elements, - // element size or type of the data. - static constexpr size_t kNumU128InTocEntry = 4; - - // Construct from a TOC entry. - MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1, - const hwy::uint128_t& key2, const hwy::uint128_t& key3) - : name_(StringFromKey(key0)), - type_(static_cast(key1.lo)), - element_size_(key2.hi), - num_elements_(key2.lo), - rows_(key3.lo), - cols_(key3.hi) { - stride_ = cols_; - } - - // Adds the contents entry to the table of contents. - void AddToToc(std::vector& toc) const { - toc.push_back(MakeKey(name_.c_str())); - toc.push_back({static_cast(type_), 0}); - toc.push_back({num_elements_, element_size_}); - toc.push_back({rows_, cols_}); - } - // Compatibility interface for CompressedArray. // TODO: remove. template @@ -124,7 +97,7 @@ class MatPtr { MatPtr& operator=(const MatPtr& other) = default; // Returns the name of the blob. - const std::string& Name() const { return name_; } + const char* Name() const override { return name_.c_str(); } void SetName(const std::string& name) { name_ = name; } // Returns the type of the blob. @@ -163,12 +136,6 @@ class MatPtr { return name; } - // Adds the blob to the writer. - void AddToWriter(BlobWriter& writer) const { - fprintf(stderr, "Adding %s to writer\n", name_.c_str()); - writer.Add(MakeKey(name_.c_str()), ptr_, SizeBytes()); - } - // Sets all data to zero. void ZeroInit() { if (ptr_ == nullptr) @@ -176,6 +143,17 @@ class MatPtr { hwy::ZeroBytes(ptr_, SizeBytes()); } + void VisitFields(IFieldsVisitor& visitor) override { + visitor(name_); + visitor(type_); + visitor(element_size_); + visitor(num_elements_); + visitor(rows_); + visitor(cols_); + visitor(scale_); + visitor(stride_); + } + // Calls func on the upcasted type. Since MatPtr by design is not templated, // here we provide a way to get to the derived type, provided that `Type()` // is one of the strings returned by `TypeName()`. @@ -188,13 +166,13 @@ class MatPtr { // Should be the result of TypeEnum for CallUpcasted() to work. Type type_; // sizeof(T) - size_t element_size_ = 0; + uint32_t element_size_ = 0; // Number of elements in the array. - size_t num_elements_ = 0; // In element_size units. + uint32_t num_elements_ = 0; // In element_size units. // Number of rows in the 2-d array (outer dimension). - size_t rows_ = 0; + uint32_t rows_ = 0; // Number of columns in the 2-d array (inner dimension). - size_t cols_ = 0; + uint32_t cols_ = 0; // Scaling to apply to each element. float scale_ = 1.0f; // Aligned data array. This is always a borrowed pointer. It should never be @@ -202,7 +180,7 @@ class MatPtr { // and must outlive this object. void* ptr_ = nullptr; - size_t stride_; + uint32_t stride_; }; // MatPtrT adds a single template argument to MatPtr for an explicit type. @@ -394,31 +372,28 @@ class BlobToc { public: BlobToc() = default; - // Adds all blobs to the blob writer. Note that the blobs must have unique - // names. - static void AddAllToBlobWriter(const std::vector& blobs, - BlobWriter& writer) { - std::vector toc; - for (const auto& blob : blobs) { - blob.AddToToc(toc); - blob.AddToWriter(writer); - } - writer.Add(MakeKey(kTocName), toc.data(), toc.size() * sizeof(toc[0])); - } - // Loads the table of contents from the given reader. BlobError LoadToc(BlobReader& reader) { hwy::uint128_t toc_key = MakeKey(kTocName); size_t toc_size = reader.BlobSize(toc_key); if (toc_size != 0) { - std::vector toc(toc_size / sizeof(hwy::uint128_t)); + std::vector toc(toc_size / sizeof(uint32_t)); BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size); if (err != 0) { fprintf(stderr, "Failed to read toc (error %d)\n", err); return err; } - for (size_t i = 0; i < toc.size(); i += MatPtr::kNumU128InTocEntry) { - AddToToc(MatPtr(toc[i], toc[i + 1], toc[i + 2], toc[i + 3])); + size_t consumed = 0; + size_t prev_consumed = static_cast(-1); + while (consumed < toc.size() && prev_consumed != consumed) { + MatPtr blob; + const IFields::ReadResult result = + blob.Read(hwy::Span(toc), consumed); + prev_consumed = consumed; + consumed = result.pos; + if (blob.NumElements() > 0) { + AddToToc(blob); + } } } return 0; @@ -437,11 +412,16 @@ class BlobToc { if (it == toc_map_.end()) return nullptr; return &toc_[it->second]; } - - private: // The name of the toc in the blob store file. static constexpr char kTocName[] = "toc"; + // The name of the config in the blob store file. + static constexpr char kConfigName[] = "config"; + + // The name of the tokenizer in the blob store file. + static constexpr char kTokenizerName[] = "tokenizer"; + + private: // Adds the blob to the table of contents. void AddToToc(const MatPtr& blob) { HWY_ASSERT(!Contains(blob.Name())); @@ -519,6 +499,68 @@ struct CompressWorkingSet { std::vector tls; }; +// Class to collect and write a set of tensors to a blob store file. +class WriteToBlobStore { + public: + explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {} + + template + void operator()(MatPtrT* compressed, const char* decorated_name) { + if (compressed->Ptr() == nullptr) return; + writer_.Add(MakeKey(decorated_name), compressed->Ptr(), + compressed->SizeBytes()); + MatPtr renamed_tensor(*compressed); + renamed_tensor.SetName(decorated_name); + renamed_tensor.AppendTo(toc_); + } + + void AddTokenizer(const std::string& tokenizer) { + writer_.Add(MakeKey(BlobToc::kTokenizerName), tokenizer.data(), + tokenizer.size() * sizeof(tokenizer[0])); + } + + void AddScales(const float* scales, size_t len) { + if (len) { + MatPtrT scales_ptr("scales", 0, 1); + writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales, + len * sizeof(scales[0])); + } + } + + // Writes all blobs to disk in the given order. The config is optional and + // if given, it is written to the file, along with the TOC, making it + // single-file format. Otherwise, the file is written in the multi-file format + // without a TOC. + BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) { + if (config) { + writer_.Add(MakeKey(BlobToc::kTocName), toc_.data(), + toc_.size() * sizeof(toc_[0])); + config_buffer_ = config->Write(); + writer_.Add(MakeKey(BlobToc::kConfigName), config_buffer_.data(), + config_buffer_.size() * sizeof(config_buffer_[0])); + } + const BlobError err = writer_.WriteAll(pool_, blob_filename); + if (err != 0) { + fprintf(stderr, "Failed to write blobs to %s (error %d)\n", + blob_filename.path.c_str(), err); + } + return err; + } + + // Returns the number of blobs added. + size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); } + + hwy::ThreadPool& pool() { return pool_; } + + protected: + hwy::ThreadPool& pool_; + + private: + std::vector toc_; + BlobWriter writer_; + std::vector config_buffer_; +}; + // Functor called for each tensor, which loads them and their scaling factors // from BlobStore. class ReadFromBlobStore { @@ -539,11 +581,40 @@ class ReadFromBlobStore { // Returns true if there is a TOC. bool HaveToc() const { return !file_toc_.Empty(); } + // Reads the config from the blob store file. + BlobError LoadConfig(ModelConfig& config) { + hwy::uint128_t config_key = MakeKey(BlobToc::kConfigName); + size_t config_size = reader_.BlobSize(config_key); + if (config_size == 0) return __LINE__; + std::vector config_buffer(config_size / sizeof(uint32_t)); + BlobError err = + reader_.ReadOne(config_key, config_buffer.data(), config_size); + if (err != 0) { + fprintf(stderr, "Failed to read config (error %d)\n", err); + return err; + } + config.Read(hwy::Span(config_buffer), 0); + return 0; + } + + // Reads the tokenizer from the blob store file. + BlobError LoadTokenizer(std::string& tokenizer) { + hwy::uint128_t key = MakeKey(BlobToc::kTokenizerName); + size_t tokenizer_size = reader_.BlobSize(key); + if (tokenizer_size == 0) return __LINE__; + tokenizer.resize(tokenizer_size); + ; + BlobError err = reader_.ReadOne(key, tokenizer.data(), tokenizer_size); + if (err != 0) { + fprintf(stderr, "Failed to read tokenizer (error %d)\n", err); + return err; + } + return 0; + } + // Called for each tensor, enqueues read requests. void operator()(const char* name, hwy::Span tensors) { if (file_toc_.Empty() || file_toc_.Contains(name)) { - if (tensors[0]->NumElements() == 0) - fprintf(stderr, "Zero elements for %s\n", name); model_toc_.push_back(tensors[0]); file_keys_.push_back(name); } @@ -579,12 +650,12 @@ class ReadFromBlobStore { fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str()); return __LINE__; } - MatStorage toc_blob_array(*toc_blob); - model_memory.push_back(std::move(toc_blob_array)); - } else { - model_memory.emplace_back(*blob); - model_memory.back().SetName(file_key); + std::string name = blob->Name(); + *blob = *toc_blob; + blob->SetName(name); } + model_memory.emplace_back(*blob); + model_memory.back().SetName(file_key); } // Allocate in parallel using the pool. pool.Run(0, model_memory.size(), @@ -594,12 +665,12 @@ class ReadFromBlobStore { }); // Enqueue the read requests. for (auto& blob : model_memory) { - err_ = reader_.Enqueue(MakeKey(blob.Name().c_str()), blob.data(), - blob.SizeBytes()); + err_ = + reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes()); if (err_ != 0) { fprintf(stderr, "Failed to read blob %s (error %d) of size %zu x %zu x %zu\n", - blob.Name().c_str(), err_, blob.Rows(), blob.Cols(), + blob.Name(), err_, blob.Rows(), blob.Cols(), blob.ElementSize()); return err_; } diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index 66e95ec8..cbf7e359 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -25,6 +25,7 @@ // After highway.h #include "compression/compress-inl.h" #include "gemma/configs.h" +#include "gemma/tokenizer.h" #ifndef GEMMA_COMPRESS_WEIGHTS_ONCE #define GEMMA_COMPRESS_WEIGHTS_ONCE @@ -99,6 +100,9 @@ struct Args : public ArgsBase { std::string model_type_str; std::string weight_type_str; size_t num_threads; + // If non-empty, whether to include the config and TOC in the output file, as + // well as the tokenizer. + Path tokenizer; template void ForEach(const Visitor& visitor) { @@ -123,6 +127,9 @@ struct Args : public ArgsBase { "Number of threads to use.\n Default = Estimate of the " "number of supported concurrent threads.", 2); + visitor(tokenizer, "tokenizer", Path(), + "Path to tokenizer file. If given, the config and TOC are also " + "added to the output file."); } // Uninitialized before Validate, must call after that. @@ -156,7 +163,8 @@ namespace HWY_NAMESPACE { template void CompressWeights(const Path& weights_path, const Path& compressed_weights_path, Model model_type, - hwy::ThreadPool& pool) { + Type weight_type, PromptWrapping wrapping, + const Path& tokenizer_path, hwy::ThreadPool& pool) { if (!weights_path.Exists()) { HWY_ABORT("The model weights file '%s' does not exist.", weights_path.path.c_str()); @@ -164,6 +172,8 @@ void CompressWeights(const Path& weights_path, printf("Compressing weights from %s to %s\n", weights_path.path.c_str(), compressed_weights_path.path.c_str()); ModelConfig config = ConfigFromModel(model_type); + config.weight = weight_type; + config.wrapping = wrapping; std::vector model_storage; ModelWeightsPtrs c_weights(config); c_weights.Allocate(model_storage, pool); @@ -185,6 +195,9 @@ void CompressWeights(const Path& weights_path, ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr); total_size += tensors[0]->SizeBytes(); }); + if (!tokenizer_path.path.empty()) { + uc_weights.AllocAndCopyWithTranspose(pool, model_storage); + } const bool scale_for_compression = config.num_tensor_scales > 0; std::vector scales; if (scale_for_compression) { @@ -193,14 +206,21 @@ void CompressWeights(const Path& weights_path, Compressor compressor(pool); ModelWeightsPtrs::ForEachTensor( {reinterpret_cast*>(&uc_weights), &c_weights}, - ForEachType::kLoadNoToc, + tokenizer_path.path.empty() ? ForEachType::kLoadNoToc + : ForEachType::kLoadWithToc, [&compressor](const char* name, hwy::Span tensors) { tensors[1]->CallUpcasted( compressor, name, reinterpret_cast(tensors[0]->Ptr())); }); - compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0])); - compressor.WriteAll(pool, compressed_weights_path); + if (!tokenizer_path.path.empty()) { + std::string tokenizer_proto = ReadFileToString(tokenizer_path); + compressor.AddTokenizer(tokenizer_proto); + } else { + compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0])); + } + compressor.WriteAll(compressed_weights_path, + tokenizer_path.path.empty() ? nullptr : &config); } } // namespace HWY_NAMESPACE @@ -220,19 +240,23 @@ void Run(Args& args) { switch (weight_type) { case Type::kF32: HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, pool); + (args.weights, args.compressed_weights, model_type, weight_type, + args.PromptWrappingType(), args.tokenizer, pool); break; case Type::kBF16: HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, pool); + (args.weights, args.compressed_weights, model_type, weight_type, + args.PromptWrappingType(), args.tokenizer, pool); break; case Type::kSFP: HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, pool); + (args.weights, args.compressed_weights, model_type, weight_type, + args.PromptWrappingType(), args.tokenizer, pool); break; case Type::kNUQ: HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, pool); + (args.weights, args.compressed_weights, model_type, weight_type, + args.PromptWrappingType(), args.tokenizer, pool); break; default: HWY_ABORT("Weight type %d unsupported.", static_cast(weight_type)); diff --git a/compression/fields.cc b/compression/fields.cc index de90ec86..8977af79 100644 --- a/compression/fields.cc +++ b/compression/fields.cc @@ -83,6 +83,14 @@ class PrintVisitor : public VisitorBase { fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value); } + void operator()(int32_t& value) override { + fprintf(stderr, "%sI32 %d\n", indent_.c_str(), value); + } + + void operator()(uint64_t& value) override { + fprintf(stderr, "%sU64 %zu\n", indent_.c_str(), value); + } + void operator()(float& value) override { fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value); } @@ -120,6 +128,21 @@ class ReadVisitor : public VisitorBase { value = span_[result_.pos++]; } + void operator()(int32_t& value) override { + if (HWY_UNLIKELY(SkipField())) return; + + value = static_cast(span_[result_.pos++]); + } + + void operator()(uint64_t& value) override { + if (HWY_UNLIKELY(SkipField())) return; + uint32_t lower = static_cast(value); + operator()(lower); + uint32_t upper = static_cast(value >> 32); + operator()(upper); + value = lower | (static_cast(upper) << 32); + } + void operator()(float& value) override { if (HWY_UNLIKELY(SkipField())) return; @@ -229,6 +252,15 @@ class WriteVisitor : public VisitorBase { void operator()(uint32_t& value) override { storage_.push_back(value); } + void operator()(int32_t& value) override { + storage_.push_back(static_cast(value)); + } + + void operator()(uint64_t& value) override { + storage_.push_back(static_cast(value)); + storage_.push_back(static_cast(value >> 32)); + } + void operator()(float& value) override { storage_.push_back(hwy::BitCastScalar(value)); CheckF32(value); diff --git a/compression/fields.h b/compression/fields.h index 2ee409a1..33f3a65c 100644 --- a/compression/fields.h +++ b/compression/fields.h @@ -55,8 +55,9 @@ class IFields; // breaks circular dependency // Visitors are internal-only, but their base class is visible to user code // because their `IFields::VisitFields` calls `visitor.operator()`. // -// Supported field types `T`: `uint32_t`, `float`, `std::string`, classes -// derived from `IFields`, `bool`, `enum`, `std::vector`. +// Supported field types `T`: `uint32_t`, `int32_t`, `uint64_t`, `float`, +// `std::string`, +// classes derived from `IFields`, `bool`, `enum`, `std::vector`. class IFieldsVisitor { public: virtual ~IFieldsVisitor(); @@ -69,6 +70,8 @@ class IFieldsVisitor { // is out of range. A single generic/overloaded function is required to // support `std::vector`. virtual void operator()(uint32_t& value) = 0; + virtual void operator()(int32_t& value) = 0; + virtual void operator()(uint64_t& value) = 0; virtual void operator()(float& value) = 0; virtual void operator()(std::string& value) = 0; virtual void operator()(IFields& fields) = 0; // recurse into nested fields @@ -92,7 +95,7 @@ class IFieldsVisitor { uint32_t u32 = static_cast(value); operator()(u32); if (HWY_UNLIKELY(!EnumValid(static_cast(u32)))) { - return NotifyInvalid("Invalid enum %u\n"); + return NotifyInvalid("Invalid enum %u\n", u32); } value = static_cast(u32); } diff --git a/compression/fields_test.cc b/compression/fields_test.cc index d6c1c519..bfe7b037 100644 --- a/compression/fields_test.cc +++ b/compression/fields_test.cc @@ -97,6 +97,8 @@ struct OldFields : public IFields { visitor(old_str); visitor(old_nested); visitor(old1); + visitor(oldi); + visitor(oldl); visitor(old_vec_str); visitor(old_vec_nested); visitor(old_f); @@ -110,6 +112,8 @@ struct OldFields : public IFields { EXPECT_EQ(old_str, n.old_str); old_nested.CheckEqual(n.old_nested); EXPECT_EQ(old1, n.old1); + EXPECT_EQ(oldi, n.oldi); + EXPECT_EQ(oldl, n.oldl); CheckVectorEqual(old_vec_str, n.old_vec_str); CheckVectorEqual(old_vec_nested, n.old_vec_nested); EXPECT_EQ(old_f, n.old_f); @@ -120,6 +124,8 @@ struct OldFields : public IFields { std::string old_str = "old"; Nested old_nested = Nested(0); uint32_t old1 = 1; + int32_t oldi = -1; + uint64_t oldl = 1234567890123456789; std::vector old_vec_str = {"abc", "1234"}; std::vector old_vec_nested = {Nested(1), Nested(4)}; float old_f = 1.125f; @@ -134,6 +140,8 @@ struct NewFields : public IFields { visitor(old_str); visitor(old_nested); visitor(old1); + visitor(oldi); + visitor(oldl); visitor(old_vec_str); visitor(old_vec_nested); visitor(old_f); @@ -149,6 +157,8 @@ struct NewFields : public IFields { visitor(new_enum); visitor(new2); visitor(new_str); + visitor(new_i); + visitor(new_l); } void CheckEqual(const NewFields& n) const { @@ -176,6 +186,8 @@ struct NewFields : public IFields { std::string old_str = "old"; Nested old_nested = Nested(0); uint32_t old1 = 1; + int32_t oldi = -1; + uint64_t oldl = 1234567890123456789; std::vector old_vec_str = {"abc", "1234"}; std::vector old_vec_nested = {Nested(1), Nested(4)}; float old_f = 1.125f; @@ -190,6 +202,8 @@ struct NewFields : public IFields { Enum new_enum = Enum::k3; uint32_t new2 = 2; std::string new_str = std::string(); // empty is allowed + int32_t new_i = 123456789; + uint64_t new_l = 876543210987654321; }; // NewFields // Changes all fields to non-default values. @@ -212,6 +226,8 @@ NewFields ModifiedNewFields() { n.new_enum = Enum::k8; n.new2 = 22; n.new_str = "new and even longer"; + n.new_i = 246810121; + n.new_l = 1357913579113579135; return n; } diff --git a/compression/migrate_weights.cc b/compression/migrate_weights.cc new file mode 100644 index 00000000..7a9613e5 --- /dev/null +++ b/compression/migrate_weights.cc @@ -0,0 +1,62 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "evals/benchmark_helper.h" +#include "gemma/gemma.h" +#include "util/args.h" + +namespace gcpp { +namespace { + +struct WriterArgs : public ArgsBase { + // --output_weights is required. + WriterArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + + // Returns error string or nullptr if OK. + const char* Validate() { + if (output_weights.path.empty()) { + return "Missing --output_weights flag, a file for the model weights."; + } + return nullptr; + } + + Path output_weights; // weights file location + + template + void ForEach(const Visitor& visitor) { + visitor(output_weights, "output_weights", Path(), + "Path name of output weights (.sbs) file.\n Required argument."); + } +}; + +} // namespace +} // namespace gcpp + +int main(int argc, char** argv) { + // Loads a model in the multi-file format and saves it in single-file format. + gcpp::WriterArgs args(argc, argv); + if (const char* err = args.Validate()) { + fprintf(stderr, "Skipping model load because: %s\n", err); + return 1; + } + gcpp::GemmaEnv env(argc, argv, /*required=*/true); + hwy::ThreadPool pool(0); + env.GetModel()->Save(args.output_weights, pool); + return 0; +} diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 016f90f8..546705b5 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -16,6 +16,7 @@ cc_library( deps = [ "@abseil-cpp//absl/types:span", "//:common", + "//:tokenizer", "//compression:compress", "//compression:io", "@highway//:hwy", diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index fe0e128b..05a5110f 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -24,7 +24,9 @@ #include "absl/types/span.h" #include "compression/io.h" +#include "gemma/configs.h" #include "gemma/tensor_index.h" +#include "gemma/tokenizer.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -44,10 +46,11 @@ class WriterInterface { virtual void InsertFloat(std::string name, absl::Span weights) = 0; virtual void AddScales(const std::vector& scales) = 0; + virtual void AddTokenizer(const std::string& tokenizer_path) = 0; virtual size_t DebugNumBlobsAdded() const = 0; - virtual int Write(std::string path) = 0; + virtual int WriteWithConfig(std::string path, const ModelConfig* config) = 0; }; } // namespace gcpp @@ -133,14 +136,21 @@ class SbsWriterImpl : public WriterInterface { compressor_.AddScales(scales_.data(), scales_.size()); } + void AddTokenizer(const std::string& tokenizer_path) override { + Path path(tokenizer_path); + GemmaTokenizer tokenizer(path); + tokenizer_proto_ = tokenizer.Serialize(); + compressor_.AddTokenizer(tokenizer_proto_); + } + // Returns the number of blobs added. size_t DebugNumBlobsAdded() const { if (mode_ == CompressorMode::kTEST_ONLY) return model_memory_.size(); return compressor_.DebugNumBlobsAdded(); } - int Write(std::string path) override { - return compressor_.WriteAll(pool_, gcpp::Path(path)); + int WriteWithConfig(std::string path, const ModelConfig* config) override { + return compressor_.WriteAll(gcpp::Path(path), config); } hwy::ThreadPool pool_; @@ -149,6 +159,7 @@ class SbsWriterImpl : public WriterInterface { std::vector model_memory_; std::vector scales_; CompressorMode mode_; + std::string tokenizer_proto_; }; WriterInterface* NewSbsWriter(CompressorMode mode) { @@ -190,11 +201,17 @@ void SbsWriter::AddScales(const std::vector& scales) { impl_->AddScales(scales); } +void SbsWriter::AddTokenizer(const std::string& tokenizer_path) { + impl_->AddTokenizer(tokenizer_path); +} + size_t SbsWriter::DebugNumBlobsAdded() const { return impl_->DebugNumBlobsAdded(); } -int SbsWriter::Write(std::string path) { return impl_->Write(path); } +int SbsWriter::WriteWithConfig(std::string path, const ModelConfig* config) { + return impl_->WriteWithConfig(path, config); +} } // namespace gcpp #endif // HWY_ONCE diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index 72eb4e44..cb8eb8c6 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -8,6 +8,7 @@ #include "absl/types/span.h" #include "compression/shared.h" +#include "gemma/configs.h" #include "gemma/tensor_index.h" namespace gcpp { @@ -36,10 +37,12 @@ class SbsWriter { void InsertBfloat16(std::string name, absl::Span weights); void InsertFloat(std::string name, absl::Span weights); void AddScales(const std::vector& scales); + void AddTokenizer(const std::string& tokenizer_path); size_t DebugNumBlobsAdded() const; - int Write(std::string path); + int Write(std::string path) { return WriteWithConfig(path, nullptr); } + int WriteWithConfig(std::string path, const ModelConfig* config); private: // Isolates Highway-dispatched types and other internals from CLIF. diff --git a/compression/python/compression_extension.cc b/compression/python/compression_extension.cc index 4669f626..17266e98 100644 --- a/compression/python/compression_extension.cc +++ b/compression/python/compression_extension.cc @@ -50,6 +50,8 @@ PYBIND11_MODULE(compression, m) { .def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>) .def("insert_float", wrap_span<&SbsWriter::InsertFloat>) .def("add_scales", &SbsWriter::AddScales) + .def("add_tokenizer", &SbsWriter::AddTokenizer) .def("debug_num_blobs_added", &SbsWriter::DebugNumBlobsAdded) - .def("write", &SbsWriter::Write); + .def("write", &SbsWriter::Write) + .def("write_with_config", &SbsWriter::WriteWithConfig); } diff --git a/compression/shared.h b/compression/shared.h index 8a708730..c9ce1c0f 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -198,6 +198,11 @@ constexpr bool IsNuqStream() { // Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA }; +inline bool EnumValid(PromptWrapping type) { + return static_cast(type) >= 0 && + static_cast(type) <= static_cast(PromptWrapping::PALIGEMMA); +} + // Tensor types for loading weights. Note that not all types are supported as // weights for a model, but can be used for other purposes, such as types for // ModelWeightsPtrs. When adding a new type that is supported, also @@ -206,6 +211,11 @@ enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 }; constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"}; +inline bool EnumValid(Type type) { + return static_cast(type) >= 0 && + static_cast(type) <= static_cast(Type::kU128); +} + // Returns a Type enum for the type of the template parameter. template Type TypeEnum() { diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 435dce01..d16e5ef9 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -92,9 +92,9 @@ static AppArgs MakeAppArgs(int argc, char** argv) { return AppArgs(argc, argv); } -GemmaEnv::GemmaEnv(int argc, char** argv) - : GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv), - MakeAppArgs(argc, argv)) {} +GemmaEnv::GemmaEnv(int argc, char** argv, bool model_type_required) + : GemmaEnv(LoaderArgs(argc, argv, model_type_required), + InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {} QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { QueryResult result; @@ -270,7 +270,9 @@ void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { "specify 3 required model loading arguments:\n" " --tokenizer\n" " --weights\n" - " --model.\n"; + " --model,\n" + " or with the newer weights format, specify just:\n" + " --weights\n"; std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " "--weights 2b-it-sfp.sbs --model 2b-it\n"; std::cerr << "\n*Model Loading Arguments*\n\n"; diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 7e7f1bf6..6a1302e7 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -44,7 +44,7 @@ struct QueryResult { class GemmaEnv { public: // Calls the other constructor with *Args arguments initialized from argv. - GemmaEnv(int argc, char** argv); + GemmaEnv(int argc, char** argv, bool model_type_required = false); GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, const AppArgs& app); diff --git a/gemma/configs.cc b/gemma/configs.cc index 8a714c18..7cb45747 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -15,6 +15,7 @@ #include "gemma/configs.h" +#include #include #include "hwy/base.h" @@ -22,9 +23,9 @@ namespace gcpp { static ModelConfig ConfigNoSSM() { - ModelConfig config = {.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", - "gr_lin_y_w", "gr_lin_out_w", - "gr_gate_w", "gating_ein", "linear_w"}}; + ModelConfig config; + config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", + "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; return config; } @@ -37,6 +38,18 @@ static ModelConfig ConfigBaseGemmaV2() { return config; } +static LayerConfig LayerConfigGemma2_27B(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 16 * 4608 / 2; // = 36864 + config.heads = 32; + config.kv_heads = 16; + config.qkv_dim = 128; + config.optimized_gating = false; + config.post_norm = PostNormType::Scale; + return config; +} + static ModelConfig ConfigGemma2_27B() { ModelConfig config = ConfigBaseGemmaV2(); config.model_name = "Gemma2_27B"; @@ -44,13 +57,7 @@ static ModelConfig ConfigGemma2_27B() { config.model_dim = 4608; config.vocab_size = kVocabSize; config.seq_len = 8192; - LayerConfig layer_config = {.model_dim = config.model_dim, - .ff_hidden_dim = 16 * 4608 / 2, // = 36864 - .heads = 32, - .kv_heads = 16, - .qkv_dim = 128, - .optimized_gating = false, - .post_norm = PostNormType::Scale}; + LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim); config.layer_configs = {46, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads; @@ -59,6 +66,18 @@ static ModelConfig ConfigGemma2_27B() { return config; } +static LayerConfig LayerConfigGemma2_9B(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 8 * 3584 / 2; // = 14336 + config.heads = 16; + config.kv_heads = 8; + config.qkv_dim = 256; + config.optimized_gating = false; + config.post_norm = PostNormType::Scale; + return config; +} + static ModelConfig ConfigGemma2_9B() { ModelConfig config = ConfigBaseGemmaV2(); config.model_name = "Gemma2_9B"; @@ -66,13 +85,7 @@ static ModelConfig ConfigGemma2_9B() { config.model_dim = 3584; config.vocab_size = kVocabSize; config.seq_len = 8192; - LayerConfig layer_config = {.model_dim = config.model_dim, - .ff_hidden_dim = 8 * 3584 / 2, // = 14336 - .heads = 16, - .kv_heads = 8, - .qkv_dim = 256, - .optimized_gating = false, - .post_norm = PostNormType::Scale}; + LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim); config.layer_configs = {42, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtKeySize; @@ -81,6 +94,18 @@ static ModelConfig ConfigGemma2_9B() { return config; } +static LayerConfig LayerConfigGemma2_2B(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 8 * 2304 / 2; // = 9216 + config.heads = 8; + config.kv_heads = 4; + config.qkv_dim = 256; + config.optimized_gating = false; + config.post_norm = PostNormType::Scale; + return config; +} + static ModelConfig ConfigGemma2_2B() { ModelConfig config = ConfigBaseGemmaV2(); config.model_name = "Gemma2_2B"; @@ -88,13 +113,7 @@ static ModelConfig ConfigGemma2_2B() { config.model_dim = 2304; config.vocab_size = kVocabSize; config.seq_len = 8192; - LayerConfig layer_config = {.model_dim = config.model_dim, - .ff_hidden_dim = 8 * 2304 / 2, // = 9216 - .heads = 8, - .kv_heads = 4, - .qkv_dim = 256, - .optimized_gating = false, - .post_norm = PostNormType::Scale}; + LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim); config.layer_configs = {26, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtKeySize; @@ -103,6 +122,16 @@ static ModelConfig ConfigGemma2_2B() { return config; } +static LayerConfig LayerConfigGemma7B(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 16 * 3072 / 2; // = 24576 + config.heads = 16; + config.kv_heads = 16; + config.qkv_dim = 256; + return config; +} + static ModelConfig ConfigGemma7B() { ModelConfig config = ConfigBaseGemmaV1(); config.model_name = "Gemma7B"; @@ -110,13 +139,7 @@ static ModelConfig ConfigGemma7B() { config.model_dim = 3072; config.vocab_size = kVocabSize; config.seq_len = kSeqLen; - LayerConfig layer_config = { - .model_dim = config.model_dim, - .ff_hidden_dim = 16 * 3072 / 2, // = 24576 - .heads = 16, - .kv_heads = 16, - .qkv_dim = 256, - }; + LayerConfig layer_config = LayerConfigGemma7B(config.model_dim); config.layer_configs = {28, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtKeySize; @@ -124,6 +147,16 @@ static ModelConfig ConfigGemma7B() { return config; } +static LayerConfig LayerConfigGemma2B(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 16 * 2048 / 2; // = 16384 + config.heads = 8; + config.kv_heads = 1; + config.qkv_dim = 256; + return config; +} + static ModelConfig ConfigGemma2B() { ModelConfig config = ConfigBaseGemmaV1(); config.model_name = "Gemma2B"; @@ -131,19 +164,23 @@ static ModelConfig ConfigGemma2B() { config.model_dim = 2048; config.vocab_size = kVocabSize; config.seq_len = kSeqLen; - LayerConfig layer_config = { - .model_dim = config.model_dim, - .ff_hidden_dim = 16 * 2048 / 2, // = 16384 - .heads = 8, - .kv_heads = 1, - .qkv_dim = 256, - }; + LayerConfig layer_config = LayerConfigGemma2B(config.model_dim); config.layer_configs = {18, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen); return config; } +static LayerConfig LayerConfigGemmaTiny(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 256; + config.heads = 4; + config.kv_heads = 1; + config.qkv_dim = 16; + return config; +} + static ModelConfig ConfigGemmaTiny() { ModelConfig config = ConfigNoSSM(); config.model_name = "GemmaTiny"; @@ -151,13 +188,7 @@ static ModelConfig ConfigGemmaTiny() { config.model_dim = 128; config.vocab_size = 64; config.seq_len = 32; - LayerConfig layer_config = { - .model_dim = config.model_dim, - .ff_hidden_dim = 256, - .heads = 4, - .kv_heads = 1, - .qkv_dim = 16, - }; + LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim); config.layer_configs = {3, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtKeySize; @@ -167,6 +198,24 @@ static ModelConfig ConfigGemmaTiny() { return config; } +static LayerConfig LayerConfigGriffin2B(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.griffin_dim = model_dim; + config.ff_hidden_dim = 7680; + config.heads = 10; + config.kv_heads = 1; + config.qkv_dim = 256; + config.conv1d_width = 4; + config.ff_biases = true; + config.softmax_attn_output_biases = true; + config.optimized_gating = false; + config.type = LayerAttentionType::kGriffinRecurrentBlock; + config.activation = ActivationType::Gelu; + config.post_qk = PostQKType::HalfRope; + return config; +} + static ModelConfig ConfigGriffin2B() { ModelConfig config = ConfigNoSSM(); config.model_name = "Griffin2B"; @@ -176,21 +225,7 @@ static ModelConfig ConfigGriffin2B() { config.model_dim = 2560; config.vocab_size = kVocabSize; config.seq_len = 2048; - LayerConfig layer_config = { - .model_dim = config.model_dim, - .griffin_dim = config.model_dim, - .ff_hidden_dim = 7680, - .heads = 10, - .kv_heads = 1, - .qkv_dim = 256, - .conv1d_width = 4, - .ff_biases = true, - .softmax_attn_output_biases = true, - .optimized_gating = false, - .type = LayerAttentionType::kGriffinRecurrentBlock, - .activation = ActivationType::Gelu, - .post_qk = PostQKType::HalfRope, - }; + LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim); config.layer_configs = {26, layer_config}; for (size_t i = 2; i < config.layer_configs.size(); i += 3) { config.layer_configs[i].type = LayerAttentionType::kGemma; @@ -204,6 +239,18 @@ static ModelConfig ConfigGriffin2B() { return config; } +static LayerConfig LayerConfigVit(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 4304; + config.heads = 16; + config.kv_heads = 16; + config.qkv_dim = 72; + config.ff_biases = true; + config.type = LayerAttentionType::kVit; + return config; +} + // Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config. static void AddVitConfig(ModelConfig& config, size_t image_size = 224) { config.vit_model_dim = 1152; @@ -215,15 +262,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) { } const size_t num_patches = config.image_size / config.patch_width; config.vit_seq_len = num_patches * num_patches; - LayerConfig vit_layer_config = { - .model_dim = config.vit_model_dim, - .ff_hidden_dim = 4304, - .heads = 16, - .kv_heads = 16, - .qkv_dim = 72, - .ff_biases = true, - .type = LayerAttentionType::kVit, - }; + LayerConfig vit_layer_config = LayerConfigVit(config.vit_model_dim); config.vit_layer_configs = {27, vit_layer_config}; config.num_vit_scales = 4 * config.vit_layer_configs.size(); } diff --git a/gemma/configs.h b/gemma/configs.h index 9c33b178..aad5d32c 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -26,6 +26,7 @@ #include #include +#include "compression/fields.h" // IFieldsVisitor #include "compression/shared.h" // BF16 namespace gcpp { @@ -52,52 +53,83 @@ enum class LayerAttentionType { kVit, }; +inline bool EnumValid(LayerAttentionType type) { + return static_cast(type) >= 0 && + static_cast(type) <= static_cast(LayerAttentionType::kVit); +} + // Post attention and ffw normalization type. enum class PostNormType { None, Scale, }; +inline bool EnumValid(PostNormType type) { + return static_cast(type) >= 0 && + static_cast(type) <= static_cast(PostNormType::Scale); +} + // Post qk projection operation type. enum class PostQKType { Rope, HalfRope, }; +inline bool EnumValid(PostQKType type) { + return static_cast(type) >= 0 && + static_cast(type) <= static_cast(PostQKType::HalfRope); +} + // FFW activation function. enum class ActivationType { Gelu, }; +inline bool EnumValid(ActivationType type) { + return static_cast(type) >= 0 && + static_cast(type) <= static_cast(ActivationType::Gelu); +} + // Attention query scale. enum class QueryScaleType { SqrtKeySize, SqrtModelDimDivNumHeads, }; +inline bool EnumValid(QueryScaleType type) { + return static_cast(type) >= 0 && + static_cast(type) <= + static_cast(QueryScaleType::SqrtModelDimDivNumHeads); +} + // Residual connection type. enum class ResidualType { Add, }; +inline bool EnumValid(ResidualType type) { + return static_cast(type) >= 0 && + static_cast(type) <= static_cast(ResidualType::Add); +} + template std::vector FixedLayerConfig(LayerAttentionType type) { return std::vector(kNum, type); } -template -std::vector FixedAttentionWindowSizes(size_t window_size) { - return std::vector(kNum, window_size); +template +std::vector FixedAttentionWindowSizes(uint32_t window_size) { + return std::vector(kNum, window_size); } // Repeat window_size_pattern for kNum / kPatternSize times. -template -std::vector RepeatedAttentionWindowSizes( - const std::array& window_size_pattern) { +template +std::vector RepeatedAttentionWindowSizes( + const std::array& window_size_pattern) { static_assert(kNum % kPatternSize == 0, "kNum must be a multiple of kPatternSize"); - std::vector window_size_configs(kNum); - for (size_t i = 0; i < kNum; ++i) { + std::vector window_size_configs(kNum); + for (uint32_t i = 0; i < kNum; ++i) { window_size_configs[i] = window_size_pattern[i % kPatternSize]; } return window_size_configs; @@ -130,7 +162,14 @@ static constexpr Model kAllModels[] = { Model::PALIGEMMA2_10B_224, Model::PALIGEMMA2_10B_448, }; -struct LayerConfig { +inline bool EnumValid(Model model) { + for (Model m : kAllModels) { + if (m == model) return true; + } + return false; +} + +struct LayerConfig : public IFields { // Returns true if *this and other are equal. // If partial is true, then we don't check for items that are only set after // the tensors are loaded from the checkpoint. @@ -146,13 +185,32 @@ struct LayerConfig { // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. size_t QStride() const { return qkv_dim * (IsMHA() ? 3 : 1); } - size_t model_dim = 0; - size_t griffin_dim = 0; - size_t ff_hidden_dim = 0; - size_t heads = 0; - size_t kv_heads = 0; - size_t qkv_dim = 0; - size_t conv1d_width = 0; // griffin only + const char* Name() const override { return "LayerConfig"; } + + void VisitFields(IFieldsVisitor& visitor) override { + visitor(model_dim); + visitor(griffin_dim); + visitor(ff_hidden_dim); + visitor(heads); + visitor(kv_heads); + visitor(qkv_dim); + visitor(conv1d_width); + visitor(ff_biases); + visitor(softmax_attn_output_biases); + visitor(optimized_gating); + visitor(post_norm); + visitor(type); + visitor(activation); + visitor(post_qk); + } + + uint32_t model_dim = 0; + uint32_t griffin_dim = 0; + uint32_t ff_hidden_dim = 0; + uint32_t heads = 0; + uint32_t kv_heads = 0; + uint32_t qkv_dim = 0; + uint32_t conv1d_width = 0; // griffin only bool ff_biases = false; bool softmax_attn_output_biases = false; bool optimized_gating = true; @@ -162,7 +220,7 @@ struct LayerConfig { PostQKType post_qk = PostQKType::Rope; }; -struct ModelConfig { +struct ModelConfig : public IFields { // Returns true if *this and other are equal. // If partial is true, then we don't check for items that are only set after // the tensors are loaded from the checkpoint. @@ -191,39 +249,68 @@ struct ModelConfig { } size_t NumHeads() const { - size_t num_heads = 0; + uint32_t num_heads = 0; for (const auto& layer_config : layer_configs) { num_heads = std::max(num_heads, layer_config.heads); } return num_heads; } + const char* Name() const override { return "ModelConfig"; } + + void VisitFields(IFieldsVisitor& visitor) override { + visitor(model_family_version); + visitor(model_name); + visitor(model); + visitor(wrapping); + visitor(weight); + visitor(num_layers); + visitor(model_dim); + visitor(vocab_size); + visitor(seq_len); + visitor(num_tensor_scales); + visitor(att_cap); + visitor(final_cap); + visitor(absolute_pe); + visitor(use_local_attention); + visitor(query_scale); + visitor(layer_configs); + visitor(attention_window_sizes); + visitor(norm_num_groups); + visitor(vit_model_dim); + visitor(vit_seq_len); + visitor(num_vit_scales); + visitor(vit_layer_configs); + visitor(patch_width); + visitor(image_size); + } + std::string model_name; - Model model; - PromptWrapping wrapping; - Type weight; - size_t num_layers = 0; - size_t model_dim = 0; - size_t vit_model_dim = 0; - size_t vocab_size = 0; - size_t seq_len = 0; - size_t vit_seq_len = 0; - size_t num_tensor_scales = 0; - size_t num_vit_scales = 0; + Model model = Model::UNKNOWN; + PromptWrapping wrapping = PromptWrapping::GEMMA_PT; + Type weight = Type::kUnknown; + uint32_t num_layers = 0; + uint32_t model_dim = 0; + uint32_t vit_model_dim = 0; + uint32_t vocab_size = 0; + uint32_t seq_len = 0; + uint32_t vit_seq_len = 0; + uint32_t num_tensor_scales = 0; + uint32_t num_vit_scales = 0; float att_cap = 0.0f; float final_cap = 0.0f; bool absolute_pe = false; bool use_local_attention = false; // griffin only QueryScaleType query_scale = QueryScaleType::SqrtKeySize; std::vector layer_configs; - std::vector attention_window_sizes; + std::vector attention_window_sizes; std::vector vit_layer_configs; std::unordered_set scale_names; - int norm_num_groups = 1; - int model_family_version = 1; + uint32_t norm_num_groups = 1; + uint32_t model_family_version = 1; // Dimensions related to image processing. - size_t patch_width = 14; - size_t image_size = 224; + uint32_t patch_width = 14; + uint32_t image_size = 224; }; // Returns the config for the given model. diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index fa8d8700..456d5fa5 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -2,9 +2,12 @@ #include #include +#include #include +#include #include "gtest/gtest.h" +#include "hwy/aligned_allocator.h" namespace gcpp { @@ -412,8 +415,17 @@ void AssertMatch(const ModelConfig& config) { ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales); } +ModelConfig RoundTripSerialize(const ModelConfig& config) { + std::vector config_buffer = config.Write(); + ModelConfig deserialized; + deserialized.Read(hwy::Span(config_buffer), 0); + return deserialized; +} + TEST(ConfigsTest, OldConfigGemma2B) { AssertMatch>(ConfigFromModel(Model::GEMMA_2B)); + ModelConfig config = RoundTripSerialize(ConfigFromModel(Model::GEMMA_2B)); + AssertMatch>(config); } TEST(ConfigsTest, OldConfigGemma7B) { diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 739bddb7..328144b8 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -23,6 +23,7 @@ #include #include +#include #include // std::move #include @@ -40,13 +41,21 @@ namespace gcpp { Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, NestedPools& pools) - : pools_(pools), tokenizer_(tokenizer_path), info_(info) { - model_.Load(weights, info.model, info.weight, pools_.Pool()); + : pools_(pools), tokenizer_(tokenizer_path) { + model_.Load(weights, info.model, info.weight, info.wrapping, pools_.Pool(), + /*tokenizer_proto=*/nullptr); +} + +Gemma::Gemma(const Path& weights, NestedPools& pools) : pools_(pools) { + std::string tokenizer_proto; + model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT, + pools_.Pool(), &tokenizer_proto); + tokenizer_.Deserialize(tokenizer_proto); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools) - : pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) { + : pools_(pools), tokenizer_(std::move(tokenizer)) { HWY_ASSERT(info.weight == Type::kF32); model_.Allocate(info.model, info.weight, pools_.Pool()); } @@ -166,7 +175,7 @@ void RangeChecks(const ModelConfig& weights_config, if (!weights_config.use_local_attention) { if (max_generated_tokens > weights_config.seq_len) { fprintf(stderr, - "WARNING: max_generated_tokens %zu > kSeqLen %zu, truncating.\n", + "WARNING: max_generated_tokens %zu > kSeqLen %u, truncating.\n", max_generated_tokens, weights_config.seq_len); max_generated_tokens = weights_config.seq_len; } diff --git a/gemma/gemma.h b/gemma/gemma.h index 1ad87176..15d22a18 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -190,18 +190,28 @@ struct TimingInfo { class Gemma { public: + // Reads old format weights file and tokenizer file. Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, NestedPools& pools); - + // Reads new format weights file that contains everything in a single file. + Gemma(const Path& weights, NestedPools& pools); // Allocates weights, caller is responsible for filling them. Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools); ~Gemma(); const ModelConfig& GetModelConfig() const { return model_.Config(); } - const ModelInfo& Info() const { return info_; } + ModelInfo Info() const { + return ModelInfo({.model = model_.Config().model, + .wrapping = model_.Config().wrapping, + .weight = model_.Config().weight}); + } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ModelWeightsStorage& Weights() const { return model_; } ModelWeightsStorage& MutableWeights() { return model_; } + void Save(const Path& weights, hwy::ThreadPool& pool) { + std::string tokenizer_proto = tokenizer_.Serialize(); + model_.Save(tokenizer_proto, weights, pool); + } // `pos` is the position in the KV cache. Users are responsible for // incrementing it in the `*StreamFunc`, or setting to zero for single-turn. @@ -241,7 +251,6 @@ class Gemma { GemmaTokenizer tokenizer_; // Type-erased so that this can be defined in the header. ModelWeightsStorage model_; - ModelInfo info_; }; // Adds BOS token and possibly 'turn' annotations, which depend on `info` diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 82ee01d8..4e94ad65 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -53,7 +53,7 @@ KVCache KVCache::Create(const ModelConfig& weights_config, LayerAttentionType::kGriffinRecurrentBlock); // TODO(patrickms): Add query batching support for Griffin. if (num_griffin_layers > 0) { - size_t conv1d_width = 0; + uint32_t conv1d_width = 0; for (const auto& layer_config : weights_config.layer_configs) { conv1d_width = std::max(conv1d_width, layer_config.conv1d_width); } diff --git a/gemma/tensor_index.cc b/gemma/tensor_index.cc index b6afa4d9..50c7af0c 100644 --- a/gemma/tensor_index.cc +++ b/gemma/tensor_index.cc @@ -482,6 +482,8 @@ std::vector LLMLayerTensors(const ModelConfig& config, .name = "att_w", .source_names = {"attn/attn_vec_einsum/w", "attention_block/proj_final/kernel"}, + .preshape = {layer_config.heads, layer_config.qkv_dim, + config.model_dim}, .axes = {2, 0, 1}, .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, .cols_take_extra_dims = true, diff --git a/gemma/tensor_index_test.cc b/gemma/tensor_index_test.cc index 8928ad4e..43eaa4ea 100644 --- a/gemma/tensor_index_test.cc +++ b/gemma/tensor_index_test.cc @@ -56,7 +56,8 @@ TEST(TensorIndexTest, FindName) { // Test that the MatPtr can be constructed from the TensorInfo, // and that the dimensions match. MatPtrT mat_ptr(tensor.Name(), tensor_index); - EXPECT_EQ(tensor.Name(), mat_ptr.Name()) << "on tensor " << name; + EXPECT_STREQ(tensor.Name(), mat_ptr.Name()) + << "on tensor " << name; EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name; EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name; ++num_found; diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index ffd71ae1..9e0b827b 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -44,6 +44,17 @@ class GemmaTokenizer::Impl { HWY_ABORT("Failed to load the tokenizer file."); } } + // Loads the tokenizer from a serialized proto. + explicit Impl(const std::string& tokenizer_proto) { + PROFILER_ZONE("Startup.tokenizer"); + spp_ = std::make_unique(); + if (!spp_->LoadFromSerializedProto(tokenizer_proto).ok()) { + fprintf(stderr, "serialized proto size=%zu.\n", tokenizer_proto.size()); + HWY_ABORT("Failed to load the tokenizer from serialized proto."); + } + } + + std::string Serialize() const { return spp_->serialized_model_proto(); } bool Encode(const std::string& input, std::vector* pieces) const { @@ -81,6 +92,12 @@ GemmaTokenizer::~GemmaTokenizer() = default; GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default; GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default; +std::string GemmaTokenizer::Serialize() const { return impl_->Serialize(); } + +void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) { + impl_ = std::make_unique(tokenizer_proto); +} + bool GemmaTokenizer::Encode(const std::string& input, std::vector* pieces) const { return impl_->Encode(input, pieces); diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index f0bb0fcd..e2bb6115 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -41,6 +41,9 @@ class GemmaTokenizer { GemmaTokenizer(GemmaTokenizer&& other); GemmaTokenizer& operator=(GemmaTokenizer&& other); + std::string Serialize() const; + void Deserialize(const std::string& tokenizer_proto); + bool Encode(const std::string& input, std::vector* pieces) const; bool Encode(const std::string& input, std::vector* ids) const; bool Decode(const std::vector& ids, std::string* detokenized) const; diff --git a/gemma/weights.cc b/gemma/weights.cc index 6fd84802..568edb81 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -19,11 +19,13 @@ #include #include #include +#include #include #include "compression/blob_store.h" #include "compression/compress.h" #include "compression/io.h" // Path +#include "compression/shared.h" #include "gemma/common.h" #include "gemma/configs.h" #include "hwy/aligned_allocator.h" @@ -47,7 +49,9 @@ struct TensorLoader { }; BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type, - Type weight_type, hwy::ThreadPool& pool) { + Type weight_type, PromptWrapping wrapping, + hwy::ThreadPool& pool, + std::string* tokenizer_proto) { PROFILER_ZONE("Startup.LoadModelWeightsPtrs"); if (!weights.Exists()) { HWY_ABORT("The model weights file '%s' does not exist.", @@ -56,17 +60,36 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type, ReadFromBlobStore loader(weights); ForEachType fet = loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc; + std::vector scales; if (fet == ForEachType::kLoadWithToc) { - // TODO(rays): Load the config from the file. - HWY_ABORT("TOC not supported yet."); + BlobError err = loader.LoadConfig(config_); + if (err != 0 || config_.model_dim == 0) { + fprintf(stderr, "Failed to load model config: %d\n", err); + return err; + } + if (tokenizer_proto != nullptr) { + err = loader.LoadTokenizer(*tokenizer_proto); + if (err != 0) { + fprintf(stderr, "Failed to load tokenizer: %d\n", err); + return err; + } + } } else { + if (weight_type == Type::kUnknown || model_type == Model::UNKNOWN) { + fprintf(stderr, + "weight type (%d) and model type (%d) must be specified when " + "no config is present in weights file\n", + static_cast(weight_type), static_cast(model_type)); + return __LINE__; + } // No Toc-> no config. config_ = ConfigFromModel(model_type); config_.weight = weight_type; + config_.wrapping = wrapping; + scales.resize(config_.num_tensor_scales + config_.num_vit_scales); } - CreateForType(weight_type, pool); + CreateForType(config_.weight, pool); CallForModelWeightT(fet, loader); - std::vector scales(config_.num_tensor_scales + config_.num_vit_scales); if (!scales.empty()) { loader.LoadScales(scales.data(), scales.size()); } @@ -85,6 +108,34 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type, return 0; } +template +struct TensorSaver { + // Adds all the tensors to the blob writer. + void operator()(ModelWeightsPtrs& weights, ForEachType fet, + WriteToBlobStore& writer) { + weights.ForEachTensor( + {&weights}, fet, + [&writer](const char* name, hwy::Span tensors) { + tensors[0]->CallUpcasted(writer, name); + }); + } +}; + +BlobError ModelWeightsStorage::Save(const std::string& tokenizer, + const Path& weights, + hwy::ThreadPool& pool) { + WriteToBlobStore writer(pool); + ForEachType fet = ForEachType::kLoadWithToc; + CallForModelWeightT(fet, writer); + writer.AddTokenizer(tokenizer); + int err = writer.WriteAll(weights, &config_); + if (err != 0) { + fprintf(stderr, "Failed to load model weights: %d\n", err); + return err; + } + return 0; +} + void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type, hwy::ThreadPool& pool) { PROFILER_ZONE("Startup.AllocateModelWeightsPtrs"); diff --git a/gemma/weights.h b/gemma/weights.h index 2db98111..33da8429 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -522,7 +522,18 @@ class ModelWeightsStorage { ModelWeightsStorage() = default; ~ModelWeightsStorage() = default; + // Loads the weights from a blob store file. Supports multi-file or + // single-file format. If the weights file contains a TOC, then it is in + // single-file format, and model_type, weight_type, training are ignored, + // and tokenizer_proto is required and written to. + // With a multi-file format, file, model_type, weight_type, training are + // required and tokenizer_proto is ignored. BlobError Load(const Path& weights, Model model_type, Type weight_type, + PromptWrapping wrapping, hwy::ThreadPool& pool, + std::string* tokenizer_proto); + // Writes the weights to a blob store file, using the single-file format with + // a TOC and config included. + BlobError Save(const std::string& tokenizer, const Path& weights, hwy::ThreadPool& pool); void Allocate(Model model_type, Type weight_type, hwy::ThreadPool& pool) { Allocate(ConfigFromModel(model_type), weight_type, pool); diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 73d9f8f3..547cc63d 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -26,7 +26,6 @@ #include #include -#include "compression/compress.h" #include "compression/shared.h" #include "util/allocator.h" #include "util/test_util.h" diff --git a/util/app.h b/util/app.h index aa175671..6c66d2c3 100644 --- a/util/app.h +++ b/util/app.h @@ -25,6 +25,7 @@ #include #include "compression/io.h" // Path +#include "compression/shared.h" #include "gemma/common.h" #include "gemma/gemma.h" // For CreateGemma #include "ops/matmul.h" @@ -125,7 +126,10 @@ static inline NestedPools CreatePools(const AppArgs& app) { } struct LoaderArgs : public ArgsBase { - LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } + LoaderArgs(int argc, char* argv[], bool required = true) + : model_type_required(required) { + InitAndParse(argc, argv); + } LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path, const std::string& model) { Init(); // Init sets to defaults, so assignments must come after Init(). @@ -136,18 +140,24 @@ struct LoaderArgs : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() { + info_.model = Model::UNKNOWN; + info_.wrapping = PromptWrapping::GEMMA_PT; + info_.weight = Type::kUnknown; if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model, info_.wrapping)) { - return err; + if (model_type_required) return err; } if (const char* err = ParseType(weight_type_str, info_.weight)) { - return err; - } - if (tokenizer.path.empty()) { - return "Missing --tokenizer flag, a file for the tokenizer is required."; + if (model_type_required) return err; } - if (!tokenizer.Exists()) { - return "Can't open file specified with --tokenizer flag."; + if (model_type_required) { + if (tokenizer.path.empty()) { + return "Missing --tokenizer flag, a file for the tokenizer is " + "required."; + } + if (!tokenizer.Exists()) { + return "Can't open file specified with --tokenizer flag."; + } } if (!compressed_weights.path.empty()) { if (weights.path.empty()) { @@ -172,11 +182,12 @@ struct LoaderArgs : public ArgsBase { Path compressed_weights; std::string model_type_str; std::string weight_type_str; + bool model_type_required = true; template void ForEach(const Visitor& visitor) { visitor(tokenizer, "tokenizer", Path(), - "Path name of tokenizer model file.\n Required argument."); + "Path name of tokenizer model file."); visitor(weights, "weights", Path(), "Path name of model weights (.sbs) file.\n Required argument."); visitor(compressed_weights, "compressed_weights", Path(), @@ -186,11 +197,9 @@ struct LoaderArgs : public ArgsBase { "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " "gr2b-it = griffin 2B parameters, instruction-tuned\n " - "gr2b-pt = griffin 2B parameters, pretrained\n " - " Required argument."); + "gr2b-pt = griffin 2B parameters, pretrained."); visitor(weight_type_str, "weight_type", std::string("sfp"), - "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP\n" - " Required argument."); + "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP."); } // Uninitialized before Validate, must call after that. @@ -208,6 +217,12 @@ static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) { static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, NestedPools& pools) { + if (Type::kUnknown == loader.Info().weight || + Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) { + // Newer weights file format doesn't need tokenizer path or model/weight + // info. + return std::make_unique(loader.weights, pools); + } return std::make_unique(loader.tokenizer, loader.weights, loader.Info(), pools); }