Skip to content

Commit

Permalink
Added ability to load/save a complete model file, including tokenizer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707914366
  • Loading branch information
theraysmith authored and copybara-github committed Dec 19, 2024
1 parent 5bc356f commit 9d40f01
Show file tree
Hide file tree
Showing 32 changed files with 770 additions and 257 deletions.
3 changes: 3 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ cc_library(
"gemma/tensor_index.h",
],
deps = [
"//compression:fields",
"//compression:sfp",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
Expand All @@ -257,6 +258,7 @@ cc_test(
deps = [
":common",
"@googletest//:gtest_main",
"@highway//:hwy",
],
)

Expand Down Expand Up @@ -388,6 +390,7 @@ cc_library(
":ops",
":threading",
"//compression:io",
"//compression:sfp",
"@highway//:hwy",
],
)
Expand Down
13 changes: 6 additions & 7 deletions backprop/backward_scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,12 @@ static ModelConfig TestConfig() {
config.model_dim = 32;
config.vocab_size = 12;
config.seq_len = 18;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 48,
.heads = 3,
.kv_heads = 1,
.qkv_dim = 12,
};
LayerConfig layer_config;
layer_config.model_dim = config.model_dim;
layer_config.ff_hidden_dim = 48;
layer_config.heads = 3;
layer_config.kv_heads = 1;
layer_config.qkv_dim = 12;
config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
Expand Down
13 changes: 6 additions & 7 deletions backprop/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,12 @@ static ModelConfig TestConfig() {
config.model_dim = 32;
config.vocab_size = 16;
config.seq_len = 24;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 64,
.heads = 3,
.kv_heads = 1,
.qkv_dim = 16,
};
LayerConfig layer_config;
layer_config.model_dim = config.model_dim;
layer_config.ff_hidden_dim = 64;
layer_config.heads = 3;
layer_config.kv_heads = 1;
layer_config.qkv_dim = 16;
config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
Expand Down
15 changes: 13 additions & 2 deletions compression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ cc_test(
deps = [
":fields",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
Expand Down Expand Up @@ -202,6 +201,7 @@ cc_library(
deps = [
":blob_store",
":distortion",
":fields",
":io",
":nuq",
":sfp",
Expand All @@ -210,7 +210,6 @@ cc_library(
"//:common",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
Expand Down Expand Up @@ -261,6 +260,7 @@ cc_binary(
"//:allocator",
"//:args",
"//:common",
"//:tokenizer",
"//:weights",
"@highway//:hwy",
"@highway//:thread_pool",
Expand All @@ -277,3 +277,14 @@ cc_binary(
"@highway//:hwy_test_util",
],
)

cc_binary(
name = "migrate_weights",
srcs = ["migrate_weights.cc"],
deps = [
"//:app",
"//:args",
"//:benchmark_helper",
"//:gemma_lib",
],
)
39 changes: 21 additions & 18 deletions compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
#include <stdio.h>

#include <cmath> // lroundf, only if COMPRESS_STATS
#include <string>
#include <vector>

#include "compression/blob_store.h"
#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
Expand Down Expand Up @@ -673,45 +676,45 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
// their scaling factors to BlobStore.
class Compressor {
public:
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {}

template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
const float* HWY_RESTRICT weights) {
size_t num_weights = compressed->NumElements();
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
return;
size_t num_compressed = compressed->NumElements();
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
num_weights / (1000 * 1000));
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, pool_);
const size_t num_bytes = packed.num * sizeof(Packed);
writer_.Add(MakeKey(decorated_name), packed.ptr, num_bytes);
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
writer_.pool());
writer_(compressed, decorated_name);
}

void AddTokenizer(const std::string& tokenizer) {
writer_.AddTokenizer(tokenizer);
}

void AddScales(const float* scales, size_t len) {
if (len) {
MatPtrT<float> scales_ptr("scales", 0, 1);
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
len * sizeof(scales[0]));
}
writer_.AddScales(scales, len);
}

BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
blob_filename.path.c_str(), err);
}
return err;
// Writes all blobs to disk in the given order. The config is optional and
// if given, it is written to the file, along with the TOC, making it
// single-file format. Otherwise, the file is written in the multi-file format
// without a TOC.
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
return writer_.WriteAll(blob_filename, config);
}

// Returns the number of blobs added.
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }

private:
CompressWorkingSet work_;
hwy::ThreadPool& pool_;
BlobWriter writer_;
WriteToBlobStore writer_;
};

// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand Down
Loading

0 comments on commit 9d40f01

Please sign in to comment.