Skip to content

Commit

Permalink
Added ability to load/save a complete model file, including tokenizer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706704368
  • Loading branch information
theraysmith authored and copybara-github committed Dec 17, 2024
1 parent 73766e8 commit bfab7dd
Show file tree
Hide file tree
Showing 32 changed files with 754 additions and 265 deletions.
3 changes: 3 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ cc_library(
"gemma/tensor_index.h",
],
deps = [
"//compression:fields",
"//compression:sfp",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
Expand All @@ -256,6 +257,7 @@ cc_test(
deps = [
":common",
"@googletest//:gtest_main",
"@highway//:hwy",
],
)

Expand Down Expand Up @@ -387,6 +389,7 @@ cc_library(
":ops",
":threading",
"//compression:io",
"//compression:sfp",
"@highway//:hwy",
],
)
Expand Down
13 changes: 6 additions & 7 deletions backprop/backward_scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,12 @@ static ModelConfig TestConfig() {
config.model_dim = 32;
config.vocab_size = 12;
config.seq_len = 18;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 48,
.heads = 3,
.kv_heads = 1,
.qkv_dim = 12,
};
LayerConfig layer_config;
layer_config.model_dim = config.model_dim;
layer_config.ff_hidden_dim = 48;
layer_config.heads = 3;
layer_config.kv_heads = 1;
layer_config.qkv_dim = 12;
config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
Expand Down
13 changes: 6 additions & 7 deletions backprop/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,12 @@ static ModelConfig TestConfig() {
config.model_dim = 32;
config.vocab_size = 16;
config.seq_len = 24;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 64,
.heads = 3,
.kv_heads = 1,
.qkv_dim = 16,
};
LayerConfig layer_config;
layer_config.model_dim = config.model_dim;
layer_config.ff_hidden_dim = 64;
layer_config.heads = 3;
layer_config.kv_heads = 1;
layer_config.qkv_dim = 16;
config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
Expand Down
15 changes: 13 additions & 2 deletions compression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ cc_test(
deps = [
":fields",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
Expand Down Expand Up @@ -196,6 +195,7 @@ cc_library(
deps = [
":blob_store",
":distortion",
":fields",
":io",
":nuq",
":sfp",
Expand All @@ -204,7 +204,6 @@ cc_library(
"//:common",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
Expand Down Expand Up @@ -255,6 +254,7 @@ cc_binary(
"//:allocator",
"//:args",
"//:common",
"//:tokenizer",
"//:weights",
"@highway//:hwy",
"@highway//:thread_pool",
Expand All @@ -271,3 +271,14 @@ cc_binary(
"@highway//:hwy_test_util",
],
)

cc_binary(
name = "reformat_weights",
srcs = ["reformat_weights.cc"],
deps = [
"//:app",
"//:args",
"//:benchmark_helper",
"//:gemma_lib",
],
)
33 changes: 7 additions & 26 deletions compression/compress-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
#include <stdio.h>

#include <cmath> // lroundf, only if COMPRESS_STATS
#include <vector>

#include "compression/blob_store.h"
#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
Expand Down Expand Up @@ -671,47 +673,26 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,

// Functor called for each tensor, which compresses and stores them along with
// their scaling factors to BlobStore.
class Compressor {
class Compressor : public WriteToBlobStore {
public:
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
explicit Compressor(hwy::ThreadPool& pool) : WriteToBlobStore(pool) {}

template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
const float* HWY_RESTRICT weights) {
size_t num_weights = compressed->NumElements();
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
return;
size_t num_compressed = compressed->NumElements();
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
num_weights / (1000 * 1000));
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, pool_);
const size_t num_bytes = packed.num * sizeof(Packed);
writer_.Add(MakeKey(decorated_name), packed.ptr, num_bytes);
WriteToBlobStore::operator()(compressed, decorated_name);
}

void AddScales(const float* scales, size_t len) {
if (len) {
MatPtrT<float> scales_ptr("scales", 0, 1);
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
len * sizeof(scales[0]));
}
}

BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
const BlobError err = writer_.WriteAll(pool, blob_filename);
if (err != 0) {
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
blob_filename.path.c_str(), err);
}
return err;
}

// Returns the number of blobs added.
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }

private:
CompressWorkingSet work_;
hwy::ThreadPool& pool_;
BlobWriter writer_;
};

// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand Down
Loading

0 comments on commit bfab7dd

Please sign in to comment.