Skip to content

Commit 9d40f01

Browse files
theraysmithcopybara-github
authored andcommitted
Added ability to load/save a complete model file, including tokenizer.
PiperOrigin-RevId: 707914366
1 parent 5bc356f commit 9d40f01

32 files changed

+770
-257
lines changed

BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ cc_library(
245245
"gemma/tensor_index.h",
246246
],
247247
deps = [
248+
"//compression:fields",
248249
"//compression:sfp",
249250
"@highway//:hwy", # base.h
250251
"@highway//:thread_pool",
@@ -257,6 +258,7 @@ cc_test(
257258
deps = [
258259
":common",
259260
"@googletest//:gtest_main",
261+
"@highway//:hwy",
260262
],
261263
)
262264

@@ -388,6 +390,7 @@ cc_library(
388390
":ops",
389391
":threading",
390392
"//compression:io",
393+
"//compression:sfp",
391394
"@highway//:hwy",
392395
],
393396
)

backprop/backward_scalar_test.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,13 +390,12 @@ static ModelConfig TestConfig() {
390390
config.model_dim = 32;
391391
config.vocab_size = 12;
392392
config.seq_len = 18;
393-
LayerConfig layer_config = {
394-
.model_dim = config.model_dim,
395-
.ff_hidden_dim = 48,
396-
.heads = 3,
397-
.kv_heads = 1,
398-
.qkv_dim = 12,
399-
};
393+
LayerConfig layer_config;
394+
layer_config.model_dim = config.model_dim;
395+
layer_config.ff_hidden_dim = 48;
396+
layer_config.heads = 3;
397+
layer_config.kv_heads = 1;
398+
layer_config.qkv_dim = 12;
400399
config.layer_configs = {2, layer_config};
401400
config.num_tensor_scales = 4 * config.layer_configs.size();
402401
config.query_scale = QueryScaleType::SqrtKeySize;

backprop/backward_test.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,12 @@ static ModelConfig TestConfig() {
191191
config.model_dim = 32;
192192
config.vocab_size = 16;
193193
config.seq_len = 24;
194-
LayerConfig layer_config = {
195-
.model_dim = config.model_dim,
196-
.ff_hidden_dim = 64,
197-
.heads = 3,
198-
.kv_heads = 1,
199-
.qkv_dim = 16,
200-
};
194+
LayerConfig layer_config;
195+
layer_config.model_dim = config.model_dim;
196+
layer_config.ff_hidden_dim = 64;
197+
layer_config.heads = 3;
198+
layer_config.kv_heads = 1;
199+
layer_config.qkv_dim = 16;
201200
config.layer_configs = {2, layer_config};
202201
config.num_tensor_scales = 4 * config.layer_configs.size();
203202
config.query_scale = QueryScaleType::SqrtKeySize;

compression/BUILD.bazel

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ cc_test(
5858
deps = [
5959
":fields",
6060
"@googletest//:gtest_main", # buildcleaner: keep
61-
"@highway//:hwy",
6261
"@highway//:hwy_test_util",
6362
],
6463
)
@@ -202,6 +201,7 @@ cc_library(
202201
deps = [
203202
":blob_store",
204203
":distortion",
204+
":fields",
205205
":io",
206206
":nuq",
207207
":sfp",
@@ -210,7 +210,6 @@ cc_library(
210210
"//:common",
211211
"@highway//:hwy",
212212
"@highway//:nanobenchmark",
213-
"@highway//:profiler",
214213
"@highway//:stats",
215214
"@highway//:thread_pool",
216215
],
@@ -261,6 +260,7 @@ cc_binary(
261260
"//:allocator",
262261
"//:args",
263262
"//:common",
263+
"//:tokenizer",
264264
"//:weights",
265265
"@highway//:hwy",
266266
"@highway//:thread_pool",
@@ -277,3 +277,14 @@ cc_binary(
277277
"@highway//:hwy_test_util",
278278
],
279279
)
280+
281+
cc_binary(
282+
name = "migrate_weights",
283+
srcs = ["migrate_weights.cc"],
284+
deps = [
285+
"//:app",
286+
"//:args",
287+
"//:benchmark_helper",
288+
"//:gemma_lib",
289+
],
290+
)

compression/compress-inl.h

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
#include <stdio.h>
2323

2424
#include <cmath> // lroundf, only if COMPRESS_STATS
25+
#include <string>
26+
#include <vector>
2527

2628
#include "compression/blob_store.h"
2729
#include "compression/compress.h" // IWYU pragma: export
2830
#include "compression/distortion.h"
31+
#include "gemma/configs.h"
2932
#include "hwy/aligned_allocator.h"
3033
#include "hwy/base.h"
3134
#include "hwy/contrib/thread_pool/thread_pool.h"
@@ -673,45 +676,45 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
673676
// their scaling factors to BlobStore.
674677
class Compressor {
675678
public:
676-
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
679+
explicit Compressor(hwy::ThreadPool& pool) : writer_(pool) {}
677680

678681
template <typename Packed>
679682
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
680683
const float* HWY_RESTRICT weights) {
681684
size_t num_weights = compressed->NumElements();
685+
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
686+
return;
682687
size_t num_compressed = compressed->NumElements();
683688
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
684689
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
685690
num_weights / (1000 * 1000));
686-
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, pool_);
687-
const size_t num_bytes = packed.num * sizeof(Packed);
688-
writer_.Add(MakeKey(decorated_name), packed.ptr, num_bytes);
691+
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
692+
writer_.pool());
693+
writer_(compressed, decorated_name);
694+
}
695+
696+
void AddTokenizer(const std::string& tokenizer) {
697+
writer_.AddTokenizer(tokenizer);
689698
}
690699

691700
void AddScales(const float* scales, size_t len) {
692-
if (len) {
693-
MatPtrT<float> scales_ptr("scales", 0, 1);
694-
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
695-
len * sizeof(scales[0]));
696-
}
701+
writer_.AddScales(scales, len);
697702
}
698703

699-
BlobError WriteAll(hwy::ThreadPool& pool, const Path& blob_filename) {
700-
const BlobError err = writer_.WriteAll(pool, blob_filename);
701-
if (err != 0) {
702-
fprintf(stderr, "Failed to write blobs to %s (error %d)\n",
703-
blob_filename.path.c_str(), err);
704-
}
705-
return err;
704+
// Writes all blobs to disk in the given order. The config is optional and
705+
// if given, it is written to the file, along with the TOC, making it
706+
// single-file format. Otherwise, the file is written in the multi-file format
707+
// without a TOC.
708+
BlobError WriteAll(const Path& blob_filename, const ModelConfig* config) {
709+
return writer_.WriteAll(blob_filename, config);
706710
}
707711

708712
// Returns the number of blobs added.
709713
size_t DebugNumBlobsAdded() const { return writer_.DebugNumBlobsAdded(); }
710714

711715
private:
712716
CompressWorkingSet work_;
713-
hwy::ThreadPool& pool_;
714-
BlobWriter writer_;
717+
WriteToBlobStore writer_;
715718
};
716719

717720
// NOLINTNEXTLINE(google-readability-namespace-comments)

0 commit comments

Comments
 (0)