Skip to content

Commit

Permalink
Matmul rewrite: fp64 sums, hierarchical parallelization, cache-blocki…
Browse files Browse the repository at this point in the history
…ng, autotuning

Remove empty matmul_unit_test.
Up to 25 TFLOP/s on 2xZen4 for 512,3072,24576.

PiperOrigin-RevId: 684398694
  • Loading branch information
jan-wassenberg authored and copybara-github committed Feb 12, 2025
1 parent f173aa7 commit c803fb9
Show file tree
Hide file tree
Showing 16 changed files with 2,510 additions and 568 deletions.
28 changes: 8 additions & 20 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ test_suite(

cc_library(
name = "ops",
srcs = [
"ops/matmul.cc",
],
hdrs = [
"ops/matmul.h",
"ops/ops.h",
Expand All @@ -103,12 +106,14 @@ cc_library(
":threading",
"//compression:compress",
"@highway//:algo",
"@highway//:bit_set",
"@highway//:hwy",
"@highway//:math",
"@highway//:matvec",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
"@highway//:topology",
],
)

Expand All @@ -126,6 +131,7 @@ cc_test(
":test_util",
":threading",
"@googletest//:gtest_main", # buildcleaner: keep
"//:app",
"//compression:compress",
"//compression:test_util",
"@highway//:hwy",
Expand All @@ -151,6 +157,7 @@ cc_test(
":ops",
":test_util",
"@googletest//:gtest_main", # buildcleaner: keep
"//:app",
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
Expand All @@ -176,26 +183,6 @@ cc_test(
],
)

cc_test(
name = "matmul_unit_test",
size = "small",
timeout = "long",
srcs = ["ops/matmul_unit_test.cc"],
local_defines = ["HWY_IS_TEST"],
# for test_suite.
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
":ops",
":test_util",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)

cc_test(
name = "matmul_test",
size = "small",
Expand Down Expand Up @@ -652,6 +639,7 @@ cc_test(
":sampler",
":weights",
"@googletest//:gtest_main",
"//:threading",
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ set(SOURCES
gemma/weights.h
ops/dot-inl.h
ops/matmul-inl.h
ops/matmul.cc
ops/matmul.h
ops/matvec-inl.h
ops/ops-inl.h
ops/ops.h
Expand Down Expand Up @@ -168,7 +170,6 @@ set(GEMMA_TEST_FILES
ops/dot_test.cc
ops/gemma_matvec_test.cc
ops/matmul_test.cc
ops/matmul_unit_test.cc
ops/ops_test.cc
paligemma/image_test.cc
paligemma/paligemma_test.cc
Expand Down
27 changes: 18 additions & 9 deletions backprop/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "backprop/test_util.h"
#include "gemma/configs.h"
#include "ops/ops.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"

Expand All @@ -58,7 +59,9 @@ void TestMatMulVJP() {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
hwy::ThreadPool pool(8);
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols);
MatStorageT<float> x("x", kTokens, kCols);
Expand All @@ -85,7 +88,7 @@ void TestMatMulVJP() {

grad.ZeroInit();
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
grad.data(), dx.data(), pool);
grad.data(), dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);

Expand All @@ -102,7 +105,9 @@ void TestMultiHeadMatMulVJP() {
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
hwy::ThreadPool pool(8);
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
MatStorageT<float> x("x", kTokens, kCols * kHeads);
Expand Down Expand Up @@ -130,7 +135,7 @@ void TestMultiHeadMatMulVJP() {

grad.ZeroInit();
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
kRows, kTokens, grad.data(), dx.data(), pool);
kRows, kTokens, grad.data(), dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);

Expand All @@ -145,7 +150,9 @@ void TestMultiHeadMatMulVJP() {
void TestRMSNormVJP() {
static const size_t K = 2;
static const size_t N = 64;
hwy::ThreadPool pool(8);
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 8));
Allocator::Init(pools.Topology());
std::mt19937 gen(42);
MatStorageT<float> weights("weights", N, 1);
MatStorageT<float> x("x", K, N);
Expand All @@ -172,7 +179,7 @@ void TestRMSNormVJP() {

grad.ZeroInit();
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
dx.data(), pool);
dx.data(), pools.Pool());
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);

Expand Down Expand Up @@ -209,7 +216,9 @@ static ModelConfig TestConfig() {

void TestEndToEnd() {
std::mt19937 gen(42);
hwy::ThreadPool pool(0);
gcpp::NestedPools pools(1, /*pin=*/Tristate::kFalse, BoundedSlice(0, 1),
BoundedSlice(0, 1));
Allocator::Init(pools.Topology());
ModelConfig config = TestConfig();
WeightsWrapper<float> weights(config);
WeightsWrapper<float> grad(config);
Expand All @@ -234,13 +243,13 @@ void TestEndToEnd() {

float loss1 = CrossEntropyLossForwardPass(
prompt.tokens, prompt.context_size, weights.get(), forward1,
inv_timescale, pool);
inv_timescale, pools.Pool());

EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);

grad.ZeroInit();
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
backward, inv_timescale, pool);
backward, inv_timescale, pools.Pool());

Complexify(weights.get(), c_weights.get());
auto func = [&]() {
Expand Down
4 changes: 4 additions & 0 deletions gemma/gemma.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ class Gemma {
void GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens);

void SetMatMulVerbosity(int verbosity) {
if (verbosity >= 2) env_.print_best = true;
}

private:
MatMulEnv env_;

Expand Down
5 changes: 2 additions & 3 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
}
if (end_of_turn_seen && abs_pos > 0) {
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
// more, because we will pre-pend it again to the prompt in
// more, because we will prepend it again to the prompt in
// WrapAndTokenize.
abs_pos--;
}
Expand All @@ -236,14 +236,13 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc");

// TODO: remove once MatMul is updated.
app.max_packages = 1;
// Note that num_threads is an upper bound; we also limit to the number of
// detected and enabled cores.
NestedPools pools = CreatePools(app);
Allocator::Init(pools.Topology());

Gemma model = CreateGemma(loader, pools);
model.SetMatMulVerbosity(app.verbosity);
KVCache kv_cache =
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);

Expand Down
91 changes: 55 additions & 36 deletions ops/bench_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,33 +117,37 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
}

void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
std::vector<double>& times) {
std::vector<double>& times, MMPerKey* per_key) {
std::sort(times.begin(), times.end());
// bench_dnn reports the best and average, but the median seems more
// consistent and resistant to outliers.
const double elapsed = times[times.size() / 2];
const double ratio = elapsed / (times[0] + 1E-6); // vs best, avoid / 0
const double vs_best = elapsed / (times[0] + 1E-6); // avoid / 0

const size_t num_b = B_extents.Area();
// FMA counts as two FLOP.
fprintf(stderr, "%.1f\t(med %.3f ms = %0.2fx min)\n",
2 * 1E-9 * A_extents.rows * num_b / elapsed, elapsed * 1E3, ratio);
const double flops = 2 * A_extents.rows * num_b / elapsed; // FMA = 2 ops

fprintf(stderr, "\t%.1f GFLOPS %.3f ms %0.2fx\n", flops * 1E-9, elapsed * 1E3,
vs_best);
}

