Skip to content

Commit 34e7bdf

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Minor cleanup, on-demand NUQ buffer allocation
threading_context: add profiler compress-inl: add constexpr, on-demand alloc NUQ buffer gemma_py: model->gemma Move ScaleWeights to compress.cc Move PromptWrapping to configs.h PiperOrigin-RevId: 748277092
1 parent 7164a5e commit 34e7bdf

17 files changed

+188
-94
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"])
8383
subprocess.run(["chmod", "700", "/kaggle/working/gemma"])
8484
subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"])
85-
output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--compressed_weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout
85+
output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout
8686
assert("write an email to the moon." not in output.lower());
8787
assert("moon" in output.lower());
8888
EOF

BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ cc_library(
9292
":basics",
9393
":threading",
9494
":topology",
95+
"@highway//:hwy",
96+
"@highway//:profiler",
9597
],
9698
)
9799

@@ -180,6 +182,7 @@ cc_library(
180182
"//compression:shared",
181183
"@highway//:hwy",
182184
"@highway//:profiler",
185+
"@highway//:thread_pool",
183186
],
184187
)
185188

@@ -664,6 +667,7 @@ cc_test(
664667
":mat",
665668
":prompt",
666669
":sampler",
670+
":threading_context",
667671
":weights",
668672
"@googletest//:gtest_main", # buildcleaner: keep
669673
"@highway//:thread_pool",

compression/BUILD.bazel

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ cc_library(
7070
hdrs = ["blob_store.h"],
7171
deps = [
7272
":io",
73+
"//:basics",
7374
"//:threading_context",
7475
"@highway//:hwy",
7576
"@highway//:thread_pool",
@@ -130,7 +131,6 @@ cc_library(
130131
textual_hdrs = ["sfp-inl.h"],
131132
deps = [
132133
":shared",
133-
"//:basics",
134134
"@highway//:hwy",
135135
],
136136
)
@@ -195,7 +195,6 @@ cc_test(
195195
deps = [
196196
":distortion",
197197
":nuq",
198-
":sfp",
199198
"@googletest//:gtest_main", # buildcleaner: keep
200199
"//:test_util",
201200
"@highway//:hwy",
@@ -225,6 +224,7 @@ cc_library(
225224
"//:mat",
226225
"@highway//:hwy",
227226
"@highway//:nanobenchmark",
227+
"@highway//:profiler",
228228
"@highway//:stats",
229229
"@highway//:thread_pool",
230230
],
@@ -259,6 +259,7 @@ cc_library(
259259
deps = [
260260
":nuq",
261261
":sfp",
262+
":shared",
262263
"@highway//:hwy",
263264
"@highway//:stats",
264265
"@highway//:thread_pool",

compression/compress-inl.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
#include <stdint.h>
2222
#include <stdio.h>
2323

24-
#include <cmath> // lroundf, only if COMPRESS_STATS
25-
#include <string>
24+
#include <memory>
2625
#include <vector>
2726

2827
#include "compression/blob_store.h"
@@ -35,6 +34,10 @@
3534
#include "hwy/contrib/thread_pool/thread_pool.h"
3635
#include "hwy/timer.h"
3736

37+
#if COMPRESS_STATS
38+
#include <cmath> // lroundf
39+
#endif
40+
3841
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_
3942

4043
// Include guard for (potentially) SIMD code.
@@ -388,7 +391,7 @@ struct CompressTraits<SfpStream> {
388391
const size_t packed_ofs) {
389392
SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs);
390393

391-
if (COMPRESS_STATS) {
394+
if constexpr (COMPRESS_STATS) {
392395
const hn::Repartition<BF16, DF> dbf;
393396
auto distorted =
394397
hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, hn::Lanes(dbf)));
@@ -432,9 +435,10 @@ struct CompressTraits<NuqStream> {
432435
size_t num, CompressPerThread& tls,
433436
const PackedSpan<Packed>& packed,
434437
const size_t packed_ofs) {
435-
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
438+
if (!tls.buf) tls.buf = std::make_unique<NuqStream::ClusterBuf>();
439+
NuqCodec::Enc(df, raw, num, *tls.buf, packed, packed_ofs);
436440

437-
if (COMPRESS_STATS) {
441+
if constexpr (COMPRESS_STATS) {
438442
for (size_t i = 0; i < num; ++i) {
439443
tls.stats.NotifyIn(static_cast<int>(lroundf(raw[i] * 100.0f + 500.0f)));
440444
}
@@ -478,7 +482,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
478482
const size_t packed_ofs, hwy::ThreadPool& pool) {
479483
packed.BoundsCheck(packed_ofs, num);
480484
work.tls.resize(pool.NumWorkers());
481-
if (COMPRESS_STATS) {
485+
if constexpr (COMPRESS_STATS) {
482486
for (auto& tls : work.tls) {
483487
tls.stats.Reset();
484488
}
@@ -487,7 +491,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
487491
const bool want_bench = COMPRESS_STATS || !kIsTest;
488492
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
489493

490-
using Traits = CompressTraits<Packed>;
494+
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
491495
constexpr size_t kBatch = 8192;
492496
const size_t num_batches = hwy::DivCeil(num, kBatch);
493497
pool.Run(0, num_batches,
@@ -508,7 +512,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
508512
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
509513
}
510514

511-
if (COMPRESS_STATS) {
515+
if constexpr (COMPRESS_STATS) {
512516
for (size_t i = 1; i < work.tls.size(); ++i) {
513517
work.tls[0].stats.Assimilate(work.tls[i].stats);
514518
}
@@ -534,7 +538,7 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
534538
const size_t packed_ofs) {
535539
static_assert(hwy::IsSameEither<Packed, float, BF16>());
536540
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df));
537-
using Traits = CompressTraits<Packed>;
541+
using Traits = CompressTraits<hwy::RemoveConst<Packed>>;
538542
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
539543
}
540544

compression/compress.cc

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,34 @@
1515

1616
#include "compression/compress.h"
1717

18+
#include <stddef.h>
19+
#include <stdint.h>
20+
21+
#include "util/mat.h"
22+
#include "hwy/base.h"
23+
#include "hwy/profiler.h"
24+
1825
namespace gcpp {
1926

20-
// TODO: move ScaleWeights here.
27+
float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
28+
PROFILER_FUNC;
29+
30+
float maxabs = 0.0;
31+
for (size_t i = 0; i < num; ++i) {
32+
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
33+
}
34+
if (maxabs <= SfpStream::kMax) {
35+
return 1.0f;
36+
}
37+
const float scale = maxabs / SfpStream::kMax;
38+
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
39+
for (size_t i = 0; i < num; ++i) {
40+
// Clamp because kMax may still be exceeded.
41+
const float magn =
42+
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
43+
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
44+
}
45+
return scale;
46+
}
2147

2248
} // namespace gcpp

compression/compress.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,19 @@
1717
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
1818
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
1919

20-
#include "hwy/base.h"
2120
#define COMPRESS_STATS 0
2221

2322
#include <stddef.h>
2423
#include <stdint.h>
2524
#include <stdio.h>
2625

27-
#include <cstdio>
28-
#include <cstring>
29-
#include <string>
30-
#include <unordered_map>
31-
#include <utility>
26+
#include <memory>
3227
#include <vector>
3328

34-
// IWYU pragma: begin_exports
3529
#include "compression/blob_store.h"
3630
#include "compression/fields.h"
3731
#include "compression/io.h"
38-
#include "compression/shared.h"
39-
#include "gemma/tensor_index.h"
32+
#include "compression/shared.h" // NuqStream::ClusterBuf
4033
#include "util/basics.h"
4134
// IWYU pragma: end_exports
4235
#include "gemma/configs.h"
@@ -174,7 +167,8 @@ struct CompressStats {
174167
#endif // COMPRESS_STATS
175168

176169
struct CompressPerThread {
177-
NuqStream::ClusterBuf buf;
170+
// Allocated the first time NUQ is used.
171+
std::unique_ptr<NuqStream::ClusterBuf> buf;
178172
CompressStats stats;
179173
};
180174

@@ -375,5 +369,11 @@ class ReadFromBlobStore {
375369
std::vector<std::string> file_keys_;
376370
};
377371

372+
// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales
373+
// them such that the largest magnitude is `SfpStream::kMax`, and returns the
374+
// multiplier with which to restore the original values. This is only necessary
375+
// before compressing to `SfpStream` and `NuqStream`.
376+
float ScaleWeights(float* HWY_RESTRICT raw, size_t num);
377+
378378
} // namespace gcpp
379379
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_

compression/shared.h

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
// See the License for the specific language governing permissions and
1414
// limitations under the License.
1515

16-
// Definitions shared between the public compress-inl.h interface and the
17-
// sfp-inl.h and nuq-inl.h implementation details.
16+
// Types shared between tensor definitions and `compress-inl.h`.
1817

1918
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
2019
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
@@ -63,30 +62,6 @@ struct SfpStream {
6362
};
6463
#pragma pack(pop)
6564

66-
// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them
67-
// such that the largest magnitude is SfpStream::kMax, and returns the
68-
// multiplier with which to restore the original values. This is only necessary
69-
// before compressing to SfpStream.
70-
// TODO: vectorize
71-
static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
72-
float maxabs = 0.0;
73-
for (size_t i = 0; i < num; ++i) {
74-
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
75-
}
76-
if (maxabs <= SfpStream::kMax) {
77-
return 1.0f;
78-
}
79-
const float scale = maxabs / SfpStream::kMax;
80-
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
81-
for (size_t i = 0; i < num; ++i) {
82-
// Clamp because kMax may still be exceeded.
83-
const float magn =
84-
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
85-
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
86-
}
87-
return scale;
88-
}
89-
9065
// Non-uniform quantization: a compressed representation of f32 inputs that
9166
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
9267
// two vectors (for `Decompress2`), and decoding to bf16/f32.
@@ -185,20 +160,6 @@ constexpr bool IsNuqStream() {
185160
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
186161
}
187162

188-
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
189-
enum class PromptWrapping {
190-
GEMMA_IT,
191-
GEMMA_PT,
192-
GEMMA_VLM,
193-
PALIGEMMA,
194-
kSentinel // must be last
195-
};
196-
197-
inline bool EnumValid(PromptWrapping type) {
198-
return static_cast<int>(type) >= 0 &&
199-
static_cast<int>(type) < static_cast<int>(PromptWrapping::kSentinel);
200-
}
201-
202163
// Tensor types for loading weights. Note that not all types are supported as
203164
// weights for a model, but can be used for other purposes, such as types for
204165
// `WeightsPtrs`. When adding a new type that is supported, also

gemma/configs.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ static constexpr size_t kMaxConv1DWidth = 4;
4949

5050
using EmbedderInputT = BF16;
5151

52+
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
53+
enum class PromptWrapping {
54+
GEMMA_IT,
55+
GEMMA_PT,
56+
GEMMA_VLM,
57+
PALIGEMMA,
58+
kSentinel // must be last
59+
};
60+
61+
static inline bool EnumValid(PromptWrapping wrapping) {
62+
return static_cast<size_t>(wrapping) <
63+
static_cast<size_t>(PromptWrapping::kSentinel);
64+
}
65+
5266
enum class LayerAttentionType {
5367
kGemma,
5468
kGriffinRecurrentBlock,

ops/dot-inl.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515

1616
#include <stddef.h>
1717

18-
#include "compression/compress.h"
19-
#include "util/mat.h"
20-
#include "hwy/base.h"
2118
#include "hwy/profiler.h"
2219

2320
// Include guard for (potentially) SIMD code.

paligemma/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ cc_test(
4040
],
4141
deps = [
4242
"@googletest//:gtest_main", # buildcleaner: keep
43+
"//:allocator",
4344
"//:benchmark_helper",
4445
"//:common",
4546
"//:gemma_lib",

paligemma/paligemma_test.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
#include "compression/shared.h"
2121
#include "evals/benchmark_helper.h"
2222
#include "gemma/common.h"
23+
#include "gemma/configs.h"
2324
#include "gemma/gemma.h"
25+
#include "util/allocator.h"
2426
#include "hwy/base.h"
2527
#include "hwy/tests/hwy_gtest.h"
2628

@@ -50,17 +52,18 @@ class PaliGemmaTest : public ::testing::Test {
5052

5153
void PaliGemmaTest::InitVit(const std::string& path) {
5254
ASSERT_NE(s_env->GetGemma(), nullptr);
53-
Gemma& model = *(s_env->GetGemma());
54-
image_tokens_ =
55-
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
56-
model.GetModelConfig().model_dim));
55+
const Allocator2& allocator = s_env->Env().ctx.allocator;
56+
Gemma& gemma = *(s_env->GetGemma());
57+
image_tokens_ = ImageTokens(
58+
allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len,
59+
gemma.GetModelConfig().model_dim));
5760
Image image;
58-
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
61+
HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA);
5962
HWY_ASSERT(image.ReadPPM(path));
60-
const size_t image_size = model.GetModelConfig().vit_config.image_size;
63+
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
6164
image.Resize(image_size, image_size);
6265
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0};
63-
model.GenerateImageTokens(runtime_config, image, image_tokens_);
66+
gemma.GenerateImageTokens(runtime_config, image, image_tokens_);
6467
}
6568

6669
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
@@ -124,7 +127,7 @@ TEST_F(PaliGemmaTest, General) {
124127
};
125128
const char* (*qa)[2];
126129
size_t num;
127-
switch (s_env->GetGemma()->Info().model) {
130+
switch (s_env->GetGemma()->GetModelConfig().model) {
128131
case Model::PALIGEMMA_224:
129132
qa = kQA_3B_mix_224;
130133
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);

python/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ pybind_extension(
2121
name = "gemma",
2222
srcs = ["gemma_py.cc"],
2323
deps = [
24-
"//:allocator",
2524
"//:benchmark_helper",
2625
"//:gemma_args",
2726
"//:gemma_lib",
27+
"//:threading_context",
2828
"//compression:shared",
2929
"@highway//:hwy",
3030
],

0 commit comments

Comments
 (0)