// Generates inputs and prints observed throughput of MatMul.
// M = A rows, K = A cols, N = C cols.
template <typename MatTA, typename MatTB = MatTA>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
fprintf(stderr, "\nBenchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
M, K, N, add, TypeName<MatTA>(), TypeName<MatTB>());
if (env.print_config || env.print_measurement) {
fprintf(stderr, "\n");
}
fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N,
add, TypeName<MatTA>(), TypeName<MatTB>());

const Extents2D A_extents(M, K);
const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N);

RowVectorBatch<float> c_slow_batch(C_extents);
RowVectorBatch<float> c_batch(C_extents);
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);

std::unique_ptr<MatStorageT<float>> add_storage;
if (add) {
Expand All @@ -161,57 +165,72 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C = RowPtrFromBatch(c_batch);

constexpr size_t kSamples = 20;
// Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12;
std::vector<double> times;
times.reserve(kSamples);
times.reserve(num_samples);

// Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op.
BindB(B_extents.rows, B, env.parallel);
BindC(A_extents.rows, C, env.parallel);

Tristate use_spinning = Tristate::kDefault;
env.parallel.Pools().MaybeStartSpinning(use_spinning);

// env.print_config = true;
// env.print_measurement = true;
env.print_best = true;

double keep = 0.0;
MMPerKey* per_key;
// Until enough samples collected *after* autotuning finished:
while (times.size() < kSamples) {
while (times.size() < num_samples) {
const double t0 = hwy::platform::Now();
MatMul(A, B, add_row, env, C);
per_key = MatMul(A, B, add_row, env, C);
const double t1 = hwy::platform::Now();
double elapsed = t1 - t0;
keep += C.Row(0)[hwy::Unpredictable1()];

times.push_back(elapsed);
// Only record times after autotuning finished.
if (per_key->autotune.Best()) times.push_back(elapsed);
}
hwy::PreventElision(keep);
env.parallel.Pools().MaybeStopSpinning(use_spinning);
PrintSpeed(A_extents, B_extents, times);
PrintSpeed(A_extents, B_extents, times, per_key);
}

using F32 = float;
using SFP = SfpStream;

void BenchAllMatMul() {
if (first_target == 0) first_target = HWY_TARGET;
if (HWY_TARGET != first_target) return;

for (size_t max_packages : {/*1,*/ 2}) {
const size_t max_threads = 0; // no limit
NestedPools pools(max_threads, Tristate::kDefault,
BoundedSlice(0, max_packages));
#if GEMMA_DISABLE_TOPOLOGY
if (max_packages == 2) break; // we only have one package
#else
// If less than the limit, we have already tested all num_packages.
if (pools.Topology().FullTopology().packages.size() < max_packages) break;
#endif
fprintf(stderr, "BenchAllMatMul %zu: %s %s\n", max_packages,
pools.TopologyString(), pools.PinString());
// Disable the best-target-only limitation.
// if (HWY_TARGET != first_target) return;

Allocator::Init(pools.Topology());
MatMulEnv env(pools);
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 ||
HWY_TARGET == HWY_SSE2) {
return;
}

for (size_t batch_size : {1, 4, 128, 512}) {
constexpr bool kAdd = false;
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
}
const size_t max_threads = 0; // no limit
const BoundedSlice package_slice(0, 1); // all packages/sockets
const BoundedSlice cluster_slice(0, 1); // all clusters/CCX
const BoundedSlice lp_slice(0, 1); // default to all cores (per package).
NestedPools pools(max_threads, Tristate::kDefault, package_slice,
cluster_slice, lp_slice);
fprintf(stderr, "BenchAllMatMul %s %s\n", pools.TopologyString(),
pools.PinString());

Allocator::Init(pools.Topology(), /*enable_bind=*/true);
MatMulEnv env(pools);

for (size_t batch_size : {1, 4, 128, 512}) {
constexpr bool kAdd = false;
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
}

PROFILER_PRINT_RESULTS();
Expand Down
3 changes: 3 additions & 0 deletions ops/dot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "compression/shared.h"
#include "util/allocator.h"
#include "util/app.h"
#include "util/test_util.h"
#include "util/threading.h"
#include "hwy/base.h"
Expand Down Expand Up @@ -999,6 +1000,8 @@ struct TestShortDotsT {
const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot

NestedPools pools = CreatePools(AppArgs());
Allocator::Init(pools.Topology());
CompressWorkingSet work;
std::mt19937 rng;
rng.seed(12345);
Expand Down
Loading

0 comments on commit c803fb9

Please sign in to comment